diff --git a/__init__.py b/__init__.py index 02c054b..80de336 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,4 @@ from .qa_prompt import QAPrompt -def question_answer(): - return QAPrompt() +def question_answer(temp: str = None, llm_name: str = None): + return QAPrompt(temp, llm_name) diff --git a/qa_prompt.py b/qa_prompt.py index 60001d5..db2838c 100644 --- a/qa_prompt.py +++ b/qa_prompt.py @@ -3,10 +3,7 @@ from typing import List, Tuple, Dict, Optional from towhee.operator import PyOperator -class QAPrompt(PyOperator): - def __init__(self): - super().__init__() - self._template = """Use the following pieces of context to answer the question at the end. +gpt_prompt = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. {context} @@ -16,6 +13,24 @@ Question: {question} Helpful Answer: """ +dolly_prompt = """{question} + +Input: +{context} +""" + + +class QAPrompt(PyOperator): + def __init__(self, temp: str = None, llm_name: str = None): + super().__init__() + if temp: + self._template = temp + else: + if llm_name.lower() == 'dolly': + self._template = dolly_prompt + else: + self._template = gpt_prompt + def __call__(self, question: str, docs: List[str], history=Optional[List[Tuple]]) -> List[Dict[str, str]]: """ history: