# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import requests import json from typing import List from towhee.operator.base import PyOperator class ErnieChat(PyOperator): '''Wrapper of OpenAI Chat API''' def __init__(self, api_key: str = None, secret_key: str = None, **kwargs ): self.api_key = api_key or os.getenv('ERNIE_API_KEY') self.secret_key = secret_key or os.getenv('ERNIE_SECRET_KEY') self.kwargs = kwargs try: self.access_token = self.get_access_token(api_key=self.api_key, secret_key=self.secret_key) except Exception as e: raise RuntimeError(f'Failed to get access token: {e}') self.url = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=' \ + self.access_token def __call__(self, messages: List[dict]): messages = self.parse_inputs(messages) self.kwargs['messages'] = messages payload = json.dumps(self.kwargs) headers = { 'Content-Type': 'application/json' } response = requests.request('POST', self.url, headers=headers, data=payload) # if self.kwargs.get('stream', False): # return self.stream_output(response) answer = response.json()['result'] return answer def parse_inputs(self, messages: List[dict]): assert isinstance(messages, list), \ 'Inputs must be a list of dictionaries with keys from ["question", "answer"] or ["role", "content"].' new_messages = [] for m in messages: if ('role' and 'content' in m) and (m['role'] in ['assistant', 'user']): new_messages.append(m) else: for k, v in m.items(): if k == 'question': new_m = {'role': 'user', 'content': v} elif k == 'answer': new_m = {'role': 'assistant', 'content': v} else: raise KeyError('Invalid message key: only accept key value from ["question", "answer"].') new_messages.append(new_m) return new_messages def stream_output(self, response): # todo pass @staticmethod def get_access_token(api_key, secret_key): url = 'https://aip.baidubce.com/oauth/2.0/token' params = { 'grant_type': 'client_credentials', 'client_id': api_key, 'client_secret': secret_key } return str(requests.post(url, params=params).json().get('access_token'))