From a7575a9ba5248d679ce8d4b89a70e01a4c6882b9 Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Thu, 15 Jun 2023 17:43:40 +0800 Subject: [PATCH] Update config --- README.md | 31 +++++++++++++++++++++++++------ osschat_insert.py | 37 ++++++++++++++++++++++--------------- 2 files changed, 47 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index e04b10c..92f0177 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ collection.create_index(field_name="embedding", index_params=index_params) ### **Create pipeline and set the configuration** -> You need also start [elasticsearch](https://www.elastic.co/elasticsearch/). +> If you set config.es_enable to True, you need also start [elasticsearch](https://www.elastic.co/elasticsearch/). > > More parameters refer to the Configuration. @@ -54,11 +54,12 @@ config = AutoConfig.load_config('osschat-insert') config.embedding_model = 'all-MiniLM-L6-v2' config.milvus_host = '127.0.0.1' config.milvus_port = '19530' +config.es_enable = True config.es_host = '127.0.0.1' config.es_port = '9200' p = AutoPipes.pipeline('osschat-insert', config=config) -res = p('https://github.com/towhee-io/towhee/blob/main/README.md', 'osschat', 'osschat') +res = p('https://github.com/towhee-io/towhee/blob/main/README.md', 'osschat') ``` Then you can run `collection.flush() ` and `collection.num_entities` to check the number of the data in Milvus as a knowledge base. @@ -81,6 +82,7 @@ And run `es_client.search(index='osschat', body={"query":{"match_all":{}}})['hit The type of splitter, defaults to 'RecursiveCharacter'. You can set this parameter in ['[RecursiveCharacter](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/recursive_text_splitter.html)', '[Markdown](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/markdown.html)', '[PythonCode](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/python.html)', '[Character](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/character_text_splitter.html#)', '[NLTK](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/nltk.html)', '[Spacy](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/spacy.html)', '[Tiktoken](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/tiktoken_splitter.html)', '[HuggingFace](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/huggingface_length_function.html)']. ***chunk_size***: int + The size of each chunk, defaults to 300. ***splitter_kwargs***: dict @@ -90,43 +92,62 @@ The kwargs for the splitter, defaults to {}. #### **Configuration for Sentence Embedding:** ***embedding_model***: str + The model name for sentence embedding, defaults to `'all-MiniLM-L6-v2'`. You can refer to the above [Model(s) list ](https://towhee.io/tasks/detail/operator?field_name=Natural-Language-Processing&task_name=Sentence-Embedding)to set the model, some of these models are from [HuggingFace](https://huggingface.co/) (open source), and some are from [OpenAI](https://openai.com/) (not open, required API key). ***openai_api_key***: str + The api key of openai, default to `None`. This key is required if the model is from OpenAI, you can check the model provider in the above [Model(s) list](https://towhee.io/sentence-embedding/openai). ***embedding_device:*** int + The number of device, defaults to `-1`, which means using the CPU. If the setting is not `-1`, the specified GPU device will be used. +***embedding_normalize:*** bool + +Whether to normalize the embedding vectors, defaults to `True`. + #### **Configuration for [Milvus](https://towhee.io/ann-insert/osschat-milvus):** ***milvus_host***: str + Host of Milvus vector database, default is `'127.0.0.1'`. ***milvus_port***: str + Port of Milvus vector database, default is `'19530'`. ***milvus_user***: str + The user name for [Cloud user](https://zilliz.com/cloud), defaults to `None`. ***milvus_password***: str + The user password for [Cloud user](https://zilliz.com/cloud), defaults to `None`. #### **Configuration for [Elasticsearch](https://towhee.io/elasticsearch/osschat-index):** +***es_enable***: bool + +Whether to use Elasticsearch, default is `True`. + ***es_host***: str + Host of Elasticsearch, default is `'127.0.0.1'`. ***es_port***: str + Port of Elasticsearche, default is `'9200'`. ***es_user***: str + The user name for Elasticsearch, defaults to `None`. ***es_password***: str + The user password for Elasticsearch, defaults to `None`.
@@ -143,11 +164,9 @@ Insert documentation into Milvus as a knowledge base. Path or url of the document to be loaded. -***milvus_collection***: str -The collection name for Milvus vector database, is required when inserting data into Milvus. +***project_name***: str -***es_index***: str -The index name of elasticsearch. +The collection name for Milvus vector database, also the index name of elasticsearch.
diff --git a/osschat_insert.py b/osschat_insert.py index a64d233..f91e6f4 100644 --- a/osschat_insert.py +++ b/osschat_insert.py @@ -14,7 +14,6 @@ from typing import Dict, Optional, Any from pydantic import BaseModel -from datetime import datetime from towhee import ops, pipe, AutoPipes, AutoConfig @@ -32,12 +31,14 @@ class OSSChatInsertConfig(BaseModel): embedding_model: Optional[str] = 'all-MiniLM-L6-v2' openai_api_key: Optional[str] = None embedding_device: Optional[int] = -1 + embedding_normalize: Optional[bool] = True # config for insert_milvus milvus_host: Optional[str] = '127.0.0.1' milvus_port: Optional[str] = '19530' milvus_user: Optional[str] = None milvus_password: Optional[str] = None # config for elasticsearch + es_enable: Optional[bool] = True es_host: Optional[str] = '127.0.0.1' es_port: Optional[str] = '9200' es_user: Optional[str] = None @@ -74,13 +75,6 @@ def osschat_insert_pipe(config): chunk_size=config.chunk_size, **config.splitter_kwargs) - es_index_op = ops.elasticsearch.osschat_index(host=config.es_host, - port=config.es_port, - user=config.es_user, - password=config.es_password, - ca_certs=config.es_ca_certs, - ) - allow_triton, sentence_embedding_op = _get_embedding_op(config) sentence_embedding_config = {} if allow_triton: @@ -95,14 +89,27 @@ def osschat_insert_pipe(config): password=config.milvus_password, ) - return ( - pipe.input('doc', 'milvus_collection', 'es_index') + p = ( + pipe.input('doc', 'project_name') .map('doc', 'text', ops.text_loader()) .flat_map('text', 'sentence', text_split_op) - .map('sentence', 'es_sentence', lambda x: {'sentence': x}) - .map(('es_index', 'es_sentence'), 'es_res', es_index_op) .map('sentence', 'embedding', sentence_embedding_op, config=sentence_embedding_config) - .map('embedding', 'embedding', ops.towhee.np_normalize()) - .map(('milvus_collection', 'doc', 'sentence', 'embedding'), 'milvus_res', insert_milvus_op) - .output('milvus_res', 'es_res') ) + if config.embedding_normalize: + p = p.map('embedding', 'embedding', ops.towhee.np_normalize()) + + p = p.map(('project_name', 'doc', 'sentence', 'embedding'), 'milvus_res', insert_milvus_op) + + if config.es_enable: + es_index_op = ops.elasticsearch.osschat_index(host=config.es_host, + port=config.es_port, + user=config.es_user, + password=config.es_password, + ca_certs=config.es_ca_certs, + ) + p = ( + p.map('sentence', 'es_sentence', lambda x: {'sentence': x}) + .map(('project_name', 'es_sentence'), 'es_res', es_index_op) + ) + return p.output('milvus_res', 'es_res') + return p.output('milvus_res')