|
@ -24,21 +24,18 @@ class EnhancedQASearchConfig: |
|
|
# config for sentence_embedding |
|
|
# config for sentence_embedding |
|
|
self.model = 'all-MiniLM-L6-v2' |
|
|
self.model = 'all-MiniLM-L6-v2' |
|
|
self.openai_api_key = None |
|
|
self.openai_api_key = None |
|
|
self.normalize_vec = True |
|
|
|
|
|
self.device = -1 |
|
|
|
|
|
|
|
|
self.embedding_device = -1 |
|
|
# config for search_milvus |
|
|
# config for search_milvus |
|
|
self.host = '127.0.0.1' |
|
|
self.host = '127.0.0.1' |
|
|
self.port = '19530' |
|
|
self.port = '19530' |
|
|
self.collection_name = 'chatbot' |
|
|
self.collection_name = 'chatbot' |
|
|
self.top_k = 5 |
|
|
self.top_k = 5 |
|
|
self.metric_type='IP' |
|
|
|
|
|
self.output_fields=['sentence'] |
|
|
|
|
|
self.user = None |
|
|
self.user = None |
|
|
self.password = None |
|
|
self.password = None |
|
|
# config for similarity evaluation |
|
|
# config for similarity evaluation |
|
|
self.threshold = 0.6 |
|
|
self.threshold = 0.6 |
|
|
# self.similarity_evaluation = 'score_filter' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# config for llm |
|
|
|
|
|
self.llm_device = -1 |
|
|
|
|
|
|
|
|
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() |
|
|
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() |
|
|
_sbert_models = ops.sentence_embedding.sbert().get_op().supported_model_names() |
|
|
_sbert_models = ops.sentence_embedding.sbert().get_op().supported_model_names() |
|
@ -48,13 +45,10 @@ _openai_models = ['text-embedding-ada-002', 'text-similarity-davinci-001', |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_embedding_op(config): |
|
|
def _get_embedding_op(config): |
|
|
if config.device == -1: |
|
|
|
|
|
|
|
|
if config.embedding_device == -1: |
|
|
device = 'cpu' |
|
|
device = 'cpu' |
|
|
else: |
|
|
else: |
|
|
device = config.device |
|
|
|
|
|
|
|
|
|
|
|
if config.customize_embedding_op is not None: |
|
|
|
|
|
return True, config.customize_embedding_op |
|
|
|
|
|
|
|
|
device = config.embedding_device |
|
|
|
|
|
|
|
|
if config.model in _hf_models: |
|
|
if config.model in _hf_models: |
|
|
return True, ops.sentence_embedding.transformers( |
|
|
return True, ops.sentence_embedding.transformers( |
|
@ -75,7 +69,6 @@ def _get_embedding_op(config): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_similarity_evaluation_op(config): |
|
|
def _get_similarity_evaluation_op(config): |
|
|
# if config.similarity_evaluation == 'score_filter': |
|
|
|
|
|
return lambda x: [i for i in x if i[1] >= config.threshold] |
|
|
return lambda x: [i for i in x if i[1] >= config.threshold] |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -84,8 +77,8 @@ def enhanced_qa_search_pipe(config): |
|
|
allow_triton, sentence_embedding_op = _get_embedding_op(config) |
|
|
allow_triton, sentence_embedding_op = _get_embedding_op(config) |
|
|
sentence_embedding_config = {} |
|
|
sentence_embedding_config = {} |
|
|
if allow_triton: |
|
|
if allow_triton: |
|
|
if config.device >= 0: |
|
|
|
|
|
sentence_embedding_config = AutoConfig.TritonGPUConfig(device_ids=[config.device], max_batch_size=128) |
|
|
|
|
|
|
|
|
if config.embedding_device >= 0: |
|
|
|
|
|
sentence_embedding_config = AutoConfig.TritonGPUConfig(device_ids=[config.embedding_device], max_batch_size=128) |
|
|
else: |
|
|
else: |
|
|
sentence_embedding_config = AutoConfig.TritonCPUConfig() |
|
|
sentence_embedding_config = AutoConfig.TritonCPUConfig() |
|
|
|
|
|
|
|
@ -94,8 +87,8 @@ def enhanced_qa_search_pipe(config): |
|
|
port=config.port, |
|
|
port=config.port, |
|
|
collection_name=config.collection_name, |
|
|
collection_name=config.collection_name, |
|
|
limit=config.top_k, |
|
|
limit=config.top_k, |
|
|
output_fields=config.output_fields, |
|
|
|
|
|
metric_type=config.metric_type, |
|
|
|
|
|
|
|
|
output_fields=['text'], |
|
|
|
|
|
metric_type='IP', |
|
|
user=config.user, |
|
|
user=config.user, |
|
|
password=config.password, |
|
|
password=config.password, |
|
|
) |
|
|
) |
|
@ -103,16 +96,18 @@ def enhanced_qa_search_pipe(config): |
|
|
p = ( |
|
|
p = ( |
|
|
pipe.input('question', 'history') |
|
|
pipe.input('question', 'history') |
|
|
.map('question', 'embedding', sentence_embedding_op, config=sentence_embedding_config) |
|
|
.map('question', 'embedding', sentence_embedding_op, config=sentence_embedding_config) |
|
|
|
|
|
.map('embedding', 'embedding', ops.towhee.np_normalize()) |
|
|
|
|
|
.map('embedding', 'result', search_milvus_op) |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if config.normalize_vec: |
|
|
|
|
|
p = p.map('embedding', 'embedding', ops.towhee.np_normalize()) |
|
|
|
|
|
|
|
|
|
|
|
p = p.map('embedding', 'result', search_milvus_op) |
|
|
|
|
|
|
|
|
|
|
|
# if config.similarity_evaluation: |
|
|
# if config.similarity_evaluation: |
|
|
if config.threshold: |
|
|
if config.threshold: |
|
|
sim_eval_op = _get_similarity_evaluation_op(config) |
|
|
sim_eval_op = _get_similarity_evaluation_op(config) |
|
|
p = p.map('result', 'result', sim_eval_op) |
|
|
p = p.map('result', 'result', sim_eval_op) |
|
|
|
|
|
|
|
|
return p.output('question', 'history', 'result') |
|
|
|
|
|
|
|
|
p = ( |
|
|
|
|
|
p.map('result', 'docs', lambda x:[i[2] for i in x]) |
|
|
|
|
|
.map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer()) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
return p.output('question', 'history', 'docs', 'prompt') |
|
|