diff --git a/README.md b/README.md index 0c5d656..bd578d1 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,9 @@ Create the operator via the following factory method: The model name in string, defaults to None. If None, the operator will be initialized without specified model. -Supported model names: refer to `supported_model_names` below. +Supported model names: NLP transformers models listed in [Huggingface Models](https://huggingface.co/models). +Please note that only models listed in `supported_model_names` are tested. +You can refer to [Towhee Pipeline]() for benchmark. ***checkpoint_path***: *str* diff --git a/auto_transformers.py b/auto_transformers.py index bf278d7..b91d264 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -73,19 +73,19 @@ class AutoTransformers(NNOperator): else: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model_name = self.map_model_names(model_name) - if tokenizer: - self.tokenizer = tokenizer - else: - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) - if not self.tokenizer.pad_token: - self.tokenizer.pad_token = '[PAD]' self.norm = norm self.checkpoint_path = checkpoint_path if self.model_name: - model_list = self.supported_model_names() + # model_list = self.supported_model_names() # assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" self.model = Model(self._model) + if tokenizer: + self.tokenizer = tokenizer + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + if not self.tokenizer.pad_token: + self.tokenizer.pad_token = '[PAD]' else: log.warning('The operator is initialized without specified model.') pass @@ -205,14 +205,6 @@ class AutoTransformers(NNOperator): log.error(f'Unsupported format "{format}".') return Path(output_file).resolve() - @property - def supported_formats(self): - onnxes = self.supported_model_names(format='onnx') - if self.model_name in onnxes: - return ['onnx'] - else: - return ['pytorch'] - @property def supported_formats(self): return ['onnx'] @@ -220,7 +212,60 @@ class AutoTransformers(NNOperator): @staticmethod def supported_model_names(format: str = None): full_list = [ - + 'sentence-t5-xxl', + 'sentence-t5-xl', + 'sentence-t5-large', + 'text-embedding-ada-002', + 'text-similarity-davinci-001', + 'text-similarity-babbage-001', + 'text-similarity-curie-001', + 'paraphrase-mpnet-base-v2', + 'text-similarity-ada-001', + 'gtr-t5-xxl', + 'gtr-t5-large', + 'gtr-t5-xl', + 'paraphrase-multilingual-mpnet-base-v2', + 'paraphrase-distilroberta-base-v2', + 'all-mpnet-base-v1', + 'all-roberta-large-v1', + 'all-mpnet-base-v2', + 'all-MiniLM-L12-v2', + 'all-distilroberta-v1', + 'all-MiniLM-L12-v1', + 'gtr-t5-base', + 'paraphrase-multilingual-MiniLM-L12-v2', + 'paraphrase-MiniLM-L12-v2', + 'all-MiniLM-L6-v1', + 'paraphrase-TinyBERT-L6-v2', + 'all-MiniLM-L6-v2', + 'paraphrase-albert-small-v2', + 'multi-qa-mpnet-base-cos-v1', + 'paraphrase-MiniLM-L3-v2', + 'multi-qa-distilbert-cos-v1', + 'msmarco-distilbert-base-v4', + 'multi-qa-mpnet-base-dot-v1', + 'msmarco-distilbert-base-tas-b', + 'distiluse-base-multilingual-cased-v2', + 'multi-qa-distilbert-dot-v1', + 'multi-qa-MiniLM-L6-cos-v1', + 'distiluse-base-multilingual-cased-v1', + 'msmarco-bert-base-dot-v5', + 'paraphrase-MiniLM-L6-v2', + 'multi-qa-MiniLM-L6-dot-v1', + 'msmarco-distilbert-dot-v5', + 'bert-base-nli-mean-tokens', + 'bert-large-uncased-whole-word-masking', + 'average_word_embeddings_komninos', + 'distilbert-base-uncased', + 'average_word_embeddings_glove.6B.300d', + 'dpr-ctx_encoder-multiset-base', + 'dpr-ctx_encoder-single-nq-base', + 'microsoft/deberta-xlarge', + 'facebook/bart-large', + 'bert-base-uncased', + 'microsoft/deberta-xlarge-mnli', + 'gpt2-xl', + 'bert-large-uncased' ] full_list.sort() if format is None: @@ -230,13 +275,11 @@ class AutoTransformers(NNOperator): assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) elif format == 'torchscript': - to_remove = [ - ] + to_remove = [] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) elif format == 'onnx': - to_remove = [ - ] + to_remove = [] assert set(to_remove).issubset(set(full_list)) model_list = list(set(full_list) - set(to_remove)) # todo: elif format == 'tensorrt':