From a307a3196af1a32a394b4de345bd833656fe1f08 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Tue, 13 Dec 2022 16:29:19 +0800 Subject: [PATCH] Update model list Signed-off-by: Jael Gu --- README.md | 11 +++++++ auto_transformers.py | 68 ++++++++++++++++---------------------------- 2 files changed, 36 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 52f72db..d37c853 100644 --- a/README.md +++ b/README.md @@ -357,3 +357,14 @@ Get a list of all supported model names or supported model names for specified m ​ The model format such as 'pytorch', 'torchscript'. +```python +from towhee import ops + + +op = ops.text_embedding.transformers().get_op() +full_list = op.supported_model_names() +onnx_list = op.supported_model_names(format='onnx') +print(f'Onnx-support/Total Models: {len(onnx_list)}/{len(full_list)}') +``` + 2022-12-13 16:25:15,916 - 140704500614336 - auto_transformers.py-auto_transformers:68 - WARNING: The operator is initialized without specified model. + Onnx-support/Total Models: 111/126 diff --git a/auto_transformers.py b/auto_transformers.py index a253784..189bc5c 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -39,29 +39,34 @@ class AutoTransformers(NNOperator): Which model to use for the embeddings. """ - def __init__(self, model_name: str = "bert-base-uncased", device=None) -> None: + def __init__(self, model_name: str = None, device: str = None) -> None: super().__init__() if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device self.model_name = model_name - try: - self.model = AutoModel.from_pretrained(model_name).to(self.device) - self.model.eval() - self.configs = self.model.config - except Exception as e: - model_list = self.supported_model_names() - if model_name not in model_list: - log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}") - else: - log.error(f"Fail to load model by name: {self.model_name}") - raise e - try: - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - except Exception as e: - log.error(f'Fail to load tokenizer by name: {self.model_name}') - raise e + if self.model_name: + try: + self.model = AutoModel.from_pretrained(model_name).to(self.device) + self.model.eval() + + self.configs = self.model.config + except Exception as e: + model_list = self.supported_model_names() + if model_name not in model_list: + log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}") + else: + log.error(f"Fail to load model by name: {self.model_name}") + raise e + try: + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + except Exception as e: + log.error(f'Fail to load tokenizer by name: {self.model_name}') + raise e + else: + log.warning('The operator is initialized without specified model.') + pass def __call__(self, txt: str) -> numpy.ndarray: try: @@ -313,42 +318,19 @@ class AutoTransformers(NNOperator): model_list = list(set(full_list) - set(to_remove)) elif format == 'onnx': to_remove = [ - 'allenai/led-base-16384', 'ctrl', - 'distilgpt2', - 'EleutherAI/gpt-j-6B', - 'EleutherAI/gpt-neo-1.3B', + 'funnel-transformer/intermediate', 'funnel-transformer/large', 'funnel-transformer/medium', 'funnel-transformer/small', 'funnel-transformer/xlarge', - 'google/bigbird-pegasus-large-arxiv', - 'google/bigbird-pegasus-large-bigpatent', - 'google/bigbird-pegasus-large-pubmed', + 'google/canine-c', + 'google/canine-s', 'google/fnet-base', 'google/fnet-large', 'google/reformer-crime-and-punishment', - 'gpt2', - 'gpt2-large', - 'gpt2-medium', - 'gpt2-xl', - 'microsoft/deberta-v2-xlarge', - 'microsoft/deberta-v2-xlarge-mnli', - 'microsoft/deberta-v2-xxlarge', - 'microsoft/deberta-v2-xxlarge-mnli', - 'microsoft/deberta-xlarge', - 'microsoft/deberta-xlarge-mnli', - 'openai-gpt', 'transfo-xl-wt103', 'uw-madison/yoso-4096', - 'xlm-mlm-100-1280', - 'xlm-mlm-17-1280', - 'xlm-mlm-en-2048', - 'xlm-roberta-large', - 'xlm-roberta-large-finetuned-conll02-dutch', - 'xlm-roberta-large-finetuned-conll02-spanish', - 'xlm-roberta-large-finetuned-conll03-english', - 'xlm-roberta-large-finetuned-conll03-german', 'xlnet-base-cased', 'xlnet-large-cased' ]