# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy import os import torch from pathlib import Path from transformers import AutoTokenizer, AutoModel from towhee.operator import NNOperator from towhee import register import warnings warnings.filterwarnings('ignore') @register(output_schema=['vec']) class AutoTransformers(NNOperator): """ NLP embedding operator that uses the pretrained transformers model gathered by huggingface. Args: model_name (`str`): Which model to use for the embeddings. """ def __init__(self, model_name: str = "bert-base-uncased", device=None) -> None: super().__init__() if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device self.model_name = model_name try: self.model = AutoModel.from_pretrained(model_name).to(self.device) self.model.eval() except Exception as e: model_list = self.supported_model_names() if model_name not in model_list: log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}") else: log.error(f"Fail to load model by name: {self.model_name}") raise e try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) except Exception as e: log.error(f'Fail to load tokenizer by name: {self.model_name}') raise e def __call__(self, txt: str) -> numpy.ndarray: try: inputs = self.tokenizer(txt, return_tensors="pt").to(self.device) except Exception as e: log.error(f'Invalid input for the tokenizer: {self.model_name}') raise e try: outs = self.model(**inputs) except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e try: features = outs['last_hidden_state'].squeeze(0) except Exception as e: log.error(f'Fail to extract features by model: {self.model_name}') raise e vec = features.cpu().detach().numpy() return vec def save_model(self, format: str = 'pytorch', path: str = 'default'): if path == 'default': path = str(Path(__file__).parent) path = os.path.join(path, 'saved', format) os.makedirs(path, exist_ok=True) name = self.model_name.replace('/', '-') path = os.path.join(path, name) inputs = self.tokenizer('[CLS]', return_tensors='pt') # a dictionary if format == 'pytorch': path = path + '.pt' torch.save(self.model, path) elif format == 'torchscript': path = path + '.pt' inputs = list(inputs.values()) try: try: jit_model = torch.jit.script(self.model) except Exception: jit_model = torch.jit.trace(self.model, inputs, strict=False) torch.jit.save(jit_model, path) except Exception as e: log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onnx': path = path + '.onnx' try: torch.onnx.export(self.model, tuple(inputs.values()), path, input_names=list(inputs.keys()), output_names=["last_hidden_state"], dynamic_axes={ "input_ids": {0: "batch_size", 1: "input_length"}, "token_type_ids": {0: "batch_size", 1: "input_length"}, "attention_mask": {0: "batch_size", 1: "input_length"}, "last_hidden_state": {0: "batch_size"}, }, opset_version=13, do_constant_folding=True, # enable_onnx_checker=True, ) except Exception as e: print(e, '\nTrying with 2 outputs...') torch.onnx.export(self.model, tuple(inputs.values()), path, input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys()) output_names=["last_hidden_state", "pooler_output"], opset_version=13, dynamic_axes={ "input_ids": {0: "batch_size", 1: "input_length"}, "token_type_ids": {0: "batch_size", 1: "input_length"}, "attention_mask": {0: "batch_size", 1: "input_length"}, "last_hidden_state": {0: "batch_size"}, "pooler_outputs": {0: "batch_size"} }) # todo: elif format == 'tensorrt': else: log.error(f'Unsupported format "{format}".') @staticmethod def supported_model_names(format: str = None): full_list = [ "bert-large-uncased", "bert-base-cased", "bert-large-cased", "bert-base-multilingual-uncased", "bert-base-multilingual-cased", "bert-base-chinese", "bert-base-german-cased", "bert-large-uncased-whole-word-masking", "bert-large-cased-whole-word-masking", "bert-large-uncased-whole-word-masking-finetuned-squad", "bert-large-cased-whole-word-masking-finetuned-squad", "bert-base-cased-finetuned-mrpc", "bert-base-german-dbmdz-cased", "bert-base-german-dbmdz-uncased", "cl-tohoku/bert-base-japanese-whole-word-masking", "cl-tohoku/bert-base-japanese-char", "cl-tohoku/bert-base-japanese-char-whole-word-masking", "TurkuNLP/bert-base-finnish-cased-v1", "TurkuNLP/bert-base-finnish-uncased-v1", "wietsedv/bert-base-dutch-cased", "google/bigbird-roberta-base", "google/bigbird-roberta-large", "google/bigbird-base-trivia-itc", "albert-base-v1", "albert-large-v1", "albert-xlarge-v1", "albert-xxlarge-v1", "albert-base-v2", "albert-large-v2", "albert-xlarge-v2", "albert-xxlarge-v2", "facebook/bart-large", "google/bert_for_seq_generation_L-24_bbc_encoder", "google/bigbird-pegasus-large-arxiv", "google/bigbird-pegasus-large-pubmed", "google/bigbird-pegasus-large-bigpatent", "google/canine-s", "google/canine-c", "YituTech/conv-bert-base", "YituTech/conv-bert-medium-small", "YituTech/conv-bert-small", "ctrl", "microsoft/deberta-base", "microsoft/deberta-large", "microsoft/deberta-xlarge", "microsoft/deberta-base-mnli", "microsoft/deberta-large-mnli", "microsoft/deberta-xlarge-mnli", "distilbert-base-uncased", "distilbert-base-uncased-distilled-squad", "distilbert-base-cased", "distilbert-base-cased-distilled-squad", "distilbert-base-german-cased", "distilbert-base-multilingual-cased", "distilbert-base-uncased-finetuned-sst-2-english", "google/electra-small-generator", "google/electra-base-generator", "google/electra-large-generator", "google/electra-small-discriminator", "google/electra-base-discriminator", "google/electra-large-discriminator", "google/fnet-base", "google/fnet-large", "facebook/wmt19-ru-en", "funnel-transformer/small", "funnel-transformer/small-base", "funnel-transformer/medium", "funnel-transformer/medium-base", "funnel-transformer/intermediate", "funnel-transformer/intermediate-base", "funnel-transformer/large", "funnel-transformer/large-base", "funnel-transformer/xlarge-base", "funnel-transformer/xlarge", "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl", "distilgpt2", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-j-6B", "kssteven/ibert-roberta-base", "allenai/led-base-16384", "google/mobilebert-uncased", "microsoft/mpnet-base", "uw-madison/nystromformer-512", "openai-gpt", "google/reformer-crime-and-punishment", "tau/splinter-base", "tau/splinter-base-qass", "tau/splinter-large", "tau/splinter-large-qass", "squeezebert/squeezebert-uncased", "squeezebert/squeezebert-mnli", "squeezebert/squeezebert-mnli-headless", "transfo-xl-wt103", "xlm-mlm-en-2048", "xlm-mlm-ende-1024", "xlm-mlm-enfr-1024", "xlm-mlm-enro-1024", "xlm-mlm-tlm-xnli15-1024", "xlm-mlm-xnli15-1024", "xlm-clm-enfr-1024", "xlm-clm-ende-1024", "xlm-mlm-17-1280", "xlm-mlm-100-1280", "xlm-roberta-base", "xlm-roberta-large", "xlm-roberta-large-finetuned-conll02-dutch", "xlm-roberta-large-finetuned-conll02-spanish", "xlm-roberta-large-finetuned-conll03-english", "xlm-roberta-large-finetuned-conll03-german", "xlnet-base-cased", "xlnet-large-cased", "uw-madison/yoso-4096", "microsoft/deberta-v2-xlarge", "microsoft/deberta-v2-xxlarge", "microsoft/deberta-v2-xlarge-mnli", "microsoft/deberta-v2-xxlarge-mnli", "flaubert/flaubert_small_cased", "flaubert/flaubert_base_uncased", "flaubert/flaubert_base_cased", "flaubert/flaubert_large_cased", "camembert-base", "Musixmatch/umberto-commoncrawl-cased-v1", "Musixmatch/umberto-wikipedia-uncased-v1", ] full_list.sort() if format is None: model_list = full_list elif format == 'pytorch': to_remove = [] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) elif format == 'torchscript': to_remove = [ 'EleutherAI/gpt-j-6B', 'EleutherAI/gpt-neo-1.3B', 'allenai/led-base-16384', 'ctrl', 'distilgpt2', 'facebook/bart-large', 'google/bigbird-pegasus-large-arxiv', 'google/bigbird-pegasus-large-bigpatent', 'google/bigbird-pegasus-large-pubmed', 'google/canine-c', 'google/canine-s', 'google/reformer-crime-and-punishment', 'gpt2', 'gpt2-large', 'gpt2-medium', 'gpt2-xl', 'microsoft/deberta-base', 'microsoft/deberta-base-mnli', 'microsoft/deberta-large', 'microsoft/deberta-large-mnli', 'microsoft/deberta-xlarge', 'microsoft/deberta-xlarge-mnli', 'openai-gpt', 'transfo-xl-wt103', 'uw-madison/yoso-4096', 'xlm-clm-ende-1024', 'xlm-clm-enfr-1024', 'xlm-mlm-100-1280', 'xlm-mlm-17-1280', 'xlm-mlm-en-2048', 'xlm-mlm-ende-1024', 'xlm-mlm-enfr-1024', 'xlm-mlm-enro-1024', 'xlm-mlm-tlm-xnli15-1024', 'xlm-mlm-xnli15-1024', 'xlnet-base-cased', 'xlnet-large-cased', 'microsoft/deberta-v2-xlarge', 'microsoft/deberta-v2-xxlarge', 'microsoft/deberta-v2-xlarge-mnli', 'microsoft/deberta-v2-xxlarge-mnli', 'flaubert/flaubert_small_cased', 'flaubert/flaubert_base_uncased', 'flaubert/flaubert_base_cased', 'flaubert/flaubert_large_cased' ] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) elif format == 'onnx': to_remove = [ 'allenai/led-base-16384', 'ctrl', 'distilgpt2', 'EleutherAI/gpt-j-6B', 'EleutherAI/gpt-neo-1.3B', 'funnel-transformer/large', 'funnel-transformer/medium', 'funnel-transformer/small', 'funnel-transformer/xlarge', 'google/bigbird-pegasus-large-arxiv', 'google/bigbird-pegasus-large-bigpatent', 'google/bigbird-pegasus-large-pubmed', 'google/fnet-base', 'google/fnet-large', 'google/reformer-crime-and-punishment', 'gpt2', 'gpt2-large', 'gpt2-medium', 'gpt2-xl', 'microsoft/deberta-v2-xlarge', 'microsoft/deberta-v2-xlarge-mnli', 'microsoft/deberta-v2-xxlarge', 'microsoft/deberta-v2-xxlarge-mnli', 'microsoft/deberta-xlarge', 'microsoft/deberta-xlarge-mnli', 'openai-gpt', 'transfo-xl-wt103', 'uw-madison/yoso-4096', 'xlm-mlm-100-1280', 'xlm-mlm-17-1280', 'xlm-mlm-en-2048', 'xlm-roberta-large', 'xlm-roberta-large-finetuned-conll02-dutch', 'xlm-roberta-large-finetuned-conll02-spanish', 'xlm-roberta-large-finetuned-conll03-english', 'xlm-roberta-large-finetuned-conll03-german', 'xlnet-base-cased', 'xlnet-large-cased' ] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) # todo: elif format == 'tensorrt': else: log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') return model_list