From 0a352c2e120ccdb117b1ce000ceac21ea2c475b7 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 15 Sep 2023 15:41:52 +0800 Subject: [PATCH] Debug Signed-off-by: Jael Gu --- zhipuai_chat.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/zhipuai_chat.py b/zhipuai_chat.py index c3c0faa..0215b21 100644 --- a/zhipuai_chat.py +++ b/zhipuai_chat.py @@ -41,16 +41,13 @@ class ZhipuaiChat(PyOperator): prompt=messages, **self.kwargs ) + return self.stream_output(response) else: response = zhipuai.model_api.invoke( model=self._model, prompt=messages, **self.kwargs ) - if self.stream: - for x in response.events(): - yield {'event': x.event, 'id': x.id, 'data': x.data, 'meta': x.meta} - else: return response def parse_inputs(self, messages: List[dict]): @@ -67,12 +64,14 @@ class ZhipuaiChat(PyOperator): elif k == 'answer': new_m = {'role': 'assistant', 'content': v} else: - 'Invalid message key: only accept key value from ["question", "answer"].' + raise KeyError('Invalid message key: only accept key value from ["question", "answer"].') new_messages.append(new_m) return new_messages - - def stream_output(self, response): - raise RuntimeError('Stream is not yet supported.') + + @staticmethod + def stream_output(response): + for x in response.events(): + yield {'event': x.event, 'id': x.id, 'data': x.data, 'meta': x.meta} @staticmethod def supported_model_names():