diff --git a/README.md b/README.md index 2fe958e..6adc108 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,135 @@ -# Azure-OpenAI +# OpenAI Chat Completion + +*author: David Wang* + +
+ +## Description + +A LLM operator generates answer given prompt in messages using a large language model or service. +This operator is implemented with Chat Completion method from [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/chatgpt?pivots=programming-language-chat-completions). +Please note you need an [OpenAI API key](https://platform.openai.com/account/api-keys) to access OpenAI. + +
+ +## Code Example + +Use the default model to continue the conversation from given messages. + +*Write a pipeline with explicit inputs/outputs name specifications:* + +```python +from towhee import pipe, ops + +p = ( + pipe.input('messages') + .map('messages', 'answer', ops.LLM.Azure_OpenAI(api_key=OPENAI_API_KEY, api_base=OPENAI_API_BASE)) + .output('messages', 'answer') +) + +messages=[ + {'question': 'Who won the world series in 2020?', 'answer': 'The Los Angeles Dodgers won the World Series in 2020.'}, + {'question': 'Where was it played?'} + ] +answer = p(messages).get()[0] +``` + +*Write a [retrieval-augmented generation pipeline](https://towhee.io/tasks/detail/pipeline/retrieval-augmented-generation) with explicit inputs/outputs name specifications:* + +```python +from towhee import pipe, ops + + +temp = '''Use the following pieces of context to answer the question at the end. +If you don't know the answer, just say that you don't know, don't try to make up an answer. + +{context} + +Question: {question} + +Helpful Answer: +''' + + +docs = ['You can install towhee via command `pip install towhee`.'] +history = [ + ('What is Towhee?', 'Towhee is an open-source machine learning pipeline that helps you encode your unstructured data into embeddings.') +] +question = 'How to install it?' + +p = ( + pipe.input('question', 'docs', 'history') + .map(('question', 'docs', 'history'), 'prompt', ops.prompt.template(temp, ['question', 'context'])) + .map('prompt', 'answer', + ops.LLM.Azure_OpenAI(api_key=OPENAI_API_KEY, api_base=OPENAI_API_BASE, temperature=0.5, max_tokens=100) + ) + .output('answer') +) + +answer = p(question, docs, history).get()[0] +``` + +
+ +## Factory Constructor + +Create the operator via the following factory method: + +***LLM.OpenAI(model_name: str, api_key: str)*** + +**Parameters:** + +***model_name***: *str* + +The model name in string, defaults to 'gpt-3.5-turbo'. Supported model names: +- gpt-3.5-turbo +- gpt-3.5-turbo-16k +- gpt-3.5-turbo-instruct +- gpt-3.5-turbo-0613 +- gpt-3.5-turbo-16k-0613 + +***api_type***: *str=None* + +The OpenAI API type in string, defaults to None. + +***api_version***: *str=None* + +The OpenAI API version in string, defaults to None. + +***api_key***: *str=None* + +The OpenAI API key in string, defaults to None. + +***api_base***: *str=None* + +The OpenAI API base in string, defaults to None. + +***\*\*kwargs*** + +Other OpenAI parameters such as max_tokens, stream, temperature, etc. + +
+ +## Interface + +The operator takes a piece of text in string as input. +It returns answer in json. + +***\_\_call\_\_(txt)*** + +**Parameters:** + +***messages***: *list* + +​ A list of messages to set up chat. +Must be a list of dictionaries with key value from "system", "question", "answer". For example, [{"question": "a past question?", "answer": "a past answer."}, {"question": "current question?"}] + +**Returns**: + +*answer: str* + +​ The next answer generated by role "assistant". + +
+ diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..b30e137 --- /dev/null +++ b/__init__.py @@ -0,0 +1,5 @@ +from .azure_openai_chat import AzureOpenaiChat + + +def AzureOpenAI(*args, **kwargs): + return AzureOpenaiChat(*args, **kwargs) diff --git a/azure_openai_chat.py b/azure_openai_chat.py new file mode 100644 index 0000000..9dbde77 --- /dev/null +++ b/azure_openai_chat.py @@ -0,0 +1,87 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List + +import openai +from towhee.operator.base import PyOperator + +class AzureOpenaiChat(PyOperator): + '''Wrapper of OpenAI Chat API''' + def __init__(self, + model_name: str = 'gpt-3.5-turbo', + api_type: str = 'azure', + api_version: str = '2023-07-01-preview', + api_key: str = None, + api_base = None, + **kwargs + ): + openai.api_key = api_key or os.getenv('OPENAI_API_KEY') + openai.api_base = api_base or os.getenv('OPENAI_API_BASE') + self._model = model_name + self.stream = kwargs.pop('stream') if 'stream' in kwargs else False + self.kwargs = kwargs + + def __call__(self, messages: List[dict]): + messages = self.parse_inputs(messages) + response = openai.ChatCompletion.create( + model=self._model, + messages=messages, + n=1, + stream=self.stream, + **self.kwargs + ) + if self.stream: + return self.stream_output(response) + else: + answer = response['choices'][0]['message']['content'] + return answer + + def parse_inputs(self, messages: List[dict]): + assert isinstance(messages, list), \ + 'Inputs must be a list of dictionaries with keys from ["system", "question", "answer"].' + new_messages = [] + for m in messages: + if ('role' and 'content' in m) and (m['role'] in ['system', 'assistant', 'user']): + new_messages.append(m) + else: + for k, v in m.items(): + if k == 'question': + new_m = {'role': 'user', 'content': v} + elif k == 'answer': + new_m = {'role': 'assistant', 'content': v} + elif k == 'system': + new_m = {'role': 'system', 'content': v} + else: + raise KeyError('Invalid message key: only accept key value from ["system", "question", "answer"].') + new_messages.append(new_m) + return new_messages + + def stream_output(self, response): + for resp in response: + yield resp['choices'][0]['delta'] + + @staticmethod + def supported_model_names(): + model_list = [ + 'gpt-3.5-turbo', + 'gpt-3.5-turbo-16k', + 'gpt-3.5-turbo-instruct', + 'gpt-3.5-turbo-0613', + 'gpt-3.5-turbo-16k-0613' + ] + model_list.sort() + return model_list +