|
|
@ -36,6 +36,8 @@ class EnhancedQASearchConfig: |
|
|
|
self.threshold = 0.6 |
|
|
|
# config for llm |
|
|
|
self.llm_device = -1 |
|
|
|
self.llm_model = 'openai' |
|
|
|
|
|
|
|
|
|
|
|
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() |
|
|
|
_sbert_models = ops.sentence_embedding.sbert().get_op().supported_model_names() |
|
|
@ -72,6 +74,11 @@ def _get_similarity_evaluation_op(config): |
|
|
|
return lambda x: [i for i in x if i[1] >= config.threshold] |
|
|
|
|
|
|
|
|
|
|
|
def _get_llm_op(config): |
|
|
|
if config.llm_model == 'openai': |
|
|
|
return ops.LLM.OpenAI(api_key=config.openai_api_key) |
|
|
|
|
|
|
|
|
|
|
|
@AutoPipes.register |
|
|
|
def enhanced_qa_search_pipe(config): |
|
|
|
allow_triton, sentence_embedding_op = _get_embedding_op(config) |
|
|
@ -93,6 +100,8 @@ def enhanced_qa_search_pipe(config): |
|
|
|
password=config.password, |
|
|
|
) |
|
|
|
|
|
|
|
llm_op = _get_llm_op(config) |
|
|
|
|
|
|
|
p = ( |
|
|
|
pipe.input('question', 'history') |
|
|
|
.map('question', 'embedding', sentence_embedding_op, config=sentence_embedding_config) |
|
|
@ -108,7 +117,7 @@ def enhanced_qa_search_pipe(config): |
|
|
|
p = ( |
|
|
|
p.map('result', 'docs', lambda x:[i[2] for i in x]) |
|
|
|
.map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer()) |
|
|
|
.map('prompt', 'answer', ops.LLM.OpenAI(api_key=config.openai_api_key)) |
|
|
|
.map('prompt', 'answer', llm_op) |
|
|
|
) |
|
|
|
|
|
|
|
return p.output('answer') |
|
|
|