diff --git a/template_prompt.py b/template_prompt.py index 77abdd5..fff1f90 100644 --- a/template_prompt.py +++ b/template_prompt.py @@ -1,18 +1,25 @@ -from typing import List, Tuple, Dict, Optional +from typing import List from towhee.operator import PyOperator class TemplatePrompt(PyOperator): - def __init__(self, temp: str): + def __init__(self, temp: str, keys: List[str]): super().__init__() self._template = temp + self._keys = keys - def __call__(self, **kwargs) -> List[Dict[str, str]]: - history = kwargs.get('history', []) - prompt_str = self._template.format(**kwargs) + def __call__(self, *args) -> List[Dict[str, str]]: + if len(self._keys) == len(args): + history = [] + else: + history = args[-1] + + kws = {(item[0], item[1]) for item in zip(self._keys, args)} + prompt_str = self._template.format(**kws) ret = [{'question': prompt_str}] - if not isinstance(history, list): + + if not history: return ret else: history_data = []