diff --git a/openai_chat.py b/openai_chat.py index dc1850d..9cd9e0e 100644 --- a/openai_chat.py +++ b/openai_chat.py @@ -28,6 +28,7 @@ class OpenaiChat(PyOperator): ): openai.api_key = os.getenv('OPENAI_API_KEY', api_key) self._model = model_name + self.stream = kwargs.pop('stream') if 'stream' in kwargs else False self.kwargs = kwargs def __call__(self, messages: List[dict]): @@ -36,12 +37,11 @@ class OpenaiChat(PyOperator): model=self._model, messages=messages, n=1, + stream=self.stream, **self.kwargs ) - if self.kwargs.get('stream'): - for chunk in response: - ans = chunk['choices'][0]['delta'] - yield ans + if self.stream: + return self.stream_output(response) else: answer = response['choices'][0]['message']['content'] return answer @@ -65,6 +65,10 @@ class OpenaiChat(PyOperator): 'Invalid message key: only accept key value from ["system", "question", "answer"].' new_messages.append(new_m) return new_messages + + def stream_output(self, response): + for resp in response: + yield resp['choices'][0]['delta'] @staticmethod def supported_model_names():