From 1389d9923cce357f055339e205115d3f5314acbc Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Sat, 2 Apr 2022 12:14:46 +0800 Subject: [PATCH] Refactor Signed-off-by: Jael Gu --- README.md | 74 +++++++++++++----------- __init__.py | 6 +- auto_transformers.py | 134 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 175 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index e943a57..5141205 100644 --- a/README.md +++ b/README.md @@ -6,69 +6,75 @@ ## Desription -A text embedding operator implemented with pretrained models from [Huggingface Transformers](https://huggingface.co/docs/transformers). +A text embedding operator takes a sentence, paragraph, or document in string as an input +and output an embedding vector in ndarray which captures the input's core semantic elements. +This operator is implemented with pretrained models from [Huggingface Transformers](https://huggingface.co/docs/transformers). + +## Code Example +Use the pretrained model 'distilbert-base-cased' +to generate a text embedding for the sentence "Hello, world.". + *Write the pipeline in simplified style*: ```python -from towhee import ops +from towhee import dc + -text_encoder = ops.text_embedding.transformers(model_name="bert-base-cased") -text_embedding = text_encoder("Hello, world.") +dc.stream(["Hello, world."]) + .text_embedding.transformers('distilbert-base-cased') + .show() ``` -## Factory Constructor +*Write a same pipeline with explicit inputs/outputs name specifications:* -Create the operator via the following factory method +```python +from towhee import dc -***ops.text_embedding.transformers(model_name)*** +dc.stream['txt'](["Hello, world."]) + .text_embedding.transformers['txt', 'vec']('distilbert-base-cased') + .select('txt', 'vec') + .show() +``` -## Interface +## Factory Constructor -A text embedding operator takes a sentence, paragraph, or document in string as an input -and output an embedding vector in ndarray which captures the input's core semantic elements. +Create the operator via the following factory method +***text_embedding.transformers(model_name="bert-base-uncased")*** **Parameters:** -​ ***text***: *str* - -​ The text in string. - +​ ***model_name***: *str* +​ The model name in string. +You can get the list of supported model names by calling `model_list` of the operator: +```python +from towhee import ops -**Returns**: *numpy.ndarray* -​ The text embedding extracted by model. +ops.text_embedding.transformers.model_list() +``` +## Interface -## Code Example +The operator takes a text in string as input. +It loads tokenizer and pre-trained model using model name. +Text embeddings are returned in ndarray. -Use the pretrained Bert-Base-Cased model ('bert-base-cased') -to generate a text embedding for the sentence "Hello, world.". - *Write the pipeline in simplified style*: +**Parameters:** -```python -import towhee.DataCollection as dc +​ ***text***: *str* -dc.glob("Hello, world.") - .text_embedding.transformers('bert-base-cased') - .show() -``` +​ The text in string. -*Write a same pipeline with explicit inputs/outputs name specifications:* -```python -from towhee import DataCollection as dc -dc.glob['text']('Hello, world.') - .text_embedding.transformers['text', 'vec']('bert-base-cased') - .select('vec') - .show() -``` +**Returns**: *numpy.ndarray* +​ The text embedding extracted by model. diff --git a/__init__.py b/__init__.py index b6d9faf..43cd234 100644 --- a/__init__.py +++ b/__init__.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .auto_transformers import AutoTransformers +from .auto_transformers import AutoTransformers, get_model_list def transformers(model_name: str): return AutoTransformers(model_name) + + +def model_list(): + return get_model_list() diff --git a/auto_transformers.py b/auto_transformers.py index 649110a..80e2a55 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -23,6 +23,7 @@ from towhee import register import warnings warnings.filterwarnings('ignore') +logging.getLogger("transformers").setLevel(logging.ERROR) log = logging.getLogger() @@ -35,13 +36,17 @@ class AutoTransformers(NNOperator): Which model to use for the embeddings. """ - def __init__(self, model_name: str) -> None: + def __init__(self, model_name: str = "bert-base-uncased") -> None: super().__init__() self.model_name = model_name try: self.model = AutoModel.from_pretrained(model_name) except Exception as e: - log.error(f'Fail to load model by name: {self.model_name}') + model_list = get_model_list() + if model_name not in model_list: + log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}") + else: + log.error(f"Fail to load model by name: {self.model_name}") raise e try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -65,6 +70,127 @@ class AutoTransformers(NNOperator): except Exception as e: log.error(f'Fail to extract features by model: {self.model_name}') raise e - feature_vector = features.detach().numpy() - return feature_vector + vec = features.detach().numpy() + return vec + +def get_model_list(): + full_list = [ + "bert-large-uncased", + "bert-base-cased", + "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-large-uncased-whole-word-masking-finetuned-squad", + "bert-large-cased-whole-word-masking-finetuned-squad", + "bert-base-cased-finetuned-mrpc", + "bert-base-german-dbmdz-cased", + "bert-base-german-dbmdz-uncased", + "cl-tohoku/bert-base-japanese-whole-word-masking", + "cl-tohoku/bert-base-japanese-char", + "cl-tohoku/bert-base-japanese-char-whole-word-masking", + "TurkuNLP/bert-base-finnish-cased-v1", + "TurkuNLP/bert-base-finnish-uncased-v1", + "wietsedv/bert-base-dutch-cased", + "google/bigbird-roberta-base", + "google/bigbird-roberta-large", + "google/bigbird-base-trivia-itc", + "albert-base-v1", + "albert-large-v1", + "albert-xlarge-v1", + "albert-xxlarge-v1", + "albert-base-v2", + "albert-large-v2", + "albert-xlarge-v2", + "albert-xxlarge-v2", + "facebook/bart-large", + "google/bert_for_seq_generation_L-24_bbc_encoder", + "google/bigbird-pegasus-large-arxiv", + "google/bigbird-pegasus-large-pubmed", + "google/bigbird-pegasus-large-bigpatent", + "google/canine-s", + "google/canine-c", + "YituTech/conv-bert-base", + "YituTech/conv-bert-medium-small", + "YituTech/conv-bert-small", + "ctrl", + "microsoft/deberta-base", + "microsoft/deberta-large", + "microsoft/deberta-xlarge", + "microsoft/deberta-base-mnli", + "microsoft/deberta-large-mnli", + "microsoft/deberta-xlarge-mnli", + "distilbert-base-uncased", + "distilbert-base-uncased-distilled-squad", + "distilbert-base-cased", + "distilbert-base-cased-distilled-squad", + "distilbert-base-german-cased", + "distilbert-base-multilingual-cased", + "distilbert-base-uncased-finetuned-sst-2-english", + "google/electra-small-generator", + "google/electra-base-generator", + "google/electra-large-generator", + "google/electra-small-discriminator", + "google/electra-base-discriminator", + "google/electra-large-discriminator", + "google/fnet-base", + "google/fnet-large", + "facebook/wmt19-ru-en", + "funnel-transformer/small", + "funnel-transformer/small-base", + "funnel-transformer/medium", + "funnel-transformer/medium-base", + "funnel-transformer/intermediate", + "funnel-transformer/intermediate-base", + "funnel-transformer/large", + "funnel-transformer/large-base", + "funnel-transformer/xlarge-base", + "funnel-transformer/xlarge", + "gpt2", + "gpt2-medium", + "gpt2-large", + "gpt2-xl", + "distilgpt2", + "EleutherAI/gpt-neo-1.3B", + "EleutherAI/gpt-j-6B", + "kssteven/ibert-roberta-base", + "allenai/led-base-16384", + "google/mobilebert-uncased", + "microsoft/mpnet-base", + "uw-madison/nystromformer-512", + "openai-gpt", + "google/reformer-crime-and-punishment", + "tau/splinter-base", + "tau/splinter-base-qass", + "tau/splinter-large", + "tau/splinter-large-qass", + "squeezebert/squeezebert-uncased", + "squeezebert/squeezebert-mnli", + "squeezebert/squeezebert-mnli-headless", + "transfo-xl-wt103", + "xlm-mlm-en-2048", + "xlm-mlm-ende-1024", + "xlm-mlm-enfr-1024", + "xlm-mlm-enro-1024", + "xlm-mlm-tlm-xnli15-1024", + "xlm-mlm-xnli15-1024", + "xlm-clm-enfr-1024", + "xlm-clm-ende-1024", + "xlm-mlm-17-1280", + "xlm-mlm-100-1280", + "xlm-roberta-base", + "xlm-roberta-large", + "xlm-roberta-large-finetuned-conll02-dutch", + "xlm-roberta-large-finetuned-conll02-spanish", + "xlm-roberta-large-finetuned-conll03-english", + "xlm-roberta-large-finetuned-conll03-german", + "xlnet-base-cased", + "xlnet-large-cased", + "uw-madison/yoso-4096", + ] + full_list.sort() + return full_list