logo
Browse Source

Set llm option

Signed-off-by: Kaiyuan Hu <kaiyuan.hu@zilliz.com>
main
Kaiyuan Hu 2 years ago
parent
commit
9937f688b6
  1. 11
      eqa_search.py

11
eqa_search.py

@ -36,6 +36,8 @@ class EnhancedQASearchConfig:
self.threshold = 0.6 self.threshold = 0.6
# config for llm # config for llm
self.llm_device = -1 self.llm_device = -1
self.llm_model = 'openai'
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names()
_sbert_models = ops.sentence_embedding.sbert().get_op().supported_model_names() _sbert_models = ops.sentence_embedding.sbert().get_op().supported_model_names()
@ -72,6 +74,11 @@ def _get_similarity_evaluation_op(config):
return lambda x: [i for i in x if i[1] >= config.threshold] return lambda x: [i for i in x if i[1] >= config.threshold]
def _get_llm_op(config):
if config.llm_model == 'openai':
return ops.LLM.OpenAI(api_key=config.openai_api_key)
@AutoPipes.register @AutoPipes.register
def enhanced_qa_search_pipe(config): def enhanced_qa_search_pipe(config):
allow_triton, sentence_embedding_op = _get_embedding_op(config) allow_triton, sentence_embedding_op = _get_embedding_op(config)
@ -93,6 +100,8 @@ def enhanced_qa_search_pipe(config):
password=config.password, password=config.password,
) )
llm_op = _get_llm_op(config)
p = ( p = (
pipe.input('question', 'history') pipe.input('question', 'history')
.map('question', 'embedding', sentence_embedding_op, config=sentence_embedding_config) .map('question', 'embedding', sentence_embedding_op, config=sentence_embedding_config)
@ -108,7 +117,7 @@ def enhanced_qa_search_pipe(config):
p = ( p = (
p.map('result', 'docs', lambda x:[i[2] for i in x]) p.map('result', 'docs', lambda x:[i[2] for i in x])
.map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer()) .map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer())
.map('prompt', 'answer', ops.LLM.OpenAI(api_key=config.openai_api_key))
.map('prompt', 'answer', llm_op)
) )
return p.output('answer') return p.output('answer')

Loading…
Cancel
Save