From f34e610e6c60d96c4da106ee27826319f4e33f46 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Sat, 2 Apr 2022 15:32:35 +0800 Subject: [PATCH] Update Signed-off-by: Jael Gu --- README.md | 72 +++++++++++++++++++++++++---------------------------- __init__.py | 4 +-- realm.py | 14 ++++++++--- 3 files changed, 47 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 3434df4..bc7d2e6 100644 --- a/README.md +++ b/README.md @@ -1,74 +1,70 @@ # Text Embedding with Transformers -*author: Jael Gu and David Wang* +*author: Jael Gu* ## Desription -A REALM[1] text embedding operator implemented with pretrained models from [Huggingface Transformers](https://huggingface.co/docs/transformers). +A text embedding operator takes a sentence, paragraph, or document in string as an input +and output an embedding vector in ndarray which captures the input's core semantic elements. +This operator uses the REALM model, which is a retrieval-augmented language model that firstly retrieves documents from a textual knowledge corpus and then utilizes retrieved documents to process question answering tasks. [1] +The original model was proposed in REALM: Retrieval-Augmented Language Model Pre-Training by Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat and Ming-Wei Chang.[2] +### Reference +[1].https://huggingface.co/docs/transformers/model_doc/realm -```python -from towhee import ops +[2].https://arxiv.org/abs/2002.08909 -text_encoder = ops.text_embedding.realm('google/realm-cc-news-pretrained-encoder') -text_embedding = text_encoder("Hello, world.") -``` +## Code Example -## Factory Constructor +Use the pretrained model "google/realm-cc-news-pretrained-embedder" +to generate a text embedding for the sentence "Hello, world.". -Create the operator via the following factory method + *Write the pipeline*: -***ops.text_embedding.realm(model_name)*** +```python +from towhee import dc +dc.stream(["Hello, world."]) + .text_embedding.realm(model_name="google/realm-cc-news-pretrained-embedder") + .show() +``` -## Interface +## Factory Constructor -A text embedding operator takes a sentence, paragraph, or document in string as an input -and output an embedding vector in ndarray which captures the input's core semantic elements. +Create the operator via the following factory method +***text_embedding.transformers(model_name="google/realm-cc-news-pretrained-embedder")*** **Parameters:** -​ ***text***: *str* - -​ The text in string. - +​ ***model_name***: *str* +​ The model name in string. +You can get the list of supported model names by calling `get_model_list` from [realm.py](https://towhee.io/text-embedding/realm/src/branch/main/realm.py). -**Returns**: *numpy.ndarray* -​ The text embedding extracted by model. +## Interface +The operator takes a text in string as input. +It loads tokenizer and pre-trained model using model name. +and then return text embedding in ndarray. -## Code Example +**Parameters:** -Use the pretrained model ('google/realm-cc-news-pretrained-encoder') -to generate a text embedding for the sentence "Hello, world.". +​ ***text***: *str* - *Write the pipeline in simplified style*: +​ The text in string. -```python -import towhee.DataCollection as dc -dc.glob("Hello, world.") - .text_embedding.realm('google/realm-cc-news-pretrained-encoder') - .show() -``` -*Write a same pipeline with explicit inputs/outputs name specifications:* +**Returns**: -```python -from towhee import DataCollection as dc +​ *numpy.ndarray* -dc.glob['text']('Hello, world.') - .text_embedding.realm['text', 'vec']('bert-base-cased') - .select('vec') - .show() -``` +​ The text embedding extracted by model. -[1] https://arxiv.org/abs/2002.08909 diff --git a/__init__.py b/__init__.py index d8f3b33..593d5a7 100644 --- a/__init__.py +++ b/__init__.py @@ -15,5 +15,5 @@ from .realm import Realm -def realm(model_name: str): - return Realm(model_name) +def realm(**kwargs): + return Realm(**kwargs) diff --git a/realm.py b/realm.py index bde3552..acfa934 100644 --- a/realm.py +++ b/realm.py @@ -23,6 +23,7 @@ from towhee import register import warnings warnings.filterwarnings('ignore') +logging.getLogger("transformers").setLevel(logging.ERROR) log = logging.getLogger() @@ -35,7 +36,7 @@ class Realm(NNOperator): Which model to use for the embeddings. """ - def __init__(self, model_name: str) -> None: + def __init__(self, model_name: str = "google/realm-cc-news-pretrained-embedder") -> None: super().__init__() self.model_name = model_name try: @@ -65,6 +66,13 @@ class Realm(NNOperator): except Exception as e: log.error(f'Fail to extract features by model: {self.model_name}') raise e - feature_vector = features.detach().numpy() - return feature_vector + vec = features.detach().numpy() + return vec + +def get_model_list(): + full_list = [ + "google/realm-cc-news-pretrained-embedder" + ] + full_list.sort() + return full_list