# Copyright 2023 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import List from towhee.operator.base import PyOperator import erniebot class ErnieChat(PyOperator): '''Wrapper of Ernie Bot SDK''' def __init__(self, model_name: str = 'ernie-bot-turbo', eb_api_type: str = None, eb_access_token: str = None, **kwargs ): erniebot.api_type = eb_api_type or os.getenv('EB_API_TYPE') erniebot.access_token = eb_access_token or os.getenv('EB_ACCESS_TOKEN') self._model = model_name self.stream = kwargs.pop('stream') if 'stream' in kwargs else False self.kwargs = kwargs def __call__(self, messages: List[dict]): messages = self.parse_inputs(messages) response = erniebot.ChatCompletion.create( model=self._model, messages=messages, stream=self.stream, **self.kwargs ) if self.stream: return self.stream_output(response) else: answer = response.result return answer def parse_inputs(self, messages: List[dict]): assert isinstance(messages, list), \ 'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].' new_messages = [] for m in messages: if ('role' and 'content' in m) and (m['role'] in ['system', 'assistant', 'user']): new_messages.append(m) else: for k, v in m.items(): if k == 'question': new_m = {'role': 'user', 'content': v} new_messages.append(new_m) elif k == 'answer': new_m = {'role': 'assistant', 'content': v} new_messages.append(new_m) elif k == 'system': pass # new_m = {'role': 'system', 'content': v} else: raise KeyError( 'Invalid message key: only accept key value from ["question", "answer"].') return new_messages def stream_output(self, response): for resp in response: yield resp.result @staticmethod def supported_model_names(): model_list = [ 'ernie-bot', 'ernie-bot-turbo' ] model_list.sort() return model_list