# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import List from huggingface_hub import hf_hub_download from llama_cpp import Llama from towhee.operator.base import PyOperator, SharedType class LlamaCpp(PyOperator): '''Wrapper of Dolly inference''' def __init__(self, model_name_or_file: str = 'llama-2-7b-chat', **kwargs ): self.kwargs = kwargs supported_model_names = self.supported_model_names() if model_name_or_file in supported_model_names: model_info = supported_model_names[model_name_or_file] hf_id = model_info['hf_id'] model_filename = model_info['filename'] self.model_path = hf_hub_download(repo_id=hf_id, filename=model_filename) else: self.model_path = model_name_or_file assert os.path.isfile(self.model_path), f'Invalid model path: {self.model_path}' init_kwargs = {} for k in vars(Llama).keys() & self.kwargs.keys(): init_kwargs[k] = self.kwargs.pop(k) self.model = Llama(model_path=self.model_path, **init_kwargs) def __call__(self, messages: List[dict]): messages = self.parse_inputs(messages) resp = self.model.create_chat_completion(messages, **self.kwargs) answer = self.parse_outputs(resp) return answer def parse_inputs(self, messages: List[dict]): assert isinstance(messages, list), \ 'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].' new_messages = [] for m in messages: for k, v in m.items(): if k == 'system': new_messages.append({'role': 'system', 'content': v}) elif k == 'question': new_messages.append({'role': 'user', 'content': v}) elif k == 'answer': new_messages.append({'role': 'assistant', 'content': v}) else: raise KeyError(f'Invalid key of message: {k}') return new_messages def parse_outputs(self, response): return response['choices'][0]['message']['content'] @staticmethod def supported_model_names(): models = { 'llama-2-7b-chat': { 'hf_id': 'TheBloke/Llama-2-7B-Chat-GGML', 'filename': 'llama-2-7b-chat.ggmlv3.q4_0.bin' }, 'llama-2-13b-chat': { 'hf_id': 'TheBloke/Llama-2-13B-chat-GGML', 'filename': 'llama-2-13b-chat.ggmlv3.q4_0.bin' } } return models @property def shared_type(self): return SharedType.Shareable