diff --git a/README.md b/README.md index 1b8c285..185ea52 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,6 @@ config.model = 'all-MiniLM-L6-v2' config.host = '127.0.0.1' config.port = '19530' config.collection_name = collection_name -config.source_type = 'url' p = AutoPipes.pipeline('eqa-insert', config=config) res = p('https://github.com/towhee-io/towhee/blob/main/README.md') @@ -70,14 +69,11 @@ Then you can run `collection.num_entities` to check the number of the data in Mi ### **EnhancedQAInsertConfig** -#### **Configuration for [Text Loader](https://towhee.io/towhee/text-loader):** +#### **Configuration for [Text Spliter](https://towhee.io/towhee/text-spliter):** ***chunk_size: int*** The size of each chunk, defaults to 300. -***source_type: str*** -The type of the soure, defaults to `'file'`, you can also set to `'url'` for you url of your documentation. - #### **Configuration for Sentence Embedding:** ***model: str*** @@ -88,12 +84,6 @@ You can refer to the above [Model(s) list ](https://towhee.io/tasks/detail/opera The api key of openai, default to `None`. This key is required if the model is from OpenAI, you can check the model provider in the above [Model(s) list](https://towhee.io/sentence-embedding/openai). -***customize_embedding_op: str*** -The name of the customize embedding operator, defaults to `None`. - -***normalize_vec: bool*** -Whether to normalize the embedding vectors, defaults to `True`. - ***device:*** ***int*** The number of devices, defaults to `-1`, which means using the CPU. If the setting is not `-1`, the specified GPU device will be used. diff --git a/eqa_insert.py b/eqa_insert.py index d894d29..b3ff412 100644 --- a/eqa_insert.py +++ b/eqa_insert.py @@ -21,14 +21,11 @@ class EnhancedQAInsertConfig: Config of pipeline """ def __init__(self): - # config for text_loader + # config for text_spliter self.chunk_size = 300 - self.source_type = 'file' # config for sentence_embedding self.model = 'all-MiniLM-L6-v2' self.openai_api_key = None - self.customize_embedding_op = None - self.normalize_vec = True self.device = -1 # config for insert_milvus self.host = '127.0.0.1' @@ -51,26 +48,18 @@ def _get_embedding_op(config): else: device = config.device - if config.customize_embedding_op is not None: - return True, config.customize_embedding_op - if config.model in _hf_models: - return True, ops.sentence_embedding.transformers(model_name=config.model, - device=device) + return True, ops.sentence_embedding.transformers(model_name=config.model, device=device) if config.model in _sbert_models: - return True, ops.sentence_embedding.sbert(model_name=config.model, - device=device) + return True, ops.sentence_embedding.sbert(model_name=config.model, device=device) if config.model in _openai_models: - return False, ops.sentence_embedding.openai(model_name=config.model, - api_key=config.openai_api_key) + return False, ops.sentence_embedding.openai(model_name=config.model, api_key=config.openai_api_key) raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _openai_models)) @AutoPipes.register def enhanced_qa_insert_pipe(config): - text_load_op = ops.text_loader(chunk_size=config.chunk_size, source_type=config.source_type) - allow_triton, sentence_embedding_op = _get_embedding_op(config) sentence_embedding_config = {} if allow_triton: @@ -86,15 +75,12 @@ def enhanced_qa_insert_pipe(config): password=config.password, ) - p = ( + return ( pipe.input('doc') - .flat_map('doc', 'sentence', text_load_op) + .map('doc', 'text', ops.text_loader()) + .flat_map('text', 'sentence', ops.text_spilter(chunk_size=config.chunk_size)) .map('sentence', 'embedding', sentence_embedding_op, config=sentence_embedding_config) - ) - - if config.normalize_vec: - p = p.map('embedding', 'embedding', ops.towhee.np_normalize()) - - return (p.map(('doc', 'sentence', 'embedding'), 'mr', insert_milvus_op) + .map('embedding', 'embedding', ops.towhee.np_normalize()) + .map(('doc', 'sentence', 'embedding'), 'mr', insert_milvus_op) .output('mr') - ) + )