4 changed files with 177 additions and 1 deletions
			
			
		| @ -1,2 +1,85 @@ | |||
| # Ernie | |||
| # 文心一言 | |||
| 
 | |||
| *author: Jael* | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| ## Description | |||
| 
 | |||
| A LLM operator generates answer given prompt in messages using a large language model or service. | |||
| This operator is implemented with Ernie Bot from [Baidu](https://cloud.baidu.com/wenxin.html). | |||
| Please note you will need [Ernie API key & Secret key](https://ai.baidu.com/ai-doc/REFERENCE/Lkru0zoz4) to access the service. | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| ## Code Example | |||
| 
 | |||
| Use the default model to continue the conversation from given messages. | |||
| 
 | |||
| *Write a pipeline with explicit inputs/outputs name specifications:* | |||
| 
 | |||
| ```python | |||
| from towhee import pipe, ops | |||
| 
 | |||
| p = ( | |||
|     pipe.input('messages') | |||
|         .map('messages', 'answer', ops.LLM.Ernie(api_key=ERNIE_API_KEY, secret_key=ERNIE_SECRET_KEY)) | |||
|         .output('messages', 'answer') | |||
| ) | |||
| 
 | |||
| messages=[ | |||
|         {'question': 'Zilliz Cloud 是什么?', 'answer': 'Zilliz Cloud 是一种全托管的向量检索服务。'}, | |||
|         {'question': '它和 Milvus 的关系是什么?'} | |||
|     ] | |||
| answer = p(messages).get()[0] | |||
| ``` | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| ## Factory Constructor | |||
| 
 | |||
| Create the operator via the following factory method: | |||
| 
 | |||
| ***LLM.Ernie(api_key: str, secret_key: str)*** | |||
| 
 | |||
| **Parameters:** | |||
| 
 | |||
| 
 | |||
| ***api_key***: *str=None* | |||
| 
 | |||
| The Ernie API key in string, defaults to None. If None, it will use the environment variable `ERNIE_API_KEY`. | |||
| 
 | |||
| ***secret_key***: *str=None* | |||
| 
 | |||
| The Ernie Secret key in string, defaults to None. If None, it will use the environment variable `ERNIE_SECRET_KEY`. | |||
| 
 | |||
| ***\*\*kwargs*** | |||
| 
 | |||
| Other OpenAI parameters such as temperature, etc. | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| ## Interface | |||
| 
 | |||
| The operator takes a piece of text in string as input. | |||
| It returns answer in json. | |||
| 
 | |||
| ***\_\_call\_\_(txt)*** | |||
| 
 | |||
| **Parameters:** | |||
| 
 | |||
| ***messages***: *list* | |||
| 
 | |||
| 	A list of messages to set up chat. | |||
| Must be a list of dictionaries with key value from "question", "answer". For example, [{"question": "a past question?", "answer": "a past answer."}, {"question": "current question?"}]. | |||
| It also accepts the orignal Ernie message format like [{"role": "user", "content": "a question?"}, {"role": "assistant", "content": "an answer."}] | |||
| 
 | |||
| **Returns**: | |||
| 
 | |||
| *answer: str* | |||
| 
 | |||
| 	The next answer generated by role "assistant". | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| 
 | |||
|  | |||
| @ -0,0 +1,5 @@ | |||
| from .ernie_chat import ErnieChat | |||
| 
 | |||
| 
 | |||
| def Ernie(*args, **kwargs): | |||
|     return ErnieChat(*args, **kwargs) | |||
| @ -0,0 +1,87 @@ | |||
| # Copyright 2021 Zilliz. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| #     http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| 
 | |||
| import os | |||
| import requests | |||
| import json | |||
| from typing import List | |||
| 
 | |||
| from towhee.operator.base import PyOperator | |||
| 
 | |||
| 
 | |||
| class ErnieChat(PyOperator): | |||
|     '''Wrapper of OpenAI Chat API''' | |||
|     def __init__(self, | |||
|                  api_key: str = None, | |||
|                  secret_key: str = None, | |||
|                  **kwargs | |||
|                  ): | |||
|         self.api_key = api_key or os.getenv('ERNIE_API_KEY') | |||
|         self.secret_key = secret_key or os.getenv('ERNIE_SECRET_KEY') | |||
|         self.kwargs = kwargs | |||
| 
 | |||
|         try: | |||
|             self.access_token = self.get_access_token(api_key=self.api_key, secret_key=self.secret_key) | |||
|         except Exception as e: | |||
|             raise RuntimeError(f'Failed to get access token: {e}') | |||
|         self.url = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions?access_token=' \ | |||
|             + self.access_token | |||
| 
 | |||
|     def __call__(self, messages: List[dict]): | |||
|         messages = self.parse_inputs(messages) | |||
|         self.kwargs['messages'] = messages | |||
|         payload = json.dumps(self.kwargs) | |||
|         headers = { | |||
|             'Content-Type': 'application/json' | |||
|         } | |||
|      | |||
|         response = requests.request('POST', self.url, headers=headers, data=payload) | |||
|      | |||
|         # if self.kwargs.get('stream', False): | |||
|         #     return self.stream_output(response) | |||
|          | |||
|         answer = response.json()['result'] | |||
|         return answer | |||
| 
 | |||
|     def parse_inputs(self, messages: List[dict]): | |||
|         assert isinstance(messages, list), \ | |||
|             'Inputs must be a list of dictionaries with keys from ["question", "answer"] or ["role", "content"].' | |||
|         new_messages = [] | |||
|         for m in messages: | |||
|             if ('role' and 'content' in m) and (m['role'] in ['assistant', 'user']): | |||
|                 new_messages.append(m) | |||
|             else: | |||
|                 for k, v in m.items(): | |||
|                     if k == 'question': | |||
|                         new_m = {'role': 'user', 'content': v} | |||
|                     elif k == 'answer': | |||
|                         new_m = {'role': 'assistant', 'content': v} | |||
|                     else: | |||
|                         'Invalid message key: only accept key value from ["question", "answer"].' | |||
|                     new_messages.append(new_m) | |||
|         return new_messages | |||
|      | |||
|     def stream_output(self, response): | |||
|          # todo | |||
|          pass | |||
|      | |||
|     @staticmethod | |||
|     def get_access_token(api_key, secret_key): | |||
|         url = 'https://aip.baidubce.com/oauth/2.0/token' | |||
|         params = { | |||
|             'grant_type': 'client_credentials', | |||
|             'client_id': api_key, | |||
|             'client_secret': secret_key | |||
|             } | |||
|         return str(requests.post(url, params=params).json().get('access_token')) | |||
| @ -0,0 +1 @@ | |||
| requests | |||
					Loading…
					
					
				
		Reference in new issue
	
	