diff --git a/__init__.py b/__init__.py index 38ad296..31a0af9 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,5 @@ +from typing import List from .template_prompt import TemplatePrompt -def template(temp: str): - return TemplatePrompt(temp) +def template(temp: str, keys: List[str]): + return TemplatePrompt(temp, keys) diff --git a/template_prompt.py b/template_prompt.py index b208fc8..bfa988d 100644 --- a/template_prompt.py +++ b/template_prompt.py @@ -15,7 +15,7 @@ class TemplatePrompt(PyOperator): else: history = args[-1] - kws = {(item[0], item[1]) for item in zip(self._keys, args)} + kws = dict((item[0], item[1]) for item in zip(self._keys, args)) prompt_str = self._template.format(**kws) ret = [{'question': prompt_str}]