Browse Source
Add stream
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
8 additions and
4 deletions
-
openai_chat.py
|
@ -28,6 +28,7 @@ class OpenaiChat(PyOperator): |
|
|
): |
|
|
): |
|
|
openai.api_key = os.getenv('OPENAI_API_KEY', api_key) |
|
|
openai.api_key = os.getenv('OPENAI_API_KEY', api_key) |
|
|
self._model = model_name |
|
|
self._model = model_name |
|
|
|
|
|
self.stream = kwargs.pop('stream') if 'stream' in kwargs else False |
|
|
self.kwargs = kwargs |
|
|
self.kwargs = kwargs |
|
|
|
|
|
|
|
|
def __call__(self, messages: List[dict]): |
|
|
def __call__(self, messages: List[dict]): |
|
@ -36,12 +37,11 @@ class OpenaiChat(PyOperator): |
|
|
model=self._model, |
|
|
model=self._model, |
|
|
messages=messages, |
|
|
messages=messages, |
|
|
n=1, |
|
|
n=1, |
|
|
|
|
|
stream=self.stream, |
|
|
**self.kwargs |
|
|
**self.kwargs |
|
|
) |
|
|
) |
|
|
if self.kwargs.get('stream'): |
|
|
|
|
|
for chunk in response: |
|
|
|
|
|
ans = chunk['choices'][0]['delta'] |
|
|
|
|
|
yield ans |
|
|
|
|
|
|
|
|
if self.stream: |
|
|
|
|
|
return self.stream_output(response) |
|
|
else: |
|
|
else: |
|
|
answer = response['choices'][0]['message']['content'] |
|
|
answer = response['choices'][0]['message']['content'] |
|
|
return answer |
|
|
return answer |
|
@ -66,6 +66,10 @@ class OpenaiChat(PyOperator): |
|
|
new_messages.append(new_m) |
|
|
new_messages.append(new_m) |
|
|
return new_messages |
|
|
return new_messages |
|
|
|
|
|
|
|
|
|
|
|
def stream_output(self, response): |
|
|
|
|
|
for resp in response: |
|
|
|
|
|
yield resp['choices'][0]['delta'] |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def supported_model_names(): |
|
|
def supported_model_names(): |
|
|
model_list = [ |
|
|
model_list = [ |
|
|