4 changed files with 151 additions and 1 deletions
			
			
		@ -1,2 +1,83 @@ | 
			
		|||||
# Dolly | 
				 | 
			
		||||
 | 
				# OpenAI Chat Completion | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				*author: Jael* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<br /> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## Description | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				A LLM operator generates answer given prompt in messages using a large language model or service. | 
			
		||||
 | 
				This operator uses a pretrained [Dolly](https://github.com/databrickslabs/dolly) to generate response. | 
			
		||||
 | 
				It will download model from [HuggingFace Models](https://huggingface.co/models). | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<br /> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## Code Example | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				Use the default model to continue the conversation from given messages. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				*Write a pipeline with explicit inputs/outputs name specifications:* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				```python | 
			
		||||
 | 
				from towhee import pipe, ops | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				p = ( | 
			
		||||
 | 
				    pipe.input('messages') | 
			
		||||
 | 
				        .map('messages', 'answer', ops.LLM.Dolly()) | 
			
		||||
 | 
				        .output('messages', 'answer') | 
			
		||||
 | 
				) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				messages=[ | 
			
		||||
 | 
				        {'question': 'Who won the world series in 2020?', 'answer': 'The Los Angeles Dodgers won the World Series in 2020.'}, | 
			
		||||
 | 
				        {'question': 'Where was it played?'} | 
			
		||||
 | 
				    ] | 
			
		||||
 | 
				answer = p(messages) | 
			
		||||
 | 
				``` | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<br /> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## Factory Constructor | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				Create the operator via the following factory method: | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				***LLM.Dolly(model_name: str)*** | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				**Parameters:** | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				***model_name***: *str* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				The model name in string, defaults to 'databricks/dolly-v2-12b'. Supported model names: | 
			
		||||
 | 
				- databricks/dolly-v2-12b | 
			
		||||
 | 
				- databricks/dolly-v2-7b | 
			
		||||
 | 
				- databricks/dolly-v2-3b | 
			
		||||
 | 
				- databricks/dolly-v1-6b | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				***\*\*kwargs*** | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				Other Dolly model parameters such as device_map. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<br /> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## Interface | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				The operator takes a piece of text in string as input. | 
			
		||||
 | 
				It returns answer in json. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				***\_\_call\_\_(txt)*** | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				**Parameters:** | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				***messages***: *list* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
					A list of messages to set up chat. | 
			
		||||
 | 
				Must be a list of dictionaries with key value from "system", "question", "answer". For example, [{"question": "a past question?", "answer": "a past answer."}, {"question": "current question?"}] | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				**Returns**: | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				*answer: str* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
					The answer generated. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<br /> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
			
		|||||
@ -0,0 +1,5 @@ | 
			
		|||||
 | 
				from .hf_dolly import HuggingfaceDolly | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def Dolly(*args, **kwargs): | 
			
		||||
 | 
				    return HuggingfaceDolly(*args, **kwargs) | 
			
		||||
@ -0,0 +1,62 @@ | 
			
		|||||
 | 
				# Copyright 2021 Zilliz. All rights reserved. | 
			
		||||
 | 
				# | 
			
		||||
 | 
				# Licensed under the Apache License, Version 2.0 (the "License"); | 
			
		||||
 | 
				# you may not use this file except in compliance with the License. | 
			
		||||
 | 
				# You may obtain a copy of the License at | 
			
		||||
 | 
				# | 
			
		||||
 | 
				#     http://www.apache.org/licenses/LICENSE-2.0 | 
			
		||||
 | 
				# | 
			
		||||
 | 
				# Unless required by applicable law or agreed to in writing, software | 
			
		||||
 | 
				# distributed under the License is distributed on an "AS IS" BASIS, | 
			
		||||
 | 
				# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
			
		||||
 | 
				# See the License for the specific language governing permissions and | 
			
		||||
 | 
				# limitations under the License. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from typing import List | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				from transformers import pipeline | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from towhee.operator.base import PyOperator | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class HuggingfaceDolly(PyOperator): | 
			
		||||
 | 
				    '''Wrapper of OpenAI Chat API''' | 
			
		||||
 | 
				    def __init__(self, | 
			
		||||
 | 
				                 model_name: str = 'databricks/dolly-v2-12b', | 
			
		||||
 | 
				                 **kwargs | 
			
		||||
 | 
				                 ): | 
			
		||||
 | 
				        torch_dtype = kwargs.get('torch_dtype', torch.bfloat16) | 
			
		||||
 | 
				        trust_remote_code = kwargs.get('trust_remote_code', True) | 
			
		||||
 | 
				        device_map = kwargs.get('device_map', 'auto') | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        self.pipeline = pipeline(model=model_name, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map=device_map) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def __call__(self, messages: List[dict]): | 
			
		||||
 | 
				        prompt = self.parse_inputs(messages) | 
			
		||||
 | 
				        ans = self.pipeline(prompt) | 
			
		||||
 | 
				        return ans | 
			
		||||
 | 
				         | 
			
		||||
 | 
				    def parse_inputs(self, messages: List[dict]): | 
			
		||||
 | 
				        assert isinstance(messages, list), \ | 
			
		||||
 | 
				            'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].' | 
			
		||||
 | 
				        prompt = messages[-1]['question'] | 
			
		||||
 | 
				        history = '' | 
			
		||||
 | 
				        for m in messages[:-1]: | 
			
		||||
 | 
				            for k, v in m.items(): | 
			
		||||
 | 
				                line = k + ': ' + v + '\n' | 
			
		||||
 | 
				                history += line | 
			
		||||
 | 
				        return prompt + '\n' + history | 
			
		||||
 | 
				     | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    @staticmethod | 
			
		||||
 | 
				    def supported_model_names(): | 
			
		||||
 | 
				        model_list = [ | 
			
		||||
 | 
				            'databricks/dolly-v2-12b', | 
			
		||||
 | 
				            'databricks/dolly-v2-7b', | 
			
		||||
 | 
				            'databricks/dolly-v2-3b', | 
			
		||||
 | 
				            'databricks/dolly-v1-6b' | 
			
		||||
 | 
				        ] | 
			
		||||
 | 
				        model_list.sort() | 
			
		||||
 | 
				        return model_list | 
			
		||||
 | 
				
 | 
			
		||||
@ -0,0 +1,2 @@ | 
			
		|||||
 | 
				transformers[torch]>=4.28.1,<5 | 
			
		||||
 | 
				torch>=1.13.1,<2 | 
			
		||||
					Loading…
					
					
				
		Reference in new issue