# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List import torch from transformers import pipeline from towhee.operator.base import PyOperator, SharedType class HuggingfaceDolly(PyOperator): '''Wrapper of Dolly inference''' def __init__(self, model_name: str = 'databricks/dolly-v2-12b', **kwargs ): torch_dtype = kwargs.get('torch_dtype', torch.bfloat16) trust_remote_code = kwargs.get('trust_remote_code', True) device_map = kwargs.get('device_map', 'auto') self.pipeline = pipeline(model=model_name, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map) def __call__(self, messages: List[dict]): prompt = self.parse_inputs(messages) ans = self.pipeline(prompt) return ans[0]['generated_text'] def parse_inputs(self, messages: List[dict]): assert isinstance(messages, list), \ 'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].' prompt = messages[-1]['question'] history = '' for m in messages[:-1]: for k, v in m.items(): if k == 'answer': history += v + '\n' return prompt + '\n' + history @staticmethod def supported_model_names(): model_list = [ 'databricks/dolly-v2-12b', 'databricks/dolly-v2-7b', 'databricks/dolly-v2-3b', 'databricks/dolly-v1-6b' ] model_list.sort() return model_list @property def shared_type(self): return SharedType.Shareable