From 0339dc92e7e569e6edecac9d783698b893c12a18 Mon Sep 17 00:00:00 2001 From: ChengZi Date: Thu, 16 Nov 2023 14:48:47 +0800 Subject: [PATCH] fix api key passing problem Signed-off-by: ChengZi --- openai_chat.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/openai_chat.py b/openai_chat.py index 3d1e23b..802df07 100644 --- a/openai_chat.py +++ b/openai_chat.py @@ -27,34 +27,32 @@ class OpenaiChat(PyOperator): **kwargs ): openai.api_key = api_key or os.getenv('OPENAI_API_KEY') + self._openai_version = openai.__version__ + if Version(self._openai_version) >= Version('1.0.0'): + from openai import OpenAI + self.client = OpenAI(api_key=openai.api_key) + self.openai_completion = self.client.chat.completions + else: + self.openai_completion = openai.ChatCompletion + self._model = model_name self.stream = kwargs.pop('stream') if 'stream' in kwargs else False self.kwargs = kwargs def __call__(self, messages: List[dict]): messages = self.parse_inputs(messages) - if Version(openai.__version__) >= Version('1.0.0'): - from openai import OpenAI - client = OpenAI() - response = client.chat.completions.create( + response = self.openai_completion.create( model=self._model, messages=messages, n=1, stream=self.stream, **self.kwargs ) + if Version(self._openai_version) >= Version('1.0.0'): if self.stream: response = (res.dict() for res in response) else: response = response.dict() - else: - response = openai.ChatCompletion.create( - model=self._model, - messages=messages, - n=1, - stream=self.stream, - **self.kwargs - ) if self.stream: return self.stream_output(response) else: