diff --git a/openai_chat.py b/openai_chat.py index 3d1e23b..802df07 100644 --- a/openai_chat.py +++ b/openai_chat.py @@ -27,34 +27,32 @@ class OpenaiChat(PyOperator): **kwargs ): openai.api_key = api_key or os.getenv('OPENAI_API_KEY') + self._openai_version = openai.__version__ + if Version(self._openai_version) >= Version('1.0.0'): + from openai import OpenAI + self.client = OpenAI(api_key=openai.api_key) + self.openai_completion = self.client.chat.completions + else: + self.openai_completion = openai.ChatCompletion + self._model = model_name self.stream = kwargs.pop('stream') if 'stream' in kwargs else False self.kwargs = kwargs def __call__(self, messages: List[dict]): messages = self.parse_inputs(messages) - if Version(openai.__version__) >= Version('1.0.0'): - from openai import OpenAI - client = OpenAI() - response = client.chat.completions.create( + response = self.openai_completion.create( model=self._model, messages=messages, n=1, stream=self.stream, **self.kwargs ) + if Version(self._openai_version) >= Version('1.0.0'): if self.stream: response = (res.dict() for res in response) else: response = response.dict() - else: - response = openai.ChatCompletion.create( - model=self._model, - messages=messages, - n=1, - stream=self.stream, - **self.kwargs - ) if self.stream: return self.stream_output(response) else: