From 048f1cecd044bae0d3810e4e433af62b1dc441f0 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 30 Jun 2023 16:27:00 +0800 Subject: [PATCH] Allow system message Signed-off-by: Jael Gu --- ernie_chat.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/ernie_chat.py b/ernie_chat.py index 7a44f00..61b3363 100644 --- a/ernie_chat.py +++ b/ernie_chat.py @@ -63,13 +63,24 @@ class ErnieChat(PyOperator): new_messages.append(m) else: for k, v in m.items(): - if k == 'question': - new_m = {'role': 'user', 'content': v} - elif k == 'answer': - new_m = {'role': 'assistant', 'content': v} + if ('role' and 'content' in m) and (m['role'] in ['system', 'assistant', 'user']): + if m['role'] == 'system': + new_messages.append(m) + new_messages.append({'role': 'assistant', 'content': 'OK.'}) + else: + new_messages.append(m) else: - raise KeyError('Invalid message key: only accept key value from ["question", "answer"].') - new_messages.append(new_m) + for k, v in m.items(): + if k == 'question': + new_ms = [{'role': 'user', 'content': v}] + elif k == 'answer': + new_ms = [{'role': 'assistant', 'content': v}] + elif k == 'system': + new_ms = [{'role': 'user', 'content': v}, {'role': 'assistant', 'content': 'OK.'}] + else: + raise KeyError( + 'Invalid message key: only accept key value from ["question", "answer"].') + new_messages += new_ms return new_messages def stream_output(self, response):