diff --git a/README.md b/README.md index a094d89..30fb6d9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# Sentence Embedding with OpenAI +# Sentence Embedding with Azure OpenAI -*author: Junjie, Jael* +*author: David*
@@ -9,7 +9,7 @@ A sentence embedding operator generates one embedding vector in ndarray for each input text. The embedding represents the semantic information of the whole input text as one vector. This operator is implemented with embedding models from [OpenAI](https://platform.openai.com/docs/guides/embeddings). -Please note you need an [OpenAI API key](https://platform.openai.com/account/api-keys) to access OpenAI. +This operator is designed specifically for Azure OpenAI, get more information from [link](https://learn.microsoft.com/en-us/azure/ai-services/openai/tutorials/embeddings?tabs=command-line)
@@ -25,11 +25,9 @@ from towhee import pipe, ops, DataCollection p = ( pipe.input('text') - .map('text', 'vec', - ops.sentence_embedding.openai(model_name='text-embedding-ada-002', api_key=OPENAI_API_KEY)) + .map('text', 'vec', ops.sentence_embedding.azure_openai(model_name='text-embedding-ada-002', api_key=api_key, api_base=api_base)) .output('text', 'vec') ) - DataCollection(p('Hello, world.')).show() ``` @@ -39,7 +37,7 @@ DataCollection(p('Hello, world.')).show() Create the operator via the following factory method: -***sentence_embedding.openai(model_name='text-embedding-ada-002')*** +***sentence_embedding.azure_openai(model_name='text-embedding-ada-002')*** **Parameters:** @@ -52,10 +50,23 @@ The model name in string, defaults to 'text-embedding-ada-002'. Supported model - text-similarity-babbage-001 - text-similarity-ada-001 +***api_type***: *str='azure'* + +The OpenAI type in string, defaults to 'azure'. + +***api_version***: *str='2023-07-01-preview'* + +The OpenAI version in string, defaults to '2023-07-01-preview'. + ***api_key***: *str=None* The OpenAI API key in string, defaults to None. +***api_base***: *str=None* + +The OpenAI base in string, defaults to None. + +
## Interface diff --git a/azure_openai_embedding.py b/azure_openai_embedding.py index 85b0331..5cdf307 100644 --- a/azure_openai_embedding.py +++ b/azure_openai_embedding.py @@ -20,12 +20,12 @@ from towhee.operator.base import PyOperator class AzureOpenaiEmbeding(PyOperator): def __init__(self, - engine='text-embedding-ada-002', + model_name='text-embedding-ada-002', api_type: str = 'azure', api_version: str = '2023-07-01-preview', api_key=None, api_base=None): - self._engine = engine + self._engine = model_name self._api_type = api_type self._api_version = api_version self._api_key = api_key