logo
OpenAI
repo-copy-icon

copied

Browse Source

Add stream

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
f0e4d7b305
  1. 12
      openai_chat.py

12
openai_chat.py

@ -28,6 +28,7 @@ class OpenaiChat(PyOperator):
): ):
openai.api_key = os.getenv('OPENAI_API_KEY', api_key) openai.api_key = os.getenv('OPENAI_API_KEY', api_key)
self._model = model_name self._model = model_name
self.stream = kwargs.pop('stream') if 'stream' in kwargs else False
self.kwargs = kwargs self.kwargs = kwargs
def __call__(self, messages: List[dict]): def __call__(self, messages: List[dict]):
@ -36,12 +37,11 @@ class OpenaiChat(PyOperator):
model=self._model, model=self._model,
messages=messages, messages=messages,
n=1, n=1,
stream=self.stream,
**self.kwargs **self.kwargs
) )
if self.kwargs.get('stream'):
for chunk in response:
ans = chunk['choices'][0]['delta']
yield ans
if self.stream:
return self.stream_output(response)
else: else:
answer = response['choices'][0]['message']['content'] answer = response['choices'][0]['message']['content']
return answer return answer
@ -65,6 +65,10 @@ class OpenaiChat(PyOperator):
'Invalid message key: only accept key value from ["system", "question", "answer"].' 'Invalid message key: only accept key value from ["system", "question", "answer"].'
new_messages.append(new_m) new_messages.append(new_m)
return new_messages return new_messages
def stream_output(self, response):
for resp in response:
yield resp['choices'][0]['delta']
@staticmethod @staticmethod
def supported_model_names(): def supported_model_names():

Loading…
Cancel
Save