diff --git a/eqa_search.py b/eqa_search.py index ac0ffbb..f4d897e 100644 --- a/eqa_search.py +++ b/eqa_search.py @@ -35,7 +35,6 @@ class EnhancedQASearchConfig: # config for similarity evaluation self.threshold = 0.6 # config for llm - self.llm_device = -1 self.llm_model = 'openai' @@ -75,8 +74,12 @@ def _get_similarity_evaluation_op(config): def _get_llm_op(config): - if config.llm_model == 'openai': + if config.llm_model.lower() == 'openai': return ops.LLM.OpenAI(api_key=config.openai_api_key) + if config.llm_model.lower() == 'dolly': + return ops.LLM.Dolly() + + raise RuntimeError('Unknown llm model: [%s], only support \'openai\' and \'dolly\'' % (config.model)) @AutoPipes.register @@ -116,7 +119,7 @@ def enhanced_qa_search_pipe(config): p = ( p.map('result', 'docs', lambda x:[i[2] for i in x]) - .map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer()) + .map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer(llm_name=config.llm_model)) .map('prompt', 'answer', llm_op) )