diff --git a/auto_transformers.py b/auto_transformers.py index c032711..46ab490 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -76,24 +76,25 @@ class AutoTransformers(NNOperator): vec = features.detach().numpy() return vec - def save_model(self, jit: bool = True, destination: str = 'default'): - if destination == 'default': + def save_model(self, format: str = 'default', path: str = 'default'): + if path == 'default': path = str(Path(__file__).parent) - destination = os.path.join(path, self.model_name + '.pt') + name = self.model_name.replace('/', '-') + path = os.path.join(path, name + '.pt') inputs = self.tokenizer('[CLS]', return_tensors='pt') inputs = list(inputs.values()) - if jit: + if format == 'torchscript': try: try: - traced_model = torch.jit.script(self.model) + jit_model = torch.jit.script(self.model) except Exception: - traced_model = torch.jit.trace(self.model, inputs, strict=False) - torch.jit.save(traced_model, destination) + jit_model = torch.jit.trace(self.model, inputs, strict=False) + torch.jit.save(jit_model, path) except Exception as e: log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') else: - torch.save(self.model, destination) + torch.save(self.model, path) def get_model_list(): diff --git a/test_save.py b/test_save.py index 3cf21d0..5b4d2ae 100644 --- a/test_save.py +++ b/test_save.py @@ -2,14 +2,19 @@ from auto_transformers import AutoTransformers import torch -models = ['bert-base-cased', 'distilbert-base-cased', 'distilgpt2'] +models = [ + 'bert-base-cased', + 'distilbert-base-cased', + 'distilgpt2', + 'google/fnet-base' +] for name in models: try: op = AutoTransformers(model_name=name) out1 = op('hello, world.') - op.save_model() - op.model = torch.jit.load(name + '.pt') + op.save_model(format='torchscript') + op.model = torch.jit.load(name.replace('/', '-') + '.pt') out2 = op('hello, world.') assert (out1 == out2).all() print(f'[SUCCESS] Saved torchscript for model "{name}"')