|
|
@ -76,24 +76,25 @@ class AutoTransformers(NNOperator): |
|
|
|
vec = features.detach().numpy() |
|
|
|
return vec |
|
|
|
|
|
|
|
def save_model(self, jit: bool = True, destination: str = 'default'): |
|
|
|
if destination == 'default': |
|
|
|
def save_model(self, format: str = 'default', path: str = 'default'): |
|
|
|
if path == 'default': |
|
|
|
path = str(Path(__file__).parent) |
|
|
|
destination = os.path.join(path, self.model_name + '.pt') |
|
|
|
name = self.model_name.replace('/', '-') |
|
|
|
path = os.path.join(path, name + '.pt') |
|
|
|
inputs = self.tokenizer('[CLS]', return_tensors='pt') |
|
|
|
inputs = list(inputs.values()) |
|
|
|
if jit: |
|
|
|
if format == 'torchscript': |
|
|
|
try: |
|
|
|
try: |
|
|
|
traced_model = torch.jit.script(self.model) |
|
|
|
jit_model = torch.jit.script(self.model) |
|
|
|
except Exception: |
|
|
|
traced_model = torch.jit.trace(self.model, inputs, strict=False) |
|
|
|
torch.jit.save(traced_model, destination) |
|
|
|
jit_model = torch.jit.trace(self.model, inputs, strict=False) |
|
|
|
torch.jit.save(jit_model, path) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
else: |
|
|
|
torch.save(self.model, destination) |
|
|
|
torch.save(self.model, path) |
|
|
|
|
|
|
|
|
|
|
|
def get_model_list(): |
|
|
|