logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

33 lines
1.0 KiB

from typing import List, Dict
from towhee.operator import PyOperator
class TemplatePrompt(PyOperator):
def __init__(self, temp: str, keys: List[str], sys_msg: str = None):
super().__init__()
self._template = temp
self._keys = keys
self._sys_msg = sys_msg
def __call__(self, *args) -> List[Dict[str, str]]:
if self._sys_msg:
system_message = [{'system': self._sys_msg}]
else:
system_message = []
if len(self._keys) == len(args):
history = []
else:
history = args[-1]
kws = dict((item[0], item[1]) for item in zip(self._keys, args))
prompt_str = self._template.format(**kws)
ret = [{'question': prompt_str}]
if not history:
return system_message + ret
else:
history_data = []
for item in history:
history_data.append({'question': item[0], 'answer': item[1]})
return system_message + history_data + ret