From 7dd5772f06a3848976c6b45086d9d80ccfc1c482 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Wed, 7 Jun 2023 14:43:04 +0800 Subject: [PATCH] Add customize prompt Signed-off-by: junjie.jiang --- README.md | 4 ++++ eqa_search.py | 9 ++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 90fcb3a..1751bbb 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,10 @@ The dolly model name, defaults to `databricks/dolly-v2-3b`. Users customize LLM. +**customize_prompt (Any):*** + +Users customize prompt. +
diff --git a/eqa_search.py b/eqa_search.py index 39f341d..1b37b74 100644 --- a/eqa_search.py +++ b/eqa_search.py @@ -41,6 +41,7 @@ class EnhancedQASearchConfig(BaseModel): dolly_model: Optional[str] = 'databricks/dolly-v2-3b' customize_llm: Optional[Any] = None + customize_prompt: Optional[Any] = None _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() @@ -89,6 +90,12 @@ def _get_llm_op(config): raise RuntimeError('Unknown llm source: [%s], only support \'openai\' and \'dolly\'' % (config.llm_src)) +def _get_prompt(config): + if config.customize_prompt: + return config.customize_prompt + return ops.prompt.question_answer(llm_name=config.llm_src) + + @AutoPipes.register def enhanced_qa_search_pipe(config): allow_triton, sentence_embedding_op = _get_embedding_op(config) @@ -126,7 +133,7 @@ def enhanced_qa_search_pipe(config): p = ( p.map('result', 'docs', lambda x: '\n'.join([i[2] for i in x])) - .map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer(llm_name=config.llm_src)) + .map(('question', 'docs', 'history'), 'prompt', _get_prompt(config)) .map('prompt', 'answer', llm_op) )