From 186a18d6b16c8ed3243406ed80835dcf9c7f5faa Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Thu, 1 Jun 2023 19:15:43 +0800 Subject: [PATCH] Update params Signed-off-by: shiyu22 --- README.md | 12 ++++++------ eqa_insert.py | 26 +++++++++++++------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 735096e..7b55a4c 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ collection.create_index(field_name="embedding", index_params=index_params) from towhee import AutoPipes, AutoConfig config = AutoConfig.load_config('eqa-insert') -config.model = 'all-MiniLM-L6-v2' +config.embedding_model = 'all-MiniLM-L6-v2' config.host = '127.0.0.1' config.port = '19530' config.collection_name = collection_name @@ -84,16 +84,16 @@ The kwargs for the splitter, defaults to {}. #### **Configuration for Sentence Embedding:** -***model***: str -The model name in the sentence embedding pipeline, defaults to `'all-MiniLM-L6-v2'`. +***embedding_model***: str +The model name for sentence embedding, defaults to `'all-MiniLM-L6-v2'`. You can refer to the above [Model(s) list ](https://towhee.io/tasks/detail/operator?field_name=Natural-Language-Processing&task_name=Sentence-Embedding)to set the model, some of these models are from [HuggingFace](https://huggingface.co/) (open source), and some are from [OpenAI](https://openai.com/) (not open, required API key). ***openai_api_key***: str -The api key of openai, default to `None`. +The api key of openai, default to `None`. This key is required if the model is from OpenAI, you can check the model provider in the above [Model(s) list](https://towhee.io/sentence-embedding/openai). -***device:*** int -The number of devices, defaults to `-1`, which means using the CPU. +***embedding_device:*** int +The number of device, defaults to `-1`, which means using the CPU. If the setting is not `-1`, the specified GPU device will be used. #### **Configuration for [Milvus](https://towhee.io/ann-insert/milvus-client):** diff --git a/eqa_insert.py b/eqa_insert.py index db94d1e..7c5bb26 100644 --- a/eqa_insert.py +++ b/eqa_insert.py @@ -28,9 +28,9 @@ class EnhancedQAInsertConfig(BaseModel): chunk_size: Optional[int] = 300 splitter_kwargs: Optional[Dict[str, Any]] = {} # config for sentence_embedding - model: Optional[str] = 'all-MiniLM-L6-v2' + embedding_model: Optional[str] = 'all-MiniLM-L6-v2' openai_api_key: Optional[str] = None - device: Optional[int] = -1 + embedding_device: Optional[int] = -1 # config for insert_milvus host: Optional[str] = '127.0.0.1' port: Optional[str] = '19530' @@ -47,18 +47,18 @@ _openai_models = ['text-embedding-ada-002', 'text-similarity-davinci-001', def _get_embedding_op(config): - if config.device == -1: + if config.embedding_device == -1: device = 'cpu' else: - device = config.device + device = config.embedding_device - if config.model in _hf_models: - return True, ops.sentence_embedding.transformers(model_name=config.model, device=device) - if config.model in _sbert_models: - return True, ops.sentence_embedding.sbert(model_name=config.model, device=device) - if config.model in _openai_models: - return False, ops.sentence_embedding.openai(model_name=config.model, api_key=config.openai_api_key) - raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _sbert_models + _openai_models)) + if config.embedding_model in _hf_models: + return True, ops.sentence_embedding.transformers(model_name=config.embedding_model, device=device) + if config.embedding_model in _sbert_models: + return True, ops.sentence_embedding.sbert(model_name=config.embedding_model, device=device) + if config.embedding_model in _openai_models: + return False, ops.sentence_embedding.openai(model_name=config.embedding_model, api_key=config.openai_api_key) + raise RuntimeError('Unknown model: [%s], only support: %s' % (config.embedding_model, _hf_models + _sbert_models + _openai_models)) @@ -72,8 +72,8 @@ def enhanced_qa_insert_pipe(config): allow_triton, sentence_embedding_op = _get_embedding_op(config) sentence_embedding_config = {} if allow_triton: - if config.device >= 0: - sentence_embedding_config = AutoConfig.TritonGPUConfig(device_ids=[config.device], max_batch_size=128) + if config.embedding_device >= 0: + sentence_embedding_config = AutoConfig.TritonGPUConfig(device_ids=[config.embedding_device], max_batch_size=128) else: sentence_embedding_config = AutoConfig.TritonCPUConfig()