logo
Browse Source

Update the operator.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
parent
commit
dec32368de
  1. 135
      README.md
  2. 5
      __init__.py
  3. 87
      azure_openai_chat.py

135
README.md

@ -1,2 +1,135 @@
# Azure-OpenAI
# OpenAI Chat Completion
*author: David Wang*
<br />
## Description
A LLM operator generates answer given prompt in messages using a large language model or service.
This operator is implemented with Chat Completion method from [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/chatgpt?pivots=programming-language-chat-completions).
Please note you need an [OpenAI API key](https://platform.openai.com/account/api-keys) to access OpenAI.
<br />
## Code Example
Use the default model to continue the conversation from given messages.
*Write a pipeline with explicit inputs/outputs name specifications:*
```python
from towhee import pipe, ops
p = (
pipe.input('messages')
.map('messages', 'answer', ops.LLM.Azure_OpenAI(api_key=OPENAI_API_KEY, api_base=OPENAI_API_BASE))
.output('messages', 'answer')
)
messages=[
{'question': 'Who won the world series in 2020?', 'answer': 'The Los Angeles Dodgers won the World Series in 2020.'},
{'question': 'Where was it played?'}
]
answer = p(messages).get()[0]
```
*Write a [retrieval-augmented generation pipeline](https://towhee.io/tasks/detail/pipeline/retrieval-augmented-generation) with explicit inputs/outputs name specifications:*
```python
from towhee import pipe, ops
temp = '''Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Helpful Answer:
'''
docs = ['You can install towhee via command `pip install towhee`.']
history = [
('What is Towhee?', 'Towhee is an open-source machine learning pipeline that helps you encode your unstructured data into embeddings.')
]
question = 'How to install it?'
p = (
pipe.input('question', 'docs', 'history')
.map(('question', 'docs', 'history'), 'prompt', ops.prompt.template(temp, ['question', 'context']))
.map('prompt', 'answer',
ops.LLM.Azure_OpenAI(api_key=OPENAI_API_KEY, api_base=OPENAI_API_BASE, temperature=0.5, max_tokens=100)
)
.output('answer')
)
answer = p(question, docs, history).get()[0]
```
<br />
## Factory Constructor
Create the operator via the following factory method:
***LLM.OpenAI(model_name: str, api_key: str)***
**Parameters:**
***model_name***: *str*
The model name in string, defaults to 'gpt-3.5-turbo'. Supported model names:
- gpt-3.5-turbo
- gpt-3.5-turbo-16k
- gpt-3.5-turbo-instruct
- gpt-3.5-turbo-0613
- gpt-3.5-turbo-16k-0613
***api_type***: *str=None*
The OpenAI API type in string, defaults to None.
***api_version***: *str=None*
The OpenAI API version in string, defaults to None.
***api_key***: *str=None*
The OpenAI API key in string, defaults to None.
***api_base***: *str=None*
The OpenAI API base in string, defaults to None.
***\*\*kwargs***
Other OpenAI parameters such as max_tokens, stream, temperature, etc.
<br />
## Interface
The operator takes a piece of text in string as input.
It returns answer in json.
***\_\_call\_\_(txt)***
**Parameters:**
***messages***: *list*
​ A list of messages to set up chat.
Must be a list of dictionaries with key value from "system", "question", "answer". For example, [{"question": "a past question?", "answer": "a past answer."}, {"question": "current question?"}]
**Returns**:
*answer: str*
​ The next answer generated by role "assistant".
<br />

5
__init__.py

@ -0,0 +1,5 @@
from .azure_openai_chat import AzureOpenaiChat
def AzureOpenAI(*args, **kwargs):
return AzureOpenaiChat(*args, **kwargs)

87
azure_openai_chat.py

@ -0,0 +1,87 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List
import openai
from towhee.operator.base import PyOperator
class AzureOpenaiChat(PyOperator):
'''Wrapper of OpenAI Chat API'''
def __init__(self,
model_name: str = 'gpt-3.5-turbo',
api_type: str = 'azure',
api_version: str = '2023-07-01-preview',
api_key: str = None,
api_base = None,
**kwargs
):
openai.api_key = api_key or os.getenv('OPENAI_API_KEY')
openai.api_base = api_base or os.getenv('OPENAI_API_BASE')
self._model = model_name
self.stream = kwargs.pop('stream') if 'stream' in kwargs else False
self.kwargs = kwargs
def __call__(self, messages: List[dict]):
messages = self.parse_inputs(messages)
response = openai.ChatCompletion.create(
model=self._model,
messages=messages,
n=1,
stream=self.stream,
**self.kwargs
)
if self.stream:
return self.stream_output(response)
else:
answer = response['choices'][0]['message']['content']
return answer
def parse_inputs(self, messages: List[dict]):
assert isinstance(messages, list), \
'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].'
new_messages = []
for m in messages:
if ('role' and 'content' in m) and (m['role'] in ['system', 'assistant', 'user']):
new_messages.append(m)
else:
for k, v in m.items():
if k == 'question':
new_m = {'role': 'user', 'content': v}
elif k == 'answer':
new_m = {'role': 'assistant', 'content': v}
elif k == 'system':
new_m = {'role': 'system', 'content': v}
else:
raise KeyError('Invalid message key: only accept key value from ["system", "question", "answer"].')
new_messages.append(new_m)
return new_messages
def stream_output(self, response):
for resp in response:
yield resp['choices'][0]['delta']
@staticmethod
def supported_model_names():
model_list = [
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
'gpt-3.5-turbo-instruct',
'gpt-3.5-turbo-0613',
'gpt-3.5-turbo-16k-0613'
]
model_list.sort()
return model_list
Loading…
Cancel
Save