diff --git a/eqa_insert.py b/eqa_insert.py index 33698a4..35d54c4 100644 --- a/eqa_insert.py +++ b/eqa_insert.py @@ -12,30 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Optional, Any +from pydantic import BaseModel + from towhee import ops, pipe, AutoPipes, AutoConfig @AutoConfig.register -class EnhancedQAInsertConfig: +class EnhancedQAInsertConfig(BaseModel): """ Config of pipeline """ - def __init__(self): - # config for text_splitter - self.type = 'RecursiveCharacter' - self.chunk_size = 300 - self.splitter_kwargs = {} - # config for sentence_embedding - self.model = 'all-MiniLM-L6-v2' - self.openai_api_key = None - self.device = -1 - # config for insert_milvus - self.host = '127.0.0.1' - self.port = '19530' - self.collection_name = 'chatbot' - self.user = None - self.password = None - + # config for text_splitter + type: Optional[str] = 'RecursiveCharacter' + chunk_size: int = 300 + splitter_kwargs: Optional[Dict[str, Any]] = {} + # config for sentence_embedding + model: Optional[str] = 'all-MiniLM-L6-v2' + openai_api_key: Optional[str] = None + device: Optional[int] = -1 + # config for insert_milvus + host: Optional[str] = '127.0.0.1' + port: Optional[str] = '19530' + collection_name: Optional[str] = 'chatbot' + user: Optional[str] = None + password: Optional[str] = None + _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() _sbert_models = ops.sentence_embedding.sbert().get_op().supported_model_names() @@ -56,7 +58,7 @@ def _get_embedding_op(config): return True, ops.sentence_embedding.sbert(model_name=config.model, device=device) if config.model in _openai_models: return False, ops.sentence_embedding.openai(model_name=config.model, api_key=config.openai_api_key) - raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _openai_models)) + raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _sbert_models + _openai_models))