diff --git a/ernie_chat.py b/ernie_chat.py index 951a967..654da56 100644 --- a/ernie_chat.py +++ b/ernie_chat.py @@ -1,4 +1,4 @@ -# Copyright 2021 Zilliz. All rights reserved. +# Copyright 2023 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,86 +13,71 @@ # limitations under the License. import os -import requests -import json from typing import List from towhee.operator.base import PyOperator +import erniebot + class ErnieChat(PyOperator): - '''Wrapper of OpenAI Chat API''' + '''Wrapper of Ernie Bot SDK''' + def __init__(self, - api_key: str = None, - secret_key: str = None, + model_name: str = 'ernie-bot-turbo', + api_type: str = None, + access_token: str = None, **kwargs ): - self.api_key = api_key or os.getenv('ERNIE_API_KEY') - self.secret_key = secret_key or os.getenv('ERNIE_SECRET_KEY') + erniebot.api_type = api_type or os.getenv('ERNIEBOT_API_TYPE') + erniebot.access_token = access_token or os.getenv('ERNIEBOT_ACCESS_TOKEN') + self._model = model_name + self.stream = kwargs.pop('stream') if 'stream' in kwargs else False self.kwargs = kwargs - try: - self.access_token = self.get_access_token(api_key=self.api_key, secret_key=self.secret_key) - except Exception as e: - raise RuntimeError(f'Failed to get access token: {e}') - self.url = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=' \ - + self.access_token - def __call__(self, messages: List[dict]): - messages = self.parse_inputs(messages) - self.kwargs['messages'] = messages - payload = json.dumps(self.kwargs) - headers = { - 'Content-Type': 'application/json' - } - - response = requests.request('POST', self.url, headers=headers, data=payload) - - # if self.kwargs.get('stream', False): - # return self.stream_output(response) - - answer = response.json()['result'] - return answer + response = erniebot.ChatCompletion.create( + model=self._model, + messages=messages, + stream=self.stream, + **self.kwargs + ) + if self.stream: + return self.stream_output(response) + else: + answer = response.result + return answer def parse_inputs(self, messages: List[dict]): assert isinstance(messages, list), \ - 'Inputs must be a list of dictionaries with keys from ["question", "answer"] or ["role", "content"].' + 'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].' new_messages = [] for m in messages: - if ('role' and 'content' in m) and (m['role'] in ['assistant', 'user']): + if ('role' and 'content' in m) and (m['role'] in ['system', 'assistant', 'user']): new_messages.append(m) else: for k, v in m.items(): - if ('role' and 'content' in m) and (m['role'] in ['system', 'assistant', 'user']): - if m['role'] == 'system': - new_messages.append({'role': 'user', 'content': m['content']}) - new_messages.append({'role': 'assistant', 'content': 'OK.'}) - else: - new_messages.append(m) + if k == 'question': + new_m = {'role': 'user', 'content': v} + elif k == 'answer': + new_m = {'role': 'assistant', 'content': v} + elif k == 'system': + new_m = {'role': 'system', 'content': v} else: - for k, v in m.items(): - if k == 'question': - new_ms = [{'role': 'user', 'content': v}] - elif k == 'answer': - new_ms = [{'role': 'assistant', 'content': v}] - elif k == 'system': - new_ms = [{'role': 'user', 'content': v}, {'role': 'assistant', 'content': 'OK.'}] - else: - raise KeyError( - 'Invalid message key: only accept key value from ["question", "answer"].') - new_messages += new_ms + raise KeyError( + 'Invalid message key: only accept key value from ["system", "question", "answer"].') + new_messages.append(new_m) return new_messages - + def stream_output(self, response): - # todo - pass - + for resp in response: + yield resp.result + @staticmethod - def get_access_token(api_key, secret_key): - url = 'https://aip.baidubce.com/oauth/2.0/token' - params = { - 'grant_type': 'client_credentials', - 'client_id': api_key, - 'client_secret': secret_key - } - return str(requests.post(url, params=params).json().get('access_token')) + def supported_model_names(): + model_list = [ + 'ernie-bot', + 'ernie-bot-turbo' + ] + model_list.sort() + return model_list diff --git a/legacy/__init__.py b/legacy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/legacy/ernie_chat_v1.py b/legacy/ernie_chat_v1.py new file mode 100644 index 0000000..a3ce503 --- /dev/null +++ b/legacy/ernie_chat_v1.py @@ -0,0 +1,101 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This is a legacy code, for the new version, please use ../ernie_chat.py + +import os +import requests +import json +from typing import List + +from towhee.operator.base import PyOperator + + +class ErnieChat(PyOperator): + '''Wrapper of OpenAI Chat API''' + def __init__(self, + api_key: str = None, + secret_key: str = None, + **kwargs + ): + self.api_key = api_key or os.getenv('ERNIE_API_KEY') + self.secret_key = secret_key or os.getenv('ERNIE_SECRET_KEY') + self.kwargs = kwargs + + try: + self.access_token = self.get_access_token(api_key=self.api_key, secret_key=self.secret_key) + except Exception as e: + raise RuntimeError(f'Failed to get access token: {e}') + self.url = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=' \ + + self.access_token + + def __call__(self, messages: List[dict]): + messages = self.parse_inputs(messages) + self.kwargs['messages'] = messages + payload = json.dumps(self.kwargs) + headers = { + 'Content-Type': 'application/json' + } + + response = requests.request('POST', self.url, headers=headers, data=payload) + + # if self.kwargs.get('stream', False): + # return self.stream_output(response) + + answer = response.json()['result'] + return answer + + def parse_inputs(self, messages: List[dict]): + assert isinstance(messages, list), \ + 'Inputs must be a list of dictionaries with keys from ["question", "answer"] or ["role", "content"].' + new_messages = [] + for m in messages: + if ('role' and 'content' in m) and (m['role'] in ['assistant', 'user']): + new_messages.append(m) + else: + for k, v in m.items(): + if ('role' and 'content' in m) and (m['role'] in ['system', 'assistant', 'user']): + if m['role'] == 'system': + new_messages.append({'role': 'user', 'content': m['content']}) + new_messages.append({'role': 'assistant', 'content': 'OK.'}) + else: + new_messages.append(m) + else: + for k, v in m.items(): + if k == 'question': + new_ms = [{'role': 'user', 'content': v}] + elif k == 'answer': + new_ms = [{'role': 'assistant', 'content': v}] + elif k == 'system': + new_ms = [{'role': 'user', 'content': v}, {'role': 'assistant', 'content': 'OK.'}] + else: + raise KeyError( + 'Invalid message key: only accept key value from ["question", "answer"].') + new_messages += new_ms + return new_messages + + def stream_output(self, response): + # todo + pass + + @staticmethod + def get_access_token(api_key, secret_key): + url = 'https://aip.baidubce.com/oauth/2.0/token' + params = { + 'grant_type': 'client_credentials', + 'client_id': api_key, + 'client_secret': secret_key + } + return str(requests.post(url, params=params).json().get('access_token')) +