diff --git a/README.md b/README.md
index e04b10c..92f0177 100644
--- a/README.md
+++ b/README.md
@@ -43,7 +43,7 @@ collection.create_index(field_name="embedding", index_params=index_params)
### **Create pipeline and set the configuration**
-> You need also start [elasticsearch](https://www.elastic.co/elasticsearch/).
+> If you set config.es_enable to True, you need also start [elasticsearch](https://www.elastic.co/elasticsearch/).
>
> More parameters refer to the Configuration.
@@ -54,11 +54,12 @@ config = AutoConfig.load_config('osschat-insert')
config.embedding_model = 'all-MiniLM-L6-v2'
config.milvus_host = '127.0.0.1'
config.milvus_port = '19530'
+config.es_enable = True
config.es_host = '127.0.0.1'
config.es_port = '9200'
p = AutoPipes.pipeline('osschat-insert', config=config)
-res = p('https://github.com/towhee-io/towhee/blob/main/README.md', 'osschat', 'osschat')
+res = p('https://github.com/towhee-io/towhee/blob/main/README.md', 'osschat')
```
Then you can run `collection.flush() ` and `collection.num_entities` to check the number of the data in Milvus as a knowledge base.
@@ -81,6 +82,7 @@ And run `es_client.search(index='osschat', body={"query":{"match_all":{}}})['hit
The type of splitter, defaults to 'RecursiveCharacter'. You can set this parameter in ['[RecursiveCharacter](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/recursive_text_splitter.html)', '[Markdown](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/markdown.html)', '[PythonCode](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/python.html)', '[Character](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/character_text_splitter.html#)', '[NLTK](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/nltk.html)', '[Spacy](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/spacy.html)', '[Tiktoken](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/tiktoken_splitter.html)', '[HuggingFace](https://python.langchain.com/en/latest/modules/indexes/text_splitters/examples/huggingface_length_function.html)'].
***chunk_size***: int
+
The size of each chunk, defaults to 300.
***splitter_kwargs***: dict
@@ -90,43 +92,62 @@ The kwargs for the splitter, defaults to {}.
#### **Configuration for Sentence Embedding:**
***embedding_model***: str
+
The model name for sentence embedding, defaults to `'all-MiniLM-L6-v2'`.
You can refer to the above [Model(s) list ](https://towhee.io/tasks/detail/operator?field_name=Natural-Language-Processing&task_name=Sentence-Embedding)to set the model, some of these models are from [HuggingFace](https://huggingface.co/) (open source), and some are from [OpenAI](https://openai.com/) (not open, required API key).
***openai_api_key***: str
+
The api key of openai, default to `None`.
This key is required if the model is from OpenAI, you can check the model provider in the above [Model(s) list](https://towhee.io/sentence-embedding/openai).
***embedding_device:*** int
+
The number of device, defaults to `-1`, which means using the CPU.
If the setting is not `-1`, the specified GPU device will be used.
+***embedding_normalize:*** bool
+
+Whether to normalize the embedding vectors, defaults to `True`.
+
#### **Configuration for [Milvus](https://towhee.io/ann-insert/osschat-milvus):**
***milvus_host***: str
+
Host of Milvus vector database, default is `'127.0.0.1'`.
***milvus_port***: str
+
Port of Milvus vector database, default is `'19530'`.
***milvus_user***: str
+
The user name for [Cloud user](https://zilliz.com/cloud), defaults to `None`.
***milvus_password***: str
+
The user password for [Cloud user](https://zilliz.com/cloud), defaults to `None`.
#### **Configuration for [Elasticsearch](https://towhee.io/elasticsearch/osschat-index):**
+***es_enable***: bool
+
+Whether to use Elasticsearch, default is `True`.
+
***es_host***: str
+
Host of Elasticsearch, default is `'127.0.0.1'`.
***es_port***: str
+
Port of Elasticsearche, default is `'9200'`.
***es_user***: str
+
The user name for Elasticsearch, defaults to `None`.
***es_password***: str
+
The user password for Elasticsearch, defaults to `None`.
@@ -143,11 +164,9 @@ Insert documentation into Milvus as a knowledge base.
Path or url of the document to be loaded.
-***milvus_collection***: str
-The collection name for Milvus vector database, is required when inserting data into Milvus.
+***project_name***: str
-***es_index***: str
-The index name of elasticsearch.
+The collection name for Milvus vector database, also the index name of elasticsearch.
diff --git a/osschat_insert.py b/osschat_insert.py
index a64d233..f91e6f4 100644
--- a/osschat_insert.py
+++ b/osschat_insert.py
@@ -14,7 +14,6 @@
from typing import Dict, Optional, Any
from pydantic import BaseModel
-from datetime import datetime
from towhee import ops, pipe, AutoPipes, AutoConfig
@@ -32,12 +31,14 @@ class OSSChatInsertConfig(BaseModel):
embedding_model: Optional[str] = 'all-MiniLM-L6-v2'
openai_api_key: Optional[str] = None
embedding_device: Optional[int] = -1
+ embedding_normalize: Optional[bool] = True
# config for insert_milvus
milvus_host: Optional[str] = '127.0.0.1'
milvus_port: Optional[str] = '19530'
milvus_user: Optional[str] = None
milvus_password: Optional[str] = None
# config for elasticsearch
+ es_enable: Optional[bool] = True
es_host: Optional[str] = '127.0.0.1'
es_port: Optional[str] = '9200'
es_user: Optional[str] = None
@@ -74,13 +75,6 @@ def osschat_insert_pipe(config):
chunk_size=config.chunk_size,
**config.splitter_kwargs)
- es_index_op = ops.elasticsearch.osschat_index(host=config.es_host,
- port=config.es_port,
- user=config.es_user,
- password=config.es_password,
- ca_certs=config.es_ca_certs,
- )
-
allow_triton, sentence_embedding_op = _get_embedding_op(config)
sentence_embedding_config = {}
if allow_triton:
@@ -95,14 +89,27 @@ def osschat_insert_pipe(config):
password=config.milvus_password,
)
- return (
- pipe.input('doc', 'milvus_collection', 'es_index')
+ p = (
+ pipe.input('doc', 'project_name')
.map('doc', 'text', ops.text_loader())
.flat_map('text', 'sentence', text_split_op)
- .map('sentence', 'es_sentence', lambda x: {'sentence': x})
- .map(('es_index', 'es_sentence'), 'es_res', es_index_op)
.map('sentence', 'embedding', sentence_embedding_op, config=sentence_embedding_config)
- .map('embedding', 'embedding', ops.towhee.np_normalize())
- .map(('milvus_collection', 'doc', 'sentence', 'embedding'), 'milvus_res', insert_milvus_op)
- .output('milvus_res', 'es_res')
)
+ if config.embedding_normalize:
+ p = p.map('embedding', 'embedding', ops.towhee.np_normalize())
+
+ p = p.map(('project_name', 'doc', 'sentence', 'embedding'), 'milvus_res', insert_milvus_op)
+
+ if config.es_enable:
+ es_index_op = ops.elasticsearch.osschat_index(host=config.es_host,
+ port=config.es_port,
+ user=config.es_user,
+ password=config.es_password,
+ ca_certs=config.es_ca_certs,
+ )
+ p = (
+ p.map('sentence', 'es_sentence', lambda x: {'sentence': x})
+ .map(('project_name', 'es_sentence'), 'es_res', es_index_op)
+ )
+ return p.output('milvus_res', 'es_res')
+ return p.output('milvus_res')