You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
|
from typing import List, Dict
|
|
|
|
|
|
|
|
from towhee.operator import PyOperator
|
|
|
|
|
|
|
|
|
|
|
|
class TemplatePrompt(PyOperator):
|
|
|
|
def __init__(self, temp: str, keys: List[str]):
|
|
|
|
super().__init__()
|
|
|
|
self._template = temp
|
|
|
|
self._keys = keys
|
|
|
|
|
|
|
|
def __call__(self, *args) -> List[Dict[str, str]]:
|
|
|
|
if len(self._keys) == len(args):
|
|
|
|
history = []
|
|
|
|
else:
|
|
|
|
history = args[-1]
|
|
|
|
|
|
|
|
kws = dict((item[0], item[1]) for item in zip(self._keys, args))
|
|
|
|
prompt_str = self._template.format(**kws)
|
|
|
|
ret = [{'question': prompt_str}]
|
|
|
|
|
|
|
|
if not history:
|
|
|
|
return ret
|
|
|
|
else:
|
|
|
|
history_data = []
|
|
|
|
for item in history:
|
|
|
|
history_data.append({'question': item[0], 'answer': item[1]})
|
|
|
|
return history_data + ret
|