# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import List from packaging.version import Version import openai from towhee.operator.base import PyOperator class OpenaiChat(PyOperator): '''Wrapper of OpenAI Chat API''' def __init__(self, model_name: str = 'gpt-3.5-turbo', api_key: str = None, **kwargs ): openai.api_key = api_key or os.getenv('OPENAI_API_KEY') self._openai_version = openai.__version__ if Version(self._openai_version) >= Version('1.0.0'): from openai import OpenAI self.client = OpenAI(api_key=openai.api_key) self.openai_completion = self.client.chat.completions else: self.openai_completion = openai.ChatCompletion self._model = model_name self.stream = kwargs.pop('stream') if 'stream' in kwargs else False self.kwargs = kwargs def __call__(self, messages: List[dict]): messages = self.parse_inputs(messages) response = self.openai_completion.create( model=self._model, messages=messages, n=1, stream=self.stream, **self.kwargs ) if Version(self._openai_version) >= Version('1.0.0'): if self.stream: response = (res.dict() for res in response) else: response = response.dict() if self.stream: return self.stream_output(response) else: answer = response['choices'][0]['message']['content'] return answer def parse_inputs(self, messages: List[dict]): assert isinstance(messages, list), \ 'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].' new_messages = [] for m in messages: if ('role' and 'content' in m) and (m['role'] in ['system', 'assistant', 'user']): new_messages.append(m) else: for k, v in m.items(): if k == 'question': new_m = {'role': 'user', 'content': v} elif k == 'answer': new_m = {'role': 'assistant', 'content': v} elif k == 'system': new_m = {'role': 'system', 'content': v} else: raise KeyError('Invalid message key: only accept key value from ["system", "question", "answer"].') new_messages.append(new_m) return new_messages def stream_output(self, response): for resp in response: yield resp['choices'][0]['delta'] @staticmethod def supported_model_names(): model_list = [ 'gpt-3.5-turbo', 'gpt-3.5-turbo-0301' ] model_list.sort() return model_list