Browse Source
update
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
1 changed files with
13 additions and
6 deletions
-
template_prompt.py
|
|
@ -1,18 +1,25 @@ |
|
|
|
from typing import List, Tuple, Dict, Optional |
|
|
|
from typing import List |
|
|
|
|
|
|
|
from towhee.operator import PyOperator |
|
|
|
|
|
|
|
|
|
|
|
class TemplatePrompt(PyOperator): |
|
|
|
def __init__(self, temp: str): |
|
|
|
def __init__(self, temp: str, keys: List[str]): |
|
|
|
super().__init__() |
|
|
|
self._template = temp |
|
|
|
self._keys = keys |
|
|
|
|
|
|
|
def __call__(self, **kwargs) -> List[Dict[str, str]]: |
|
|
|
history = kwargs.get('history', []) |
|
|
|
prompt_str = self._template.format(**kwargs) |
|
|
|
def __call__(self, *args) -> List[Dict[str, str]]: |
|
|
|
if len(self._keys) == len(args): |
|
|
|
history = [] |
|
|
|
else: |
|
|
|
history = args[-1] |
|
|
|
|
|
|
|
kws = {(item[0], item[1]) for item in zip(self._keys, args)} |
|
|
|
prompt_str = self._template.format(**kws) |
|
|
|
ret = [{'question': prompt_str}] |
|
|
|
if not isinstance(history, list): |
|
|
|
|
|
|
|
if not history: |
|
|
|
return ret |
|
|
|
else: |
|
|
|
history_data = [] |
|
|
|