# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Optional, Any from pydantic import BaseModel from towhee import ops, pipe, AutoPipes, AutoConfig @AutoConfig.register class EnhancedQAInsertConfig(BaseModel): """ Config of pipeline """ # config for text_splitter type: Optional[str] = 'RecursiveCharacter' chunk_size: int = 300 splitter_kwargs: Optional[Dict[str, Any]] = {} # config for sentence_embedding model: Optional[str] = 'all-MiniLM-L6-v2' openai_api_key: Optional[str] = None device: Optional[int] = -1 # config for insert_milvus host: Optional[str] = '127.0.0.1' port: Optional[str] = '19530' collection_name: Optional[str] = 'chatbot' user: Optional[str] = None password: Optional[str] = None _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() _sbert_models = ops.sentence_embedding.sbert().get_op().supported_model_names() _openai_models = ['text-embedding-ada-002', 'text-similarity-davinci-001', 'text-similarity-curie-001', 'text-similarity-babbage-001', 'text-similarity-ada-001'] def _get_embedding_op(config): if config.device == -1: device = 'cpu' else: device = config.device if config.model in _hf_models: return True, ops.sentence_embedding.transformers(model_name=config.model, device=device) if config.model in _sbert_models: return True, ops.sentence_embedding.sbert(model_name=config.model, device=device) if config.model in _openai_models: return False, ops.sentence_embedding.openai(model_name=config.model, api_key=config.openai_api_key) raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _sbert_models + _openai_models)) @AutoPipes.register def enhanced_qa_insert_pipe(config): text_split_op = ops.text_splitter(type=config.type, chunk_size=config.chunk_size, **config.splitter_kwargs) allow_triton, sentence_embedding_op = _get_embedding_op(config) sentence_embedding_config = {} if allow_triton: if config.device >= 0: sentence_embedding_config = AutoConfig.TritonGPUConfig(device_ids=[config.device], max_batch_size=128) else: sentence_embedding_config = AutoConfig.TritonCPUConfig() insert_milvus_op = ops.ann_insert.milvus_client(host=config.host, port=config.port, collection_name=config.collection_name, user=config.user, password=config.password, ) return ( pipe.input('doc') .map('doc', 'text', ops.text_loader()) .flat_map('text', 'sentence', text_split_op) .map('sentence', 'embedding', sentence_embedding_op, config=sentence_embedding_config) .map('embedding', 'embedding', ops.towhee.np_normalize()) .map(('doc', 'sentence', 'embedding'), 'mr', insert_milvus_op) .output('mr') )