4 changed files with 177 additions and 1 deletions
			
			
		| @ -1,2 +1,85 @@ | |||||
| # Ernie |  | ||||
|  | # 文心一言 | ||||
|  | 
 | ||||
|  | *author: Jael* | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
|  | ## Description | ||||
|  | 
 | ||||
|  | A LLM operator generates answer given prompt in messages using a large language model or service. | ||||
|  | This operator is implemented with Ernie Bot from [Baidu](https://cloud.baidu.com/wenxin.html). | ||||
|  | Please note you will need [Ernie API key & Secret key](https://ai.baidu.com/ai-doc/REFERENCE/Lkru0zoz4) to access the service. | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
|  | ## Code Example | ||||
|  | 
 | ||||
|  | Use the default model to continue the conversation from given messages. | ||||
|  | 
 | ||||
|  | *Write a pipeline with explicit inputs/outputs name specifications:* | ||||
|  | 
 | ||||
|  | ```python | ||||
|  | from towhee import pipe, ops | ||||
|  | 
 | ||||
|  | p = ( | ||||
|  |     pipe.input('messages') | ||||
|  |         .map('messages', 'answer', ops.LLM.Ernie(api_key=ERNIE_API_KEY, secret_key=ERNIE_SECRET_KEY)) | ||||
|  |         .output('messages', 'answer') | ||||
|  | ) | ||||
|  | 
 | ||||
|  | messages=[ | ||||
|  |         {'question': 'Zilliz Cloud 是什么?', 'answer': 'Zilliz Cloud 是一种全托管的向量检索服务。'}, | ||||
|  |         {'question': '它和 Milvus 的关系是什么?'} | ||||
|  |     ] | ||||
|  | answer = p(messages).get()[0] | ||||
|  | ``` | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
|  | ## Factory Constructor | ||||
|  | 
 | ||||
|  | Create the operator via the following factory method: | ||||
|  | 
 | ||||
|  | ***LLM.Ernie(api_key: str, secret_key: str)*** | ||||
|  | 
 | ||||
|  | **Parameters:** | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | ***api_key***: *str=None* | ||||
|  | 
 | ||||
|  | The Ernie API key in string, defaults to None. If None, it will use the environment variable `ERNIE_API_KEY`. | ||||
|  | 
 | ||||
|  | ***secret_key***: *str=None* | ||||
|  | 
 | ||||
|  | The Ernie Secret key in string, defaults to None. If None, it will use the environment variable `ERNIE_SECRET_KEY`. | ||||
|  | 
 | ||||
|  | ***\*\*kwargs*** | ||||
|  | 
 | ||||
|  | Other OpenAI parameters such as temperature, etc. | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
|  | ## Interface | ||||
|  | 
 | ||||
|  | The operator takes a piece of text in string as input. | ||||
|  | It returns answer in json. | ||||
|  | 
 | ||||
|  | ***\_\_call\_\_(txt)*** | ||||
|  | 
 | ||||
|  | **Parameters:** | ||||
|  | 
 | ||||
|  | ***messages***: *list* | ||||
|  | 
 | ||||
|  | 	A list of messages to set up chat. | ||||
|  | Must be a list of dictionaries with key value from "question", "answer". For example, [{"question": "a past question?", "answer": "a past answer."}, {"question": "current question?"}]. | ||||
|  | It also accepts the orignal Ernie message format like [{"role": "user", "content": "a question?"}, {"role": "assistant", "content": "an answer."}] | ||||
|  | 
 | ||||
|  | **Returns**: | ||||
|  | 
 | ||||
|  | *answer: str* | ||||
|  | 
 | ||||
|  | 	The next answer generated by role "assistant". | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
| 
 | 
 | ||||
|  | |||||
| @ -0,0 +1,5 @@ | |||||
|  | from .ernie_chat import ErnieChat | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | def Ernie(*args, **kwargs): | ||||
|  |     return ErnieChat(*args, **kwargs) | ||||
| @ -0,0 +1,87 @@ | |||||
|  | # Copyright 2021 Zilliz. All rights reserved. | ||||
|  | # | ||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  | # you may not use this file except in compliance with the License. | ||||
|  | # You may obtain a copy of the License at | ||||
|  | # | ||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | # | ||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  | # See the License for the specific language governing permissions and | ||||
|  | # limitations under the License. | ||||
|  | 
 | ||||
|  | import os | ||||
|  | import requests | ||||
|  | import json | ||||
|  | from typing import List | ||||
|  | 
 | ||||
|  | from towhee.operator.base import PyOperator | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | class ErnieChat(PyOperator): | ||||
|  |     '''Wrapper of OpenAI Chat API''' | ||||
|  |     def __init__(self, | ||||
|  |                  api_key: str = None, | ||||
|  |                  secret_key: str = None, | ||||
|  |                  **kwargs | ||||
|  |                  ): | ||||
|  |         self.api_key = api_key or os.getenv('ERNIE_API_KEY') | ||||
|  |         self.secret_key = secret_key or os.getenv('ERNIE_SECRET_KEY') | ||||
|  |         self.kwargs = kwargs | ||||
|  | 
 | ||||
|  |         try: | ||||
|  |             self.access_token = self.get_access_token(api_key=self.api_key, secret_key=self.secret_key) | ||||
|  |         except Exception as e: | ||||
|  |             raise RuntimeError(f'Failed to get access token: {e}') | ||||
|  |         self.url = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=' \ | ||||
|  |             + self.access_token | ||||
|  | 
 | ||||
|  |     def __call__(self, messages: List[dict]): | ||||
|  |         messages = self.parse_inputs(messages) | ||||
|  |         self.kwargs['messages'] = messages | ||||
|  |         payload = json.dumps(self.kwargs) | ||||
|  |         headers = { | ||||
|  |             'Content-Type': 'application/json' | ||||
|  |         } | ||||
|  |      | ||||
|  |         response = requests.request('POST', self.url, headers=headers, data=payload) | ||||
|  |      | ||||
|  |         # if self.kwargs.get('stream', False): | ||||
|  |         #     return self.stream_output(response) | ||||
|  |          | ||||
|  |         answer = response.json()['result'] | ||||
|  |         return answer | ||||
|  | 
 | ||||
|  |     def parse_inputs(self, messages: List[dict]): | ||||
|  |         assert isinstance(messages, list), \ | ||||
|  |             'Inputs must be a list of dictionaries with keys from ["question", "answer"] or ["role", "content"].' | ||||
|  |         new_messages = [] | ||||
|  |         for m in messages: | ||||
|  |             if ('role' and 'content' in m) and (m['role'] in ['assistant', 'user']): | ||||
|  |                 new_messages.append(m) | ||||
|  |             else: | ||||
|  |                 for k, v in m.items(): | ||||
|  |                     if k == 'question': | ||||
|  |                         new_m = {'role': 'user', 'content': v} | ||||
|  |                     elif k == 'answer': | ||||
|  |                         new_m = {'role': 'assistant', 'content': v} | ||||
|  |                     else: | ||||
|  |                         'Invalid message key: only accept key value from ["question", "answer"].' | ||||
|  |                     new_messages.append(new_m) | ||||
|  |         return new_messages | ||||
|  |      | ||||
|  |     def stream_output(self, response): | ||||
|  |          # todo | ||||
|  |          pass | ||||
|  |      | ||||
|  |     @staticmethod | ||||
|  |     def get_access_token(api_key, secret_key): | ||||
|  |         url = 'https://aip.baidubce.com/oauth/2.0/token' | ||||
|  |         params = { | ||||
|  |             'grant_type': 'client_credentials', | ||||
|  |             'client_id': api_key, | ||||
|  |             'client_secret': secret_key | ||||
|  |             } | ||||
|  |         return str(requests.post(url, params=params).json().get('access_token')) | ||||
| @ -0,0 +1 @@ | |||||
|  | requests | ||||
					Loading…
					
					
				
		Reference in new issue
	
	