From 6c3c8e24f16adee90b67bf3f379b8ff7c38d16e9 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 26 Jun 2023 14:59:08 +0800 Subject: [PATCH] Add system message Signed-off-by: Jael Gu --- README.md | 12 ++++++++---- __init__.py | 4 ++-- template_prompt.py | 11 ++++++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 6ea4b3a..52cb347 100644 --- a/README.md +++ b/README.md @@ -31,27 +31,31 @@ input: {context} """ +sys_message = """Your name is TowheeChat.""" + p = ( pipe.input('question', 'doc', 'history') .map('doc', 'doc', lambda x: x[:2000]) - .map(('question', 'doc', 'history'), 'prompt', ops.prompt.template(temp, ['question', 'context'])) + .map(('question', 'doc', 'history'), 'prompt', ops.prompt.template(temp, ['question', 'context'], sys_message)) .map('prompt', 'answer', ops.LLM.OpenAI()) .output('answer') ) -an1 = p('Tell me something about Towhee', towhee_docs, []).get()[0] +an1 = p('Who are you?', [], []).get()[0] print(an1) -an2 = p('How to use it', towhee_docs, [('Tell me something about Towhee', an1)]).get()[0] +an2 = p('Tell me something about Towhee', towhee_docs, []).get()[0] print(an2) +an3 = p('How to use it', towhee_docs, [('Tell me something about Towhee', an2)]).get()[0] +print(an3) ``` ## Factory Constructor Create the operator via the following factory method: -***ops.prompt.template(temp, keys)*** +***ops.prompt.template(temp, keys, sys_msg)***
diff --git a/__init__.py b/__init__.py index 31a0af9..946f64d 100644 --- a/__init__.py +++ b/__init__.py @@ -1,5 +1,5 @@ from typing import List from .template_prompt import TemplatePrompt -def template(temp: str, keys: List[str]): - return TemplatePrompt(temp, keys) +def template(*args, **kwargs): + return TemplatePrompt(*args, **kwargs) diff --git a/template_prompt.py b/template_prompt.py index bfa988d..599da51 100644 --- a/template_prompt.py +++ b/template_prompt.py @@ -4,12 +4,17 @@ from towhee.operator import PyOperator class TemplatePrompt(PyOperator): - def __init__(self, temp: str, keys: List[str]): + def __init__(self, temp: str, keys: List[str], sys_msg: str = None): super().__init__() self._template = temp self._keys = keys + self._sys_msg = sys_msg def __call__(self, *args) -> List[Dict[str, str]]: + if self._sys_msg: + system_message = [{'system': self._sys_msg}] + else: + system_message = [] if len(self._keys) == len(args): history = [] else: @@ -20,9 +25,9 @@ class TemplatePrompt(PyOperator): ret = [{'question': prompt_str}] if not history: - return ret + return system_message + ret else: history_data = [] for item in history: history_data.append({'question': item[0], 'answer': item[1]}) - return history_data + ret + return system_message + history_data + ret