|
@ -21,7 +21,7 @@ from towhee.operator.base import PyOperator |
|
|
class AzureOpenaiChat(PyOperator): |
|
|
class AzureOpenaiChat(PyOperator): |
|
|
'''Wrapper of OpenAI Chat API''' |
|
|
'''Wrapper of OpenAI Chat API''' |
|
|
def __init__(self, |
|
|
def __init__(self, |
|
|
model_name: str = 'gpt-3.5-turbo', |
|
|
|
|
|
|
|
|
deployment_name: str = 'gpt-3.5-turbo', |
|
|
api_type: str = 'azure', |
|
|
api_type: str = 'azure', |
|
|
api_version: str = '2023-07-01-preview', |
|
|
api_version: str = '2023-07-01-preview', |
|
|
api_key: str = None, |
|
|
api_key: str = None, |
|
@ -34,14 +34,14 @@ class AzureOpenaiChat(PyOperator): |
|
|
self._api_type = api_type |
|
|
self._api_type = api_type |
|
|
self._api_version = api_version |
|
|
self._api_version = api_version |
|
|
|
|
|
|
|
|
self._model = model_name |
|
|
|
|
|
|
|
|
self._deployment = deployment_name |
|
|
self.stream = kwargs.pop('stream') if 'stream' in kwargs else False |
|
|
self.stream = kwargs.pop('stream') if 'stream' in kwargs else False |
|
|
self.kwargs = kwargs |
|
|
self.kwargs = kwargs |
|
|
|
|
|
|
|
|
def __call__(self, messages: List[dict]): |
|
|
def __call__(self, messages: List[dict]): |
|
|
messages = self.parse_inputs(messages) |
|
|
messages = self.parse_inputs(messages) |
|
|
response = openai.ChatCompletion.create( |
|
|
response = openai.ChatCompletion.create( |
|
|
engine=self._model, |
|
|
|
|
|
|
|
|
engine=self._deployment, |
|
|
messages=messages, |
|
|
messages=messages, |
|
|
n=1, |
|
|
n=1, |
|
|
stream=self.stream, |
|
|
stream=self.stream, |
|
@ -81,15 +81,3 @@ class AzureOpenaiChat(PyOperator): |
|
|
for resp in response: |
|
|
for resp in response: |
|
|
yield resp['choices'][0]['delta'] |
|
|
yield resp['choices'][0]['delta'] |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def supported_model_names(): |
|
|
|
|
|
model_list = [ |
|
|
|
|
|
'gpt-3.5-turbo', |
|
|
|
|
|
'gpt-3.5-turbo-16k', |
|
|
|
|
|
'gpt-3.5-turbo-instruct', |
|
|
|
|
|
'gpt-3.5-turbo-0613', |
|
|
|
|
|
'gpt-3.5-turbo-16k-0613' |
|
|
|
|
|
] |
|
|
|
|
|
model_list.sort() |
|
|
|
|
|
return model_list |
|
|
|
|
|
|
|
|
|
|
|