From 87f0dc9b9c97706c23103b3d57a3c32f99ad44d4 Mon Sep 17 00:00:00 2001 From: jinlingxu06 Date: Sun, 8 Oct 2023 12:08:04 +0800 Subject: [PATCH] Update the operator. Signed-off-by: jinlingxu06 --- README.md | 86 ++++++++++++++++++++++++++++++++++++++- __init__.py | 4 ++ azure_openai_embedding.py | 61 +++++++++++++++++++++++++++ requirements.txt | 1 + 4 files changed, 151 insertions(+), 1 deletion(-) create mode 100644 __init__.py create mode 100644 azure_openai_embedding.py create mode 100644 requirements.txt diff --git a/README.md b/README.md index 7752a69..a094d89 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,86 @@ -# azure-openai +# Sentence Embedding with OpenAI + +*author: Junjie, Jael* + +
+ +## Description + +A sentence embedding operator generates one embedding vector in ndarray for each input text. +The embedding represents the semantic information of the whole input text as one vector. +This operator is implemented with embedding models from [OpenAI](https://platform.openai.com/docs/guides/embeddings). +Please note you need an [OpenAI API key](https://platform.openai.com/account/api-keys) to access OpenAI. + +
+ +## Code Example + +Use the pre-trained model '' +to generate an embedding for the sentence "Hello, world.". + +*Write a pipeline with explicit inputs/outputs name specifications:* + +```python +from towhee import pipe, ops, DataCollection + +p = ( + pipe.input('text') + .map('text', 'vec', + ops.sentence_embedding.openai(model_name='text-embedding-ada-002', api_key=OPENAI_API_KEY)) + .output('text', 'vec') +) + +DataCollection(p('Hello, world.')).show() +``` + +
+ +## Factory Constructor + +Create the operator via the following factory method: + +***sentence_embedding.openai(model_name='text-embedding-ada-002')*** + +**Parameters:** + +***model_name***: *str* + +The model name in string, defaults to 'text-embedding-ada-002'. Supported model names: +- text-embedding-ada-002 +- text-similarity-davinci-001 +- text-similarity-curie-001 +- text-similarity-babbage-001 +- text-similarity-ada-001 + +***api_key***: *str=None* + +The OpenAI API key in string, defaults to None. + +
+ +## Interface + +The operator takes a piece of text in string as input. +It returns a text emabedding in numpy.ndarray. + +***\_\_call\_\_(txt)*** + +**Parameters:** + +***text***: *str* + +​ The text in string. + +**Returns**: + +*numpy.ndarray or list* + +​ The text embedding extracted by model. + +
+ +***supported_model_names()*** + +Get a list of supported model names. + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..74bf5a3 --- /dev/null +++ b/__init__.py @@ -0,0 +1,4 @@ +from .azure_openai_embedding import AzureOpenaiEmbeding + +def azure_openai(*args, **kwargs): + return AzureOpenaiEmbeding(*args, **kwargs) diff --git a/azure_openai_embedding.py b/azure_openai_embedding.py new file mode 100644 index 0000000..85b0331 --- /dev/null +++ b/azure_openai_embedding.py @@ -0,0 +1,61 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from openai import Embedding +from tenacity import retry, stop_after_attempt, wait_random_exponential +from towhee.operator.base import PyOperator + + +class AzureOpenaiEmbeding(PyOperator): + def __init__(self, + engine='text-embedding-ada-002', + api_type: str = 'azure', + api_version: str = '2023-07-01-preview', + api_key=None, + api_base=None): + self._engine = engine + self._api_type = api_type + self._api_version = api_version + self._api_key = api_key + self._api_base = api_base + + + + @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) + def _call(self, text): + text = text.replace("\n", " ") + return Embedding.create(input=[text], + engine=self._engine, + api_key=self._api_key, + api_type=self._api_type, + api_version=self._api_version, + api_base=self._api_base + )["data"][0]["embedding"] + + def __call__(self, text): + return self._call(text) + + @staticmethod + def supported_model_names(): + model_list = [ + 'text-embedding-ada-002', + 'text-similarity-davinci-001', + 'text-similarity-curie-001', + 'text-similarity-babbage-001', + 'text-similarity-ada-001' + ] + model_list.sort() + return model_list + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ec838c5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +openai