logo
Browse Source

Modify save_model

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
0ee984c935
  1. 17
      auto_transformers.py
  2. 11
      test_save.py

17
auto_transformers.py

@ -76,24 +76,25 @@ class AutoTransformers(NNOperator):
vec = features.detach().numpy()
return vec
def save_model(self, jit: bool = True, destination: str = 'default'):
if destination == 'default':
def save_model(self, format: str = 'default', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
destination = os.path.join(path, self.model_name + '.pt')
name = self.model_name.replace('/', '-')
path = os.path.join(path, name + '.pt')
inputs = self.tokenizer('[CLS]', return_tensors='pt')
inputs = list(inputs.values())
if jit:
if format == 'torchscript':
try:
try:
traced_model = torch.jit.script(self.model)
jit_model = torch.jit.script(self.model)
except Exception:
traced_model = torch.jit.trace(self.model, inputs, strict=False)
torch.jit.save(traced_model, destination)
jit_model = torch.jit.trace(self.model, inputs, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
else:
torch.save(self.model, destination)
torch.save(self.model, path)
def get_model_list():

11
test_save.py

@ -2,14 +2,19 @@ from auto_transformers import AutoTransformers
import torch
models = ['bert-base-cased', 'distilbert-base-cased', 'distilgpt2']
models = [
'bert-base-cased',
'distilbert-base-cased',
'distilgpt2',
'google/fnet-base'
]
for name in models:
try:
op = AutoTransformers(model_name=name)
out1 = op('hello, world.')
op.save_model()
op.model = torch.jit.load(name + '.pt')
op.save_model(format='torchscript')
op.model = torch.jit.load(name.replace('/', '-') + '.pt')
out2 = op('hello, world.')
assert (out1 == out2).all()
print(f'[SUCCESS] Saved torchscript for model "{name}"')

Loading…
Cancel
Save