|
|
@ -39,29 +39,34 @@ class AutoTransformers(NNOperator): |
|
|
|
Which model to use for the embeddings. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, model_name: str = "bert-base-uncased", device=None) -> None: |
|
|
|
def __init__(self, model_name: str = None, device: str = None) -> None: |
|
|
|
super().__init__() |
|
|
|
if device is None: |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.device = device |
|
|
|
self.model_name = model_name |
|
|
|
try: |
|
|
|
self.model = AutoModel.from_pretrained(model_name).to(self.device) |
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
self.configs = self.model.config |
|
|
|
except Exception as e: |
|
|
|
model_list = self.supported_model_names() |
|
|
|
if model_name not in model_list: |
|
|
|
log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}") |
|
|
|
else: |
|
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
|
raise e |
|
|
|
try: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to load tokenizer by name: {self.model_name}') |
|
|
|
raise e |
|
|
|
if self.model_name: |
|
|
|
try: |
|
|
|
self.model = AutoModel.from_pretrained(model_name).to(self.device) |
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
self.configs = self.model.config |
|
|
|
except Exception as e: |
|
|
|
model_list = self.supported_model_names() |
|
|
|
if model_name not in model_list: |
|
|
|
log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}") |
|
|
|
else: |
|
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
|
raise e |
|
|
|
try: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to load tokenizer by name: {self.model_name}') |
|
|
|
raise e |
|
|
|
else: |
|
|
|
log.warning('The operator is initialized without specified model.') |
|
|
|
pass |
|
|
|
|
|
|
|
def __call__(self, txt: str) -> numpy.ndarray: |
|
|
|
try: |
|
|
@ -313,42 +318,19 @@ class AutoTransformers(NNOperator): |
|
|
|
model_list = list(set(full_list) - set(to_remove)) |
|
|
|
elif format == 'onnx': |
|
|
|
to_remove = [ |
|
|
|
'allenai/led-base-16384', |
|
|
|
'ctrl', |
|
|
|
'distilgpt2', |
|
|
|
'EleutherAI/gpt-j-6B', |
|
|
|
'EleutherAI/gpt-neo-1.3B', |
|
|
|
'funnel-transformer/intermediate', |
|
|
|
'funnel-transformer/large', |
|
|
|
'funnel-transformer/medium', |
|
|
|
'funnel-transformer/small', |
|
|
|
'funnel-transformer/xlarge', |
|
|
|
'google/bigbird-pegasus-large-arxiv', |
|
|
|
'google/bigbird-pegasus-large-bigpatent', |
|
|
|
'google/bigbird-pegasus-large-pubmed', |
|
|
|
'google/canine-c', |
|
|
|
'google/canine-s', |
|
|
|
'google/fnet-base', |
|
|
|
'google/fnet-large', |
|
|
|
'google/reformer-crime-and-punishment', |
|
|
|
'gpt2', |
|
|
|
'gpt2-large', |
|
|
|
'gpt2-medium', |
|
|
|
'gpt2-xl', |
|
|
|
'microsoft/deberta-v2-xlarge', |
|
|
|
'microsoft/deberta-v2-xlarge-mnli', |
|
|
|
'microsoft/deberta-v2-xxlarge', |
|
|
|
'microsoft/deberta-v2-xxlarge-mnli', |
|
|
|
'microsoft/deberta-xlarge', |
|
|
|
'microsoft/deberta-xlarge-mnli', |
|
|
|
'openai-gpt', |
|
|
|
'transfo-xl-wt103', |
|
|
|
'uw-madison/yoso-4096', |
|
|
|
'xlm-mlm-100-1280', |
|
|
|
'xlm-mlm-17-1280', |
|
|
|
'xlm-mlm-en-2048', |
|
|
|
'xlm-roberta-large', |
|
|
|
'xlm-roberta-large-finetuned-conll02-dutch', |
|
|
|
'xlm-roberta-large-finetuned-conll02-spanish', |
|
|
|
'xlm-roberta-large-finetuned-conll03-english', |
|
|
|
'xlm-roberta-large-finetuned-conll03-german', |
|
|
|
'xlnet-base-cased', |
|
|
|
'xlnet-large-cased' |
|
|
|
] |
|
|
|