|
|
@ -23,6 +23,7 @@ from towhee import register |
|
|
|
import warnings |
|
|
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
|
log = logging.getLogger() |
|
|
|
|
|
|
|
|
|
|
@ -35,13 +36,17 @@ class AutoTransformers(NNOperator): |
|
|
|
Which model to use for the embeddings. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, model_name: str) -> None: |
|
|
|
def __init__(self, model_name: str = "bert-base-uncased") -> None: |
|
|
|
super().__init__() |
|
|
|
self.model_name = model_name |
|
|
|
try: |
|
|
|
self.model = AutoModel.from_pretrained(model_name) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to load model by name: {self.model_name}') |
|
|
|
model_list = get_model_list() |
|
|
|
if model_name not in model_list: |
|
|
|
log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}") |
|
|
|
else: |
|
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
|
raise e |
|
|
|
try: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
@ -65,6 +70,127 @@ class AutoTransformers(NNOperator): |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to extract features by model: {self.model_name}') |
|
|
|
raise e |
|
|
|
feature_vector = features.detach().numpy() |
|
|
|
return feature_vector |
|
|
|
vec = features.detach().numpy() |
|
|
|
return vec |
|
|
|
|
|
|
|
|
|
|
|
def get_model_list(): |
|
|
|
full_list = [ |
|
|
|
"bert-large-uncased", |
|
|
|
"bert-base-cased", |
|
|
|
"bert-large-cased", |
|
|
|
"bert-base-multilingual-uncased", |
|
|
|
"bert-base-multilingual-cased", |
|
|
|
"bert-base-chinese", |
|
|
|
"bert-base-german-cased", |
|
|
|
"bert-large-uncased-whole-word-masking", |
|
|
|
"bert-large-cased-whole-word-masking", |
|
|
|
"bert-large-uncased-whole-word-masking-finetuned-squad", |
|
|
|
"bert-large-cased-whole-word-masking-finetuned-squad", |
|
|
|
"bert-base-cased-finetuned-mrpc", |
|
|
|
"bert-base-german-dbmdz-cased", |
|
|
|
"bert-base-german-dbmdz-uncased", |
|
|
|
"cl-tohoku/bert-base-japanese-whole-word-masking", |
|
|
|
"cl-tohoku/bert-base-japanese-char", |
|
|
|
"cl-tohoku/bert-base-japanese-char-whole-word-masking", |
|
|
|
"TurkuNLP/bert-base-finnish-cased-v1", |
|
|
|
"TurkuNLP/bert-base-finnish-uncased-v1", |
|
|
|
"wietsedv/bert-base-dutch-cased", |
|
|
|
"google/bigbird-roberta-base", |
|
|
|
"google/bigbird-roberta-large", |
|
|
|
"google/bigbird-base-trivia-itc", |
|
|
|
"albert-base-v1", |
|
|
|
"albert-large-v1", |
|
|
|
"albert-xlarge-v1", |
|
|
|
"albert-xxlarge-v1", |
|
|
|
"albert-base-v2", |
|
|
|
"albert-large-v2", |
|
|
|
"albert-xlarge-v2", |
|
|
|
"albert-xxlarge-v2", |
|
|
|
"facebook/bart-large", |
|
|
|
"google/bert_for_seq_generation_L-24_bbc_encoder", |
|
|
|
"google/bigbird-pegasus-large-arxiv", |
|
|
|
"google/bigbird-pegasus-large-pubmed", |
|
|
|
"google/bigbird-pegasus-large-bigpatent", |
|
|
|
"google/canine-s", |
|
|
|
"google/canine-c", |
|
|
|
"YituTech/conv-bert-base", |
|
|
|
"YituTech/conv-bert-medium-small", |
|
|
|
"YituTech/conv-bert-small", |
|
|
|
"ctrl", |
|
|
|
"microsoft/deberta-base", |
|
|
|
"microsoft/deberta-large", |
|
|
|
"microsoft/deberta-xlarge", |
|
|
|
"microsoft/deberta-base-mnli", |
|
|
|
"microsoft/deberta-large-mnli", |
|
|
|
"microsoft/deberta-xlarge-mnli", |
|
|
|
"distilbert-base-uncased", |
|
|
|
"distilbert-base-uncased-distilled-squad", |
|
|
|
"distilbert-base-cased", |
|
|
|
"distilbert-base-cased-distilled-squad", |
|
|
|
"distilbert-base-german-cased", |
|
|
|
"distilbert-base-multilingual-cased", |
|
|
|
"distilbert-base-uncased-finetuned-sst-2-english", |
|
|
|
"google/electra-small-generator", |
|
|
|
"google/electra-base-generator", |
|
|
|
"google/electra-large-generator", |
|
|
|
"google/electra-small-discriminator", |
|
|
|
"google/electra-base-discriminator", |
|
|
|
"google/electra-large-discriminator", |
|
|
|
"google/fnet-base", |
|
|
|
"google/fnet-large", |
|
|
|
"facebook/wmt19-ru-en", |
|
|
|
"funnel-transformer/small", |
|
|
|
"funnel-transformer/small-base", |
|
|
|
"funnel-transformer/medium", |
|
|
|
"funnel-transformer/medium-base", |
|
|
|
"funnel-transformer/intermediate", |
|
|
|
"funnel-transformer/intermediate-base", |
|
|
|
"funnel-transformer/large", |
|
|
|
"funnel-transformer/large-base", |
|
|
|
"funnel-transformer/xlarge-base", |
|
|
|
"funnel-transformer/xlarge", |
|
|
|
"gpt2", |
|
|
|
"gpt2-medium", |
|
|
|
"gpt2-large", |
|
|
|
"gpt2-xl", |
|
|
|
"distilgpt2", |
|
|
|
"EleutherAI/gpt-neo-1.3B", |
|
|
|
"EleutherAI/gpt-j-6B", |
|
|
|
"kssteven/ibert-roberta-base", |
|
|
|
"allenai/led-base-16384", |
|
|
|
"google/mobilebert-uncased", |
|
|
|
"microsoft/mpnet-base", |
|
|
|
"uw-madison/nystromformer-512", |
|
|
|
"openai-gpt", |
|
|
|
"google/reformer-crime-and-punishment", |
|
|
|
"tau/splinter-base", |
|
|
|
"tau/splinter-base-qass", |
|
|
|
"tau/splinter-large", |
|
|
|
"tau/splinter-large-qass", |
|
|
|
"squeezebert/squeezebert-uncased", |
|
|
|
"squeezebert/squeezebert-mnli", |
|
|
|
"squeezebert/squeezebert-mnli-headless", |
|
|
|
"transfo-xl-wt103", |
|
|
|
"xlm-mlm-en-2048", |
|
|
|
"xlm-mlm-ende-1024", |
|
|
|
"xlm-mlm-enfr-1024", |
|
|
|
"xlm-mlm-enro-1024", |
|
|
|
"xlm-mlm-tlm-xnli15-1024", |
|
|
|
"xlm-mlm-xnli15-1024", |
|
|
|
"xlm-clm-enfr-1024", |
|
|
|
"xlm-clm-ende-1024", |
|
|
|
"xlm-mlm-17-1280", |
|
|
|
"xlm-mlm-100-1280", |
|
|
|
"xlm-roberta-base", |
|
|
|
"xlm-roberta-large", |
|
|
|
"xlm-roberta-large-finetuned-conll02-dutch", |
|
|
|
"xlm-roberta-large-finetuned-conll02-spanish", |
|
|
|
"xlm-roberta-large-finetuned-conll03-english", |
|
|
|
"xlm-roberta-large-finetuned-conll03-german", |
|
|
|
"xlnet-base-cased", |
|
|
|
"xlnet-large-cased", |
|
|
|
"uw-madison/yoso-4096", |
|
|
|
] |
|
|
|
full_list.sort() |
|
|
|
return full_list |
|
|
|