logo
Browse Source

Update model list

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
a307a3196a
  1. 11
      README.md
  2. 68
      auto_transformers.py

11
README.md

@ -357,3 +357,14 @@ Get a list of all supported model names or supported model names for specified m
​ The model format such as 'pytorch', 'torchscript'.
```python
from towhee import ops
op = ops.text_embedding.transformers().get_op()
full_list = op.supported_model_names()
onnx_list = op.supported_model_names(format='onnx')
print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}')
```
2022-12-13 16:25:15,916 - 140704500614336 - auto_transformers.py-auto_transformers:68 - WARNING: The operator is initialized without specified model.
Onnx-support/Total Models: 111/126

68
auto_transformers.py

@ -39,29 +39,34 @@ class AutoTransformers(NNOperator):
Which model to use for the embeddings.
"""
def __init__(self, model_name: str = "bert-base-uncased", device=None) -> None:
def __init__(self, model_name: str = None, device: str = None) -> None:
super().__init__()
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.model_name = model_name
try:
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
self.configs = self.model.config
except Exception as e:
model_list = self.supported_model_names()
if model_name not in model_list:
log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}")
else:
log.error(f"Fail to load model by name: {self.model_name}")
raise e
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
log.error(f'Fail to load tokenizer by name: {self.model_name}')
raise e
if self.model_name:
try:
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
self.configs = self.model.config
except Exception as e:
model_list = self.supported_model_names()
if model_name not in model_list:
log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}")
else:
log.error(f"Fail to load model by name: {self.model_name}")
raise e
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
log.error(f'Fail to load tokenizer by name: {self.model_name}')
raise e
else:
log.warning('The operator is initialized without specified model.')
pass
def __call__(self, txt: str) -> numpy.ndarray:
try:
@ -313,42 +318,19 @@ class AutoTransformers(NNOperator):
model_list = list(set(full_list) - set(to_remove))
elif format == 'onnx':
to_remove = [
'allenai/led-base-16384',
'ctrl',
'distilgpt2',
'EleutherAI/gpt-j-6B',
'EleutherAI/gpt-neo-1.3B',
'funnel-transformer/intermediate',
'funnel-transformer/large',
'funnel-transformer/medium',
'funnel-transformer/small',
'funnel-transformer/xlarge',
'google/bigbird-pegasus-large-arxiv',
'google/bigbird-pegasus-large-bigpatent',
'google/bigbird-pegasus-large-pubmed',
'google/canine-c',
'google/canine-s',
'google/fnet-base',
'google/fnet-large',
'google/reformer-crime-and-punishment',
'gpt2',
'gpt2-large',
'gpt2-medium',
'gpt2-xl',
'microsoft/deberta-v2-xlarge',
'microsoft/deberta-v2-xlarge-mnli',
'microsoft/deberta-v2-xxlarge',
'microsoft/deberta-v2-xxlarge-mnli',
'microsoft/deberta-xlarge',
'microsoft/deberta-xlarge-mnli',
'openai-gpt',
'transfo-xl-wt103',
'uw-madison/yoso-4096',
'xlm-mlm-100-1280',
'xlm-mlm-17-1280',
'xlm-mlm-en-2048',
'xlm-roberta-large',
'xlm-roberta-large-finetuned-conll02-dutch',
'xlm-roberta-large-finetuned-conll02-spanish',
'xlm-roberta-large-finetuned-conll03-english',
'xlm-roberta-large-finetuned-conll03-german',
'xlnet-base-cased',
'xlnet-large-cased'
]

Loading…
Cancel
Save