From 9937f688b67f191b828e7b0855693b1d0b98494d Mon Sep 17 00:00:00 2001 From: Kaiyuan Hu Date: Tue, 30 May 2023 16:48:46 +0800 Subject: [PATCH] Set llm option Signed-off-by: Kaiyuan Hu --- eqa_search.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/eqa_search.py b/eqa_search.py index ad40c84..ac0ffbb 100644 --- a/eqa_search.py +++ b/eqa_search.py @@ -36,6 +36,8 @@ class EnhancedQASearchConfig: self.threshold = 0.6 # config for llm self.llm_device = -1 + self.llm_model = 'openai' + _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() _sbert_models = ops.sentence_embedding.sbert().get_op().supported_model_names() @@ -72,6 +74,11 @@ def _get_similarity_evaluation_op(config): return lambda x: [i for i in x if i[1] >= config.threshold] +def _get_llm_op(config): + if config.llm_model == 'openai': + return ops.LLM.OpenAI(api_key=config.openai_api_key) + + @AutoPipes.register def enhanced_qa_search_pipe(config): allow_triton, sentence_embedding_op = _get_embedding_op(config) @@ -93,6 +100,8 @@ def enhanced_qa_search_pipe(config): password=config.password, ) + llm_op = _get_llm_op(config) + p = ( pipe.input('question', 'history') .map('question', 'embedding', sentence_embedding_op, config=sentence_embedding_config) @@ -108,7 +117,7 @@ def enhanced_qa_search_pipe(config): p = ( p.map('result', 'docs', lambda x:[i[2] for i in x]) .map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer()) - .map('prompt', 'answer', ops.LLM.OpenAI(api_key=config.openai_api_key)) + .map('prompt', 'answer', llm_op) ) return p.output('answer')