|
@ -140,14 +140,14 @@ class AutoTransformers(NNOperator): |
|
|
dummy_input = '[CLS]' |
|
|
dummy_input = '[CLS]' |
|
|
inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary |
|
|
inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary |
|
|
if model_type == 'pytorch': |
|
|
if model_type == 'pytorch': |
|
|
torch.save(self.model, output_file) |
|
|
|
|
|
|
|
|
torch.save(self._model, output_file) |
|
|
elif model_type == 'torchscript': |
|
|
elif model_type == 'torchscript': |
|
|
inputs = list(inputs.values()) |
|
|
inputs = list(inputs.values()) |
|
|
try: |
|
|
try: |
|
|
try: |
|
|
try: |
|
|
jit_model = torch.jit.script(self.model) |
|
|
|
|
|
|
|
|
jit_model = torch.jit.script(self._model) |
|
|
except Exception: |
|
|
except Exception: |
|
|
jit_model = torch.jit.trace(self.model, inputs, strict=False) |
|
|
|
|
|
|
|
|
jit_model = torch.jit.trace(self._model, inputs, strict=False) |
|
|
torch.jit.save(jit_model, output_file) |
|
|
torch.jit.save(jit_model, output_file) |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
@ -171,7 +171,7 @@ class AutoTransformers(NNOperator): |
|
|
# todo: elif format == 'tensorrt': |
|
|
# todo: elif format == 'tensorrt': |
|
|
else: |
|
|
else: |
|
|
log.error(f'Unsupported format "{format}".') |
|
|
log.error(f'Unsupported format "{format}".') |
|
|
return True |
|
|
|
|
|
|
|
|
return Path(output_file).resolve() |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def supported_formats(self): |
|
|
def supported_formats(self): |
|
|