From 805417c3f29a08fca9d1b9cb2d1be3dc09dc4fd6 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Tue, 30 May 2023 13:38:53 +0800 Subject: [PATCH] Update Signed-off-by: junjie.jiang --- README.md | 12 ++++++++++++ qa_prompt.py | 9 ++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index fc723ed..045f768 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,18 @@ Create the operator via the following factory method: ***ops.prompt.question_answer()*** +Args: + +Args: + +**temp: str** + + User-defined prompt, must contain {context} and {question}" + +**llm_name: str** + + Pre-defined prompt, currently supports openai and dolly, openai prompt is used by default." +
diff --git a/qa_prompt.py b/qa_prompt.py index db2838c..26af6e9 100644 --- a/qa_prompt.py +++ b/qa_prompt.py @@ -1,7 +1,9 @@ from typing import List, Tuple, Dict, Optional +import logging from towhee.operator import PyOperator +logger = logging.getLogger() gpt_prompt = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. @@ -26,9 +28,14 @@ class QAPrompt(PyOperator): if temp: self._template = temp else: - if llm_name.lower() == 'dolly': + if not llm_name: + self._template = gpt_prompt + elif llm_name.lower() == 'dolly': self._template = dolly_prompt + elif llm_name.lower() == 'openai': + self._template = gpt_prompt else: + logger.warning('Unkown llm_name, use default prompt') self._template = gpt_prompt def __call__(self, question: str, docs: List[str], history=Optional[List[Tuple]]) -> List[Dict[str, str]]: