Skip to content

Commit

Permalink
Add token count (#76)
Browse files Browse the repository at this point in the history
Signed-off-by: Mengjia Gu <mengjia.gu@zilliz.com>
  • Loading branch information
jaelgu authored Aug 28, 2023
1 parent 2cfbe13 commit ce5c97f
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 31 deletions.
9 changes: 6 additions & 3 deletions gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
'The service should start with either "--langchain" or "--towhee".'

if USE_LANGCHAIN:
from src_langchain.operations import chat, insert, check, drop, get_history, clear_history # pylint: disable=C0413
from src_langchain.operations import chat, insert, check, drop, get_history, clear_history, count # pylint: disable=C0413
if USE_TOWHEE:
from src_towhee.operations import chat, insert, check, drop, get_history, clear_history # pylint: disable=C0413
from src_towhee.operations import chat, insert, check, drop, get_history, clear_history, count # pylint: disable=C0413


def create_session_id():
Expand Down Expand Up @@ -52,7 +52,10 @@ def add_project(project, data_url: str = None, data_file: object = None):
def check_project(project):
status = check(project)
if status['store']:
return 'Project exists. You can upload more documents or directly start conversation.'
counts = count(project)
vector_num = counts['vector store']
scalar_num = counts['scalar store']
return f'Project exists: {vector_num} in vector store, {scalar_num} in scalar store.'
else:
return 'Project does not exist. You need to upload the first document before conversation.'

Expand Down
9 changes: 6 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,17 @@ def do_project_add_api(project: str, url: str = None, file: UploadFile = None):
assert url or file, 'You need to upload file or enter url of document to add data.'
try:
if url:
num = insert(data_src=url, project=project, source_type='url')
chunk_num, token_count = insert(data_src=url, project=project, source_type='url')
if file:
temp_file = os.path.join(TEMP_DIR, file.filename)
with open(temp_file, 'wb') as f:
content = file.file.read()
f.write(content)
num = insert(data_src=temp_file, project=project, source_type='file')
return jsonable_encoder({'status': True, 'msg': f'Successfully inserted doc chunks: {num}'}), 200
chunk_num, token_count = insert(data_src=temp_file, project=project, source_type='file')
return jsonable_encoder({'status': True, 'msg': {
'chunk count': chunk_num,
'token count': token_count
}}), 200
except Exception as e: # pylint: disable=W0703
return jsonable_encoder({'status': False, 'msg': f'Failed to load data:\n{e}'}), 400

Expand Down
14 changes: 10 additions & 4 deletions src_langchain/data_loader/data_parser.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from config import DATAPARSER_CONFIG # pylint: disable=C0413
import os
import sys

from typing import List, Optional
import tiktoken
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter, RecursiveCharacterTextSplitter

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

from config import DATAPARSER_CONFIG # pylint: disable=C0413


CHUNK_SIZE = DATAPARSER_CONFIG.get('chunk_size', 300)

Expand All @@ -18,9 +18,11 @@ class DataParser:

def __init__(self,
splitter: TextSplitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE)
chunk_size=CHUNK_SIZE),
token_model: str = 'gpt-3.5-turbo'
):
self.splitter = splitter
self.enc = tiktoken.encoding_for_model(token_model)

def __call__(self, data_src, source_type: str = 'file') -> List[str]:
if not isinstance(data_src, list):
Expand All @@ -34,7 +36,11 @@ def __call__(self, data_src, source_type: str = 'file') -> List[str]:
'Invalid source type. Only support "file" or "url".')

docs = self.splitter.split_documents(docs)
return [str(doc.page_content) for doc in docs]
docs = [str(doc.page_content) for doc in docs]
token_count = 0
for doc in docs:
token_count += len(self.enc.encode(doc))
return docs, token_count

def from_files(self, files: list, encoding: Optional[str] = None) -> List[Document]:
'''Load documents from path or file-like object, return a list of unsplit LangChain Documents'''
Expand Down
4 changes: 2 additions & 2 deletions src_langchain/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def insert(data_src, project, source_type: str = 'file'):
'''
doc_db = DocStore(table_name=project,
embedding_func=encoder)
docs = load_data(data_src=data_src, source_type=source_type)
docs, token_count = load_data(data_src=data_src, source_type=source_type)
num = doc_db.insert(docs)
return num
return num, token_count


def drop(project):
Expand Down
10 changes: 7 additions & 3 deletions src_towhee/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def insert(data_src, project, source_type: str = 'file'): # pylint: disable=W061
res = insert_pipeline(data_src, project).to_list()
num = towhee_pipelines.count_entities(project)['vector store']
assert len(res) <= num, 'Failed to insert data.'
return len(res)
token_count = 0
for r in res:
token_count += r[0]['token_count']
return len(res), token_count


def drop(project):
Expand Down Expand Up @@ -126,8 +129,9 @@ def clear_history(project, session_id):
# question1 = 'What is Towhee?'
# question2 = 'What does it do?'

# count = insert(data_src=data_src, project=project, source_type='url')
# print('\nCount:', count)
# chunk_count, token_count = insert(data_src=data_src, project=project, source_type='url')
# print('\nChunk count:', chunk_count)
# print('\nToken count:', token_count)
# print('\nCheck:', check(project))

# new_question, answer = chat(project=project, session_id=session_id, question=question0)
Expand Down
2 changes: 1 addition & 1 deletion src_towhee/pipelines/insert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def build_insert_pipeline(
except RuntimeError as e: # pylint: disable=W0703
if name.replace('-', '_') == 'generate_questions':
sys.path.append(os.path.dirname(__file__))
from generate_questions import custom_pipeline # pylint: disable=c0415
from generate_questions import custom_pipeline # pylint: disable=c0415

insert_pipeline = custom_pipeline(config=config)
else:
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
towhee
torch
langchain
tiktoken
milvus
transformers
dashscope
Expand Down
10 changes: 5 additions & 5 deletions tests/unit_tests/src_langchain/data_loader/test_data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter

sys.path.append(os.path.join(os.path.dirname(__file__), '../../../../..'))
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../..'))

from src_langchain.data_loader import DataParser

Expand All @@ -23,14 +23,14 @@ def test_call_from_files(self):
tmp_file_path = fp.name
with io.open(tmp_file_path, 'w') as file:
file.write(text)
output = self.data_parser(tmp_file_path, source_type='file')
assert output == ['ab', 'c', 'd']
output, token_count = self.data_parser(tmp_file_path, source_type='file')
assert output == ['ab', 'c', 'd'], token_count == 3

def test_call_from_urls(self):
with patch('langchain.document_loaders.UnstructuredURLLoader.load') as mock_url_loader:
mock_url_loader.return_value = [Document(page_content='ab\ncd', metadata={})]
output = self.data_parser('www.mockurl.com', source_type='url')
assert output == ['ab', 'c', 'd']
output, token_count = self.data_parser('www.mockurl.com', source_type='url')
assert output == ['ab', 'c', 'd'], token_count == 3


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import unittest

sys.path.append(os.path.join(os.path.dirname(__file__), '../../../../..'))
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../..'))

from src_langchain.data_loader.data_splitter import MarkDownSplitter

Expand Down
34 changes: 29 additions & 5 deletions tests/unit_tests/src_towhee/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def create_pipelines(llm_src):

class TestPipelines(unittest.TestCase):
project = 'akcio_ut'
data_src = 'https://github.com/towhee-io/towhee/blob/main/requirements.txt'
data_src = 'https://towhee.io'
question = 'test question'

@classmethod
Expand All @@ -81,6 +81,10 @@ def test_openai(self):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -114,6 +118,10 @@ def test_chatglm(self):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -149,6 +157,10 @@ def json(self):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -190,6 +202,10 @@ def output(self):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -225,6 +241,10 @@ def json(self):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down Expand Up @@ -262,6 +282,10 @@ def iter_lines(self):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand All @@ -286,10 +310,6 @@ def __call__(self, *args, **kwargs):

pipelines = create_pipelines('dolly')

self.project = 'akcio_ut'
self.data_src = 'https://github.com/towhee-io/towhee/blob/main/requirements.txt'
self.question = 'test question'

# Check insert
if pipelines.check(self.project):
pipelines.drop(self.project)
Expand All @@ -300,6 +320,10 @@ def __call__(self, *args, **kwargs):

insert_pipeline = pipelines.insert_pipeline
res = insert_pipeline(self.data_src, self.project).to_list()
token_count = 0
for x in res:
token_count += x[0]['token_count']
assert token_count == 290
num = pipelines.count_entities(self.project)['vector store']
assert len(res) <= num

Expand Down
9 changes: 5 additions & 4 deletions tests/unit_tests/src_towhee/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def search_pipeline(self, *args, **kwargs):

def insert_pipeline(self, data_src, project, source_type='file'):
self.projects[project] = [data_src]
self.insert_que.put([(data_src)])
self.insert_que.put([{'milvus': len(data_src), 'es': None, 'token_count': len(data_src.split(' '))}])
self.insert_que.seal()
return self.insert_que

Expand All @@ -61,8 +61,9 @@ class TestOperations(unittest.TestCase):
'''Test operations'''
session_id = 'test000'
project = 'akcio_ut'
test_src = 'test_src'
test_src = 'test src'
expect_len = 1
expect_token_count = len(test_src.split(' '))
question = 'the first question'
expect_answer = 'mock answer'

Expand Down Expand Up @@ -113,8 +114,8 @@ def test_insert(self):

from src_towhee.operations import insert, check, drop

count = insert(self.test_src, self.project)
assert count == self.expect_len
chunk_count, token_count = insert(self.test_src, self.project)
assert chunk_count == self.expect_len, token_count == self.expect_token_count
status = check(self.project)
assert status['store']

Expand Down

0 comments on commit ce5c97f

Please sign in to comment.