4 changed files with 151 additions and 1 deletions
@ -1,2 +1,83 @@ |
|||
# Dolly |
|||
# OpenAI Chat Completion |
|||
|
|||
*author: Jael* |
|||
|
|||
<br /> |
|||
|
|||
## Description |
|||
|
|||
A LLM operator generates answer given prompt in messages using a large language model or service. |
|||
This operator uses a pretrained [Dolly](https://github.com/databrickslabs/dolly) to generate response. |
|||
It will download model from [HuggingFace Models](https://huggingface.co/models). |
|||
|
|||
<br /> |
|||
|
|||
## Code Example |
|||
|
|||
Use the default model to continue the conversation from given messages. |
|||
|
|||
*Write a pipeline with explicit inputs/outputs name specifications:* |
|||
|
|||
```python |
|||
from towhee import pipe, ops |
|||
|
|||
p = ( |
|||
pipe.input('messages') |
|||
.map('messages', 'answer', ops.LLM.Dolly()) |
|||
.output('messages', 'answer') |
|||
) |
|||
|
|||
messages=[ |
|||
{'question': 'Who won the world series in 2020?', 'answer': 'The Los Angeles Dodgers won the World Series in 2020.'}, |
|||
{'question': 'Where was it played?'} |
|||
] |
|||
answer = p(messages) |
|||
``` |
|||
|
|||
<br /> |
|||
|
|||
## Factory Constructor |
|||
|
|||
Create the operator via the following factory method: |
|||
|
|||
***LLM.Dolly(model_name: str)*** |
|||
|
|||
**Parameters:** |
|||
|
|||
***model_name***: *str* |
|||
|
|||
The model name in string, defaults to 'databricks/dolly-v2-12b'. Supported model names: |
|||
- databricks/dolly-v2-12b |
|||
- databricks/dolly-v2-7b |
|||
- databricks/dolly-v2-3b |
|||
- databricks/dolly-v1-6b |
|||
|
|||
***\*\*kwargs*** |
|||
|
|||
Other Dolly model parameters such as device_map. |
|||
|
|||
<br /> |
|||
|
|||
## Interface |
|||
|
|||
The operator takes a piece of text in string as input. |
|||
It returns answer in json. |
|||
|
|||
***\_\_call\_\_(txt)*** |
|||
|
|||
**Parameters:** |
|||
|
|||
***messages***: *list* |
|||
|
|||
A list of messages to set up chat. |
|||
Must be a list of dictionaries with key value from "system", "question", "answer". For example, [{"question": "a past question?", "answer": "a past answer."}, {"question": "current question?"}] |
|||
|
|||
**Returns**: |
|||
|
|||
*answer: str* |
|||
|
|||
The answer generated. |
|||
|
|||
<br /> |
|||
|
|||
|
|||
|
@ -0,0 +1,5 @@ |
|||
from .hf_dolly import HuggingfaceDolly |
|||
|
|||
|
|||
def Dolly(*args, **kwargs): |
|||
return HuggingfaceDolly(*args, **kwargs) |
@ -0,0 +1,62 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from typing import List |
|||
|
|||
import torch |
|||
from transformers import pipeline |
|||
|
|||
from towhee.operator.base import PyOperator |
|||
|
|||
|
|||
class HuggingfaceDolly(PyOperator): |
|||
'''Wrapper of OpenAI Chat API''' |
|||
def __init__(self, |
|||
model_name: str = 'databricks/dolly-v2-12b', |
|||
**kwargs |
|||
): |
|||
torch_dtype = kwargs.get('torch_dtype', torch.bfloat16) |
|||
trust_remote_code = kwargs.get('trust_remote_code', True) |
|||
device_map = kwargs.get('device_map', 'auto') |
|||
|
|||
self.pipeline = pipeline(model=model_name, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map) |
|||
|
|||
def __call__(self, messages: List[dict]): |
|||
prompt = self.parse_inputs(messages) |
|||
ans = self.pipeline(prompt) |
|||
return ans |
|||
|
|||
def parse_inputs(self, messages: List[dict]): |
|||
assert isinstance(messages, list), \ |
|||
'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].' |
|||
prompt = messages[-1]['question'] |
|||
history = '' |
|||
for m in messages[:-1]: |
|||
for k, v in m.items(): |
|||
line = k + ': ' + v + '\n' |
|||
history += line |
|||
return prompt + '\n' + history |
|||
|
|||
|
|||
@staticmethod |
|||
def supported_model_names(): |
|||
model_list = [ |
|||
'databricks/dolly-v2-12b', |
|||
'databricks/dolly-v2-7b', |
|||
'databricks/dolly-v2-3b', |
|||
'databricks/dolly-v1-6b' |
|||
] |
|||
model_list.sort() |
|||
return model_list |
|||
|
@ -0,0 +1,2 @@ |
|||
transformers[torch]>=4.28.1,<5 |
|||
torch>=1.13.1,<2 |
Loading…
Reference in new issue