diff --git a/README.md b/README.md index 384648f..434f9e4 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,83 @@ -# Dolly +# OpenAI Chat Completion + +*author: Jael* + +
+ +## Description + +A LLM operator generates answer given prompt in messages using a large language model or service. +This operator uses a pretrained [Dolly](https://github.com/databrickslabs/dolly) to generate response. +It will download model from [HuggingFace Models](https://huggingface.co/models). + +
+ +## Code Example + +Use the default model to continue the conversation from given messages. + +*Write a pipeline with explicit inputs/outputs name specifications:* + +```python +from towhee import pipe, ops + +p = ( + pipe.input('messages') + .map('messages', 'answer', ops.LLM.Dolly()) + .output('messages', 'answer') +) + +messages=[ + {'question': 'Who won the world series in 2020?', 'answer': 'The Los Angeles Dodgers won the World Series in 2020.'}, + {'question': 'Where was it played?'} + ] +answer = p(messages) +``` + +
+ +## Factory Constructor + +Create the operator via the following factory method: + +***LLM.Dolly(model_name: str)*** + +**Parameters:** + +***model_name***: *str* + +The model name in string, defaults to 'databricks/dolly-v2-12b'. Supported model names: +- databricks/dolly-v2-12b +- databricks/dolly-v2-7b +- databricks/dolly-v2-3b +- databricks/dolly-v1-6b + +***\*\*kwargs*** + +Other Dolly model parameters such as device_map. + +
+ +## Interface + +The operator takes a piece of text in string as input. +It returns answer in json. + +***\_\_call\_\_(txt)*** + +**Parameters:** + +***messages***: *list* + +​ A list of messages to set up chat. +Must be a list of dictionaries with key value from "system", "question", "answer". For example, [{"question": "a past question?", "answer": "a past answer."}, {"question": "current question?"}] + +**Returns**: + +*answer: str* + +​ The answer generated. + +
+ diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..b89bb8e --- /dev/null +++ b/__init__.py @@ -0,0 +1,5 @@ +from .hf_dolly import HuggingfaceDolly + + +def Dolly(*args, **kwargs): + return HuggingfaceDolly(*args, **kwargs) diff --git a/hf_dolly.py b/hf_dolly.py new file mode 100644 index 0000000..0558779 --- /dev/null +++ b/hf_dolly.py @@ -0,0 +1,62 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +from transformers import pipeline + +from towhee.operator.base import PyOperator + + +class HuggingfaceDolly(PyOperator): + '''Wrapper of OpenAI Chat API''' + def __init__(self, + model_name: str = 'databricks/dolly-v2-12b', + **kwargs + ): + torch_dtype = kwargs.get('torch_dtype', torch.bfloat16) + trust_remote_code = kwargs.get('trust_remote_code', True) + device_map = kwargs.get('device_map', 'auto') + + self.pipeline = pipeline(model=model_name, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map) + + def __call__(self, messages: List[dict]): + prompt = self.parse_inputs(messages) + ans = self.pipeline(prompt) + return ans + + def parse_inputs(self, messages: List[dict]): + assert isinstance(messages, list), \ + 'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].' + prompt = messages[-1]['question'] + history = '' + for m in messages[:-1]: + for k, v in m.items(): + line = k + ': ' + v + '\n' + history += line + return prompt + '\n' + history + + + @staticmethod + def supported_model_names(): + model_list = [ + 'databricks/dolly-v2-12b', + 'databricks/dolly-v2-7b', + 'databricks/dolly-v2-3b', + 'databricks/dolly-v1-6b' + ] + model_list.sort() + return model_list + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f3592a7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +transformers[torch]>=4.28.1,<5 +torch>=1.13.1,<2