diff --git a/README.md b/README.md index fc723ed..045f768 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,18 @@ Create the operator via the following factory method: ***ops.prompt.question_answer()*** +Args: + +Args: + +**temp: str** + + User-defined prompt, must contain {context} and {question}" + +**llm_name: str** + + Pre-defined prompt, currently supports openai and dolly, openai prompt is used by default." +
diff --git a/qa_prompt.py b/qa_prompt.py index db2838c..26af6e9 100644 --- a/qa_prompt.py +++ b/qa_prompt.py @@ -1,7 +1,9 @@ from typing import List, Tuple, Dict, Optional +import logging from towhee.operator import PyOperator +logger = logging.getLogger() gpt_prompt = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. @@ -26,9 +28,14 @@ class QAPrompt(PyOperator): if temp: self._template = temp else: - if llm_name.lower() == 'dolly': + if not llm_name: + self._template = gpt_prompt + elif llm_name.lower() == 'dolly': self._template = dolly_prompt + elif llm_name.lower() == 'openai': + self._template = gpt_prompt else: + logger.warning('Unkown llm_name, use default prompt') self._template = gpt_prompt def __call__(self, question: str, docs: List[str], history=Optional[List[Tuple]]) -> List[Dict[str, str]]: