From 298deff019807bc382182e83af4d017f17e4ca02 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Tue, 30 May 2023 16:18:38 +0800 Subject: [PATCH] update Signed-off-by: junjie.jiang --- template_prompt.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/template_prompt.py b/template_prompt.py index 77abdd5..fff1f90 100644 --- a/template_prompt.py +++ b/template_prompt.py @@ -1,18 +1,25 @@ -from typing import List, Tuple, Dict, Optional +from typing import List from towhee.operator import PyOperator class TemplatePrompt(PyOperator): - def __init__(self, temp: str): + def __init__(self, temp: str, keys: List[str]): super().__init__() self._template = temp + self._keys = keys - def __call__(self, **kwargs) -> List[Dict[str, str]]: - history = kwargs.get('history', []) - prompt_str = self._template.format(**kwargs) + def __call__(self, *args) -> List[Dict[str, str]]: + if len(self._keys) == len(args): + history = [] + else: + history = args[-1] + + kws = {(item[0], item[1]) for item in zip(self._keys, args)} + prompt_str = self._template.format(**kws) ret = [{'question': prompt_str}] - if not isinstance(history, list): + + if not history: return ret else: history_data = []