logo
Browse Source

Update eqa insert

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 2 years ago
parent
commit
ad5bfee254
  1. 12
      README.md
  2. 32
      eqa_insert.py

12
README.md

@ -53,7 +53,6 @@ config.model = 'all-MiniLM-L6-v2'
config.host = '127.0.0.1'
config.port = '19530'
config.collection_name = collection_name
config.source_type = 'url'
p = AutoPipes.pipeline('eqa-insert', config=config)
res = p('https://github.com/towhee-io/towhee/blob/main/README.md')
@ -70,14 +69,11 @@ Then you can run `collection.num_entities` to check the number of the data in Mi
### **EnhancedQAInsertConfig**
#### **Configuration for [Text Loader](https://towhee.io/towhee/text-loader):**
#### **Configuration for [Text Spliter](https://towhee.io/towhee/text-spliter):**
***chunk_size: int***
The size of each chunk, defaults to 300.
***source_type: str***
The type of the soure, defaults to `'file'`, you can also set to `'url'` for you url of your documentation.
#### **Configuration for Sentence Embedding:**
***model: str***
@ -88,12 +84,6 @@ You can refer to the above [Model(s) list ](https://towhee.io/tasks/detail/opera
The api key of openai, default to `None`.
This key is required if the model is from OpenAI, you can check the model provider in the above [Model(s) list](https://towhee.io/sentence-embedding/openai).
***customize_embedding_op: str***
The name of the customize embedding operator, defaults to `None`.
***normalize_vec: bool***
Whether to normalize the embedding vectors, defaults to `True`.
***device:*** ***int***
The number of devices, defaults to `-1`, which means using the CPU.
If the setting is not `-1`, the specified GPU device will be used.

32
eqa_insert.py

@ -21,14 +21,11 @@ class EnhancedQAInsertConfig:
Config of pipeline
"""
def __init__(self):
# config for text_loader
# config for text_spliter
self.chunk_size = 300
self.source_type = 'file'
# config for sentence_embedding
self.model = 'all-MiniLM-L6-v2'
self.openai_api_key = None
self.customize_embedding_op = None
self.normalize_vec = True
self.device = -1
# config for insert_milvus
self.host = '127.0.0.1'
@ -51,26 +48,18 @@ def _get_embedding_op(config):
else:
device = config.device
if config.customize_embedding_op is not None:
return True, config.customize_embedding_op
if config.model in _hf_models:
return True, ops.sentence_embedding.transformers(model_name=config.model,
device=device)
return True, ops.sentence_embedding.transformers(model_name=config.model, device=device)
if config.model in _sbert_models:
return True, ops.sentence_embedding.sbert(model_name=config.model,
device=device)
return True, ops.sentence_embedding.sbert(model_name=config.model, device=device)
if config.model in _openai_models:
return False, ops.sentence_embedding.openai(model_name=config.model,
api_key=config.openai_api_key)
return False, ops.sentence_embedding.openai(model_name=config.model, api_key=config.openai_api_key)
raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _openai_models))
@AutoPipes.register
def enhanced_qa_insert_pipe(config):
text_load_op = ops.text_loader(chunk_size=config.chunk_size, source_type=config.source_type)
allow_triton, sentence_embedding_op = _get_embedding_op(config)
sentence_embedding_config = {}
if allow_triton:
@ -86,15 +75,12 @@ def enhanced_qa_insert_pipe(config):
password=config.password,
)
p = (
return (
pipe.input('doc')
.flat_map('doc', 'sentence', text_load_op)
.map('doc', 'text', ops.text_loader())
.flat_map('text', 'sentence', ops.text_spilter(chunk_size=config.chunk_size))
.map('sentence', 'embedding', sentence_embedding_op, config=sentence_embedding_config)
)
if config.normalize_vec:
p = p.map('embedding', 'embedding', ops.towhee.np_normalize())
return (p.map(('doc', 'sentence', 'embedding'), 'mr', insert_milvus_op)
.map('embedding', 'embedding', ops.towhee.np_normalize())
.map(('doc', 'sentence', 'embedding'), 'mr', insert_milvus_op)
.output('mr')
)

Loading…
Cancel
Save