diff --git a/README.md b/README.md index 8ec76ef..d9a1c96 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,10 @@ Whether to use Elasticsearch, default is `True`. The connection arguments to connect elastic service. +***token_model*** + +The model used to count tokens, defaults to 'gpt-3.5-turbo'. +
diff --git a/osschat_insert.py b/osschat_insert.py index 8007b7e..5e69af5 100644 --- a/osschat_insert.py +++ b/osschat_insert.py @@ -40,6 +40,8 @@ class OSSChatInsertConfig(BaseModel, extra=Extra.allow): # config for elasticsearch es_enable: Optional[bool] = True es_connection_kwargs: Optional[dict] = {'hosts': ['https://127.0.0.1:9200'], 'basic_auth': ('elastic', 'my_password')} + # token count + token_model: Optional[str] = 'gpt-3.5-turbo' _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() @@ -96,6 +98,7 @@ def osschat_insert_pipe(config): p = ( pipe.input('doc', 'project_name') .map('doc', 'text', data_loader) + .map('doc', 'token_count', ops.token_counter(config.token_model)) .flat_map('text', 'chunk', text_split_op) .map('chunk', 'embedding', sentence_embedding_op, config=sentence_embedding_config) ) @@ -109,8 +112,8 @@ def osschat_insert_pipe(config): p = ( p.map('chunk', 'es_doc', lambda x: {'doc': x}) .map(('project_name', 'es_doc'), 'es_res', es_index_op) - .map(('milvus_res', 'es_res'), 'res', lambda x, y: {'milvus_res': x, 'es_res': y}) + .map(('milvus_res', 'es_res', 'token_count'), 'res', lambda x, y, c: {'milvus_res': x, 'es_res': y, 'token_count': c}) ) else: - p = p.map('milvus_res', 'res', lambda x: {'milvus_res': x, 'es_res': None}) + p = p.map(('milvus_res', 'token_count'), 'res', lambda x, c: {'milvus_res': x, 'es_res': None, 'token_count': c}) return p.output('res')