# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import List import zhipuai from towhee.operator.base import PyOperator class ZhipuaiChat(PyOperator): '''Wrapper of OpenAI Chat API''' def __init__(self, model_name: str = 'chatglm_std', api_key: str = None, **kwargs ): zhipuai.api_key = api_key or os.getenv("ZHIPUAI_API_KEY") self._model = model_name self.kwargs = kwargs def __call__(self, messages: List[dict]): messages = self.parse_inputs(messages) self.stream = self.kwargs.pop('stream', False) if self.stream: response = zhipuai.model_api.sse_invoke( model=self._model, prompt=messages, **self.kwargs ) return self.stream_output(response) else: response = zhipuai.model_api.invoke( model=self._model, prompt=messages, **self.kwargs ) return response def parse_inputs(self, messages: List[dict]): assert isinstance(messages, list), \ 'Inputs must be a list of dictionaries with keys from ["question", "answer"].' new_messages = [] for m in messages: if ('role' and 'content' in m) and (m['role'] in ['assistant', 'user']): new_messages.append(m) else: for k, v in m.items(): if k == 'question': new_m = {'role': 'user', 'content': v} elif k == 'answer': new_m = {'role': 'assistant', 'content': v} else: raise KeyError('Invalid message key: only accept key value from ["question", "answer"].') new_messages.append(new_m) return new_messages @staticmethod def stream_output(response): for x in response.events(): yield {'event': x.event, 'id': x.id, 'data': x.data, 'meta': x.meta} @staticmethod def supported_model_names(): model_list = [ 'chatglm_130b', 'chatglm_6b' ] model_list.sort() return model_list