4 changed files with 151 additions and 1 deletions
			
			
		@ -1,2 +1,83 @@ | 
				
			|||
# Dolly | 
				
			|||
# OpenAI Chat Completion | 
				
			|||
 | 
				
			|||
*author: Jael* | 
				
			|||
 | 
				
			|||
<br /> | 
				
			|||
 | 
				
			|||
## Description | 
				
			|||
 | 
				
			|||
A LLM operator generates answer given prompt in messages using a large language model or service. | 
				
			|||
This operator uses a pretrained [Dolly](https://github.com/databrickslabs/dolly) to generate response. | 
				
			|||
It will download model from [HuggingFace Models](https://huggingface.co/models). | 
				
			|||
 | 
				
			|||
<br /> | 
				
			|||
 | 
				
			|||
## Code Example | 
				
			|||
 | 
				
			|||
Use the default model to continue the conversation from given messages. | 
				
			|||
 | 
				
			|||
*Write a pipeline with explicit inputs/outputs name specifications:* | 
				
			|||
 | 
				
			|||
```python | 
				
			|||
from towhee import pipe, ops | 
				
			|||
 | 
				
			|||
p = ( | 
				
			|||
    pipe.input('messages') | 
				
			|||
        .map('messages', 'answer', ops.LLM.Dolly()) | 
				
			|||
        .output('messages', 'answer') | 
				
			|||
) | 
				
			|||
 | 
				
			|||
messages=[ | 
				
			|||
        {'question': 'Who won the world series in 2020?', 'answer': 'The Los Angeles Dodgers won the World Series in 2020.'}, | 
				
			|||
        {'question': 'Where was it played?'} | 
				
			|||
    ] | 
				
			|||
answer = p(messages) | 
				
			|||
``` | 
				
			|||
 | 
				
			|||
<br /> | 
				
			|||
 | 
				
			|||
## Factory Constructor | 
				
			|||
 | 
				
			|||
Create the operator via the following factory method: | 
				
			|||
 | 
				
			|||
***LLM.Dolly(model_name: str)*** | 
				
			|||
 | 
				
			|||
**Parameters:** | 
				
			|||
 | 
				
			|||
***model_name***: *str* | 
				
			|||
 | 
				
			|||
The model name in string, defaults to 'databricks/dolly-v2-12b'. Supported model names: | 
				
			|||
- databricks/dolly-v2-12b | 
				
			|||
- databricks/dolly-v2-7b | 
				
			|||
- databricks/dolly-v2-3b | 
				
			|||
- databricks/dolly-v1-6b | 
				
			|||
 | 
				
			|||
***\*\*kwargs*** | 
				
			|||
 | 
				
			|||
Other Dolly model parameters such as device_map. | 
				
			|||
 | 
				
			|||
<br /> | 
				
			|||
 | 
				
			|||
## Interface | 
				
			|||
 | 
				
			|||
The operator takes a piece of text in string as input. | 
				
			|||
It returns answer in json. | 
				
			|||
 | 
				
			|||
***\_\_call\_\_(txt)*** | 
				
			|||
 | 
				
			|||
**Parameters:** | 
				
			|||
 | 
				
			|||
***messages***: *list* | 
				
			|||
 | 
				
			|||
	A list of messages to set up chat. | 
				
			|||
Must be a list of dictionaries with key value from "system", "question", "answer". For example, [{"question": "a past question?", "answer": "a past answer."}, {"question": "current question?"}] | 
				
			|||
 | 
				
			|||
**Returns**: | 
				
			|||
 | 
				
			|||
*answer: str* | 
				
			|||
 | 
				
			|||
	The answer generated. | 
				
			|||
 | 
				
			|||
<br /> | 
				
			|||
 | 
				
			|||
 | 
				
			|||
 | 
				
			|||
@ -0,0 +1,5 @@ | 
				
			|||
from .hf_dolly import HuggingfaceDolly | 
				
			|||
 | 
				
			|||
 | 
				
			|||
def Dolly(*args, **kwargs): | 
				
			|||
    return HuggingfaceDolly(*args, **kwargs) | 
				
			|||
@ -0,0 +1,62 @@ | 
				
			|||
# Copyright 2021 Zilliz. All rights reserved. | 
				
			|||
# | 
				
			|||
# Licensed under the Apache License, Version 2.0 (the "License"); | 
				
			|||
# you may not use this file except in compliance with the License. | 
				
			|||
# You may obtain a copy of the License at | 
				
			|||
# | 
				
			|||
#     http://www.apache.org/licenses/LICENSE-2.0 | 
				
			|||
# | 
				
			|||
# Unless required by applicable law or agreed to in writing, software | 
				
			|||
# distributed under the License is distributed on an "AS IS" BASIS, | 
				
			|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
				
			|||
# See the License for the specific language governing permissions and | 
				
			|||
# limitations under the License. | 
				
			|||
 | 
				
			|||
from typing import List | 
				
			|||
 | 
				
			|||
import torch | 
				
			|||
from transformers import pipeline | 
				
			|||
 | 
				
			|||
from towhee.operator.base import PyOperator | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class HuggingfaceDolly(PyOperator): | 
				
			|||
    '''Wrapper of OpenAI Chat API''' | 
				
			|||
    def __init__(self, | 
				
			|||
                 model_name: str = 'databricks/dolly-v2-12b', | 
				
			|||
                 **kwargs | 
				
			|||
                 ): | 
				
			|||
        torch_dtype = kwargs.get('torch_dtype', torch.bfloat16) | 
				
			|||
        trust_remote_code = kwargs.get('trust_remote_code', True) | 
				
			|||
        device_map = kwargs.get('device_map', 'auto') | 
				
			|||
 | 
				
			|||
        self.pipeline = pipeline(model=model_name, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map) | 
				
			|||
 | 
				
			|||
    def __call__(self, messages: List[dict]): | 
				
			|||
        prompt = self.parse_inputs(messages) | 
				
			|||
        ans = self.pipeline(prompt) | 
				
			|||
        return ans | 
				
			|||
         | 
				
			|||
    def parse_inputs(self, messages: List[dict]): | 
				
			|||
        assert isinstance(messages, list), \ | 
				
			|||
            'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].' | 
				
			|||
        prompt = messages[-1]['question'] | 
				
			|||
        history = '' | 
				
			|||
        for m in messages[:-1]: | 
				
			|||
            for k, v in m.items(): | 
				
			|||
                line = k + ': ' + v + '\n' | 
				
			|||
                history += line | 
				
			|||
        return prompt + '\n' + history | 
				
			|||
     | 
				
			|||
 | 
				
			|||
    @staticmethod | 
				
			|||
    def supported_model_names(): | 
				
			|||
        model_list = [ | 
				
			|||
            'databricks/dolly-v2-12b', | 
				
			|||
            'databricks/dolly-v2-7b', | 
				
			|||
            'databricks/dolly-v2-3b', | 
				
			|||
            'databricks/dolly-v1-6b' | 
				
			|||
        ] | 
				
			|||
        model_list.sort() | 
				
			|||
        return model_list | 
				
			|||
 | 
				
			|||
@ -0,0 +1,2 @@ | 
				
			|||
transformers[torch]>=4.28.1,<5 | 
				
			|||
torch>=1.13.1,<2 | 
				
			|||
					Loading…
					
					
				
		Reference in new issue