From df73beb1ebdba030515ddce827594d13c9c53e33 Mon Sep 17 00:00:00 2001 From: jinlingxu06 Date: Sun, 8 Oct 2023 18:12:05 +0800 Subject: [PATCH] Update the Azure OpenAI Chat Operator. Signed-off-by: jinlingxu06 --- azure_openai_chat.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/azure_openai_chat.py b/azure_openai_chat.py index 9dbde77..ff58b8e 100644 --- a/azure_openai_chat.py +++ b/azure_openai_chat.py @@ -28,8 +28,12 @@ class AzureOpenaiChat(PyOperator): api_base = None, **kwargs ): - openai.api_key = api_key or os.getenv('OPENAI_API_KEY') - openai.api_base = api_base or os.getenv('OPENAI_API_BASE') + + self._api_key = api_key or os.getenv('OPENAI_API_KEY') + self._api_base = api_base or os.getenv('OPENAI_API_BASE') + self._api_type = api_type + self._api_version = api_version + self._model = model_name self.stream = kwargs.pop('stream') if 'stream' in kwargs else False self.kwargs = kwargs @@ -37,10 +41,14 @@ class AzureOpenaiChat(PyOperator): def __call__(self, messages: List[dict]): messages = self.parse_inputs(messages) response = openai.ChatCompletion.create( - model=self._model, + engine=self._model, messages=messages, n=1, stream=self.stream, + api_key = self._api_key, + api_type = self._api_type, + api_base = self._api_base, + api_version = self._api_version, **self.kwargs ) if self.stream: