diff --git a/README.md b/README.md index c5e9757..6b45e09 100644 --- a/README.md +++ b/README.md @@ -316,10 +316,11 @@ The operator takes a piece of text in string as input. It loads tokenizer and pre-trained model using model name. and then return text embedding in ndarray. +***__call__(txt)*** **Parameters:** -***text***: *str* +***txt***: *str* ​ The text in string. @@ -329,3 +330,29 @@ and then return text embedding in ndarray. *numpy.ndarray* ​ The text embedding extracted by model. + + +***save_model(format='pytorch', path='default')*** + +Save model to local with specified format. + +**Parameters:** + +***format***: *str* + +​ The format of saved model, defaults to 'pytorch'. + +***format***: *path* + +​ The path where model is saved to. By default, it will save model to the operator directory. + + +***supported_model_names(format=None)*** + +Get a list of all supported model names or supported model names for specified model format. + +**Parameters:** + +***format***: *str* + +​ The model format such as 'pytorch', 'torchscript'. diff --git a/auto_transformers.py b/auto_transformers.py index 3d14f52..ce5b15c 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -76,13 +76,15 @@ class AutoTransformers(NNOperator): vec = features.detach().numpy() return vec - def save_model(self, format: str = 'default', path: str = 'default'): + def save_model(self, format: str = 'pytorch', path: str = 'default'): if path == 'default': path = str(Path(__file__).parent) name = self.model_name.replace('/', '-') path = os.path.join(path, name) inputs = self.tokenizer('[CLS]', return_tensors='pt') - if format == 'torchscript': + if format == 'pytorch': + torch.save(self.model, path) + elif format == 'torchscript': path = path + '.pt' inputs = list(inputs.values()) try: @@ -95,12 +97,13 @@ class AutoTransformers(NNOperator): log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onxx': - pass + pass # todo else: - torch.save(self.model, path) + log.error(f'Unsupported format "{format}".') - def get_model_name(self): - full_list = [ + @staticmethod + def supported_model_names(format: str = None): + model_list = [ "bert-large-uncased", "bert-base-cased", "bert-large-cased", @@ -228,5 +231,59 @@ class AutoTransformers(NNOperator): "Musixmatch/umberto-commoncrawl-cased-v1", "Musixmatch/umberto-wikipedia-uncased-v1", ] - full_list.sort() - return full_list + model_list.sort() + if format is None: + pass + elif format == 'pytorch': + pass + elif format == 'torchscript': + model_list.remove( + 'EleutherAI/gpt-j-6B', + 'EleutherAI/gpt-neo-1.3B', + 'allenai/led-base-16384', + 'ctrl', + 'distilgpt2', + 'facebook/bart-large', + 'google/bigbird-pegasus-large-arxiv', + 'google/bigbird-pegasus-large-bigpatent', + 'google/bigbird-pegasus-large-pubmed', + 'google/canine-c', + 'google/canine-s', + 'google/reformer-crime-and-punishment', + 'gpt2', + 'gpt2-large', + 'gpt2-medium', + 'gpt2-xl', + 'microsoft/deberta-base', + 'microsoft/deberta-base-mnli', + 'microsoft/deberta-large', + 'microsoft/deberta-large-mnli', + 'microsoft/deberta-xlarge', + 'microsoft/deberta-xlarge-mnli', + 'openai-gpt', + 'transfo-xl-wt103', + 'uw-madison/yoso-4096', + 'xlm-clm-ende-1024', + 'xlm-clm-enfr-1024', + 'xlm-mlm-100-1280', + 'xlm-mlm-17-1280', + 'xlm-mlm-en-2048', + 'xlm-mlm-ende-1024', + 'xlm-mlm-enfr-1024', + 'xlm-mlm-enro-1024', + 'xlm-mlm-tlm-xnli15-1024', + 'xlm-mlm-xnli15-1024', + 'xlnet-base-cased', + 'xlnet-large-cased', + 'microsoft/deberta-v2-xlarge', + 'microsoft/deberta-v2-xxlarge', + 'microsoft/deberta-v2-xlarge-mnli', + 'microsoft/deberta-v2-xxlarge-mnli', + 'flaubert/flaubert_small_cased', + 'flaubert/flaubert_base_uncased', + 'flaubert/flaubert_base_cased', + 'flaubert/flaubert_large_cased' + ) + else: # todo: format in {'onnx', 'tensorrt'} + log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".') + return model_list