towhee
/
            
              eqa-insert
              
                 
                
            
          copied
			You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
			Readme
Files and versions
		
      
        
        
          
            95 lines
          
        
        
          
            3.7 KiB
          
        
        
      
		
    
      
      
    
	
  
	
            95 lines
          
        
        
          
            3.7 KiB
          
        
        
      | # Copyright 2021 Zilliz. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| #     http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| 
 | |
| from typing import Dict, Optional, Any | |
| from pydantic import BaseModel | |
| 
 | |
| from towhee import ops, pipe, AutoPipes, AutoConfig | |
| 
 | |
| 
 | |
| @AutoConfig.register | |
| class EnhancedQAInsertConfig(BaseModel): | |
|     """ | |
|     Config of pipeline | |
|     """ | |
|     # config for text_splitter | |
|     type: Optional[str] = 'RecursiveCharacter' | |
|     chunk_size: int = 300 | |
|     splitter_kwargs: Optional[Dict[str, Any]] = {} | |
|     # config for sentence_embedding | |
|     model: Optional[str] = 'all-MiniLM-L6-v2' | |
|     openai_api_key: Optional[str] = None | |
|     device: Optional[int] = -1 | |
|     # config for insert_milvus | |
|     host: Optional[str] = '127.0.0.1' | |
|     port: Optional[str] = '19530' | |
|     collection_name: Optional[str] = 'chatbot' | |
|     user: Optional[str] = None | |
|     password: Optional[str] = None | |
| 
 | |
| 
 | |
| _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() | |
| _sbert_models = ops.sentence_embedding.sbert().get_op().supported_model_names() | |
| _openai_models = ['text-embedding-ada-002', 'text-similarity-davinci-001', | |
|                   'text-similarity-curie-001', 'text-similarity-babbage-001', | |
|                   'text-similarity-ada-001'] | |
| 
 | |
| 
 | |
| def _get_embedding_op(config): | |
|     if config.device == -1: | |
|         device = 'cpu' | |
|     else: | |
|         device = config.device | |
| 
 | |
|     if config.model in _hf_models: | |
|         return True, ops.sentence_embedding.transformers(model_name=config.model, device=device) | |
|     if config.model in _sbert_models: | |
|         return True, ops.sentence_embedding.sbert(model_name=config.model, device=device) | |
|     if config.model in _openai_models: | |
|         return False, ops.sentence_embedding.openai(model_name=config.model, api_key=config.openai_api_key) | |
|     raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _sbert_models + _openai_models)) | |
| 
 | |
| 
 | |
| 
 | |
| @AutoPipes.register | |
| def enhanced_qa_insert_pipe(config): | |
|     text_split_op = ops.text_splitter(type=config.type,  | |
|                                       chunk_size=config.chunk_size,  | |
|                                       **config.splitter_kwargs) | |
|      | |
|      | |
|     allow_triton, sentence_embedding_op = _get_embedding_op(config) | |
|     sentence_embedding_config = {} | |
|     if allow_triton: | |
|         if config.device >= 0: | |
|             sentence_embedding_config = AutoConfig.TritonGPUConfig(device_ids=[config.device], max_batch_size=128) | |
|         else: | |
|             sentence_embedding_config = AutoConfig.TritonCPUConfig() | |
| 
 | |
|     insert_milvus_op = ops.ann_insert.milvus_client(host=config.host, | |
|                                                     port=config.port, | |
|                                                     collection_name=config.collection_name, | |
|                                                     user=config.user, | |
|                                                     password=config.password, | |
|                                                     ) | |
| 
 | |
|     return ( | |
|         pipe.input('doc') | |
|             .map('doc', 'text', ops.text_loader()) | |
|             .flat_map('text', 'sentence', text_split_op) | |
|             .map('sentence', 'embedding', sentence_embedding_op, config=sentence_embedding_config) | |
|             .map('embedding', 'embedding', ops.towhee.np_normalize()) | |
|             .map(('doc', 'sentence', 'embedding'), 'mr', insert_milvus_op) | |
|             .output('mr') | |
|     )
 | 
