|
|
@ -119,47 +119,63 @@ class AutoTransformers(NNOperator): |
|
|
|
vec = features.cpu().detach().numpy() |
|
|
|
return vec |
|
|
|
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|
|
if path == 'default': |
|
|
|
path = str(Path(__file__).parent) |
|
|
|
path = os.path.join(path, 'saved', format) |
|
|
|
os.makedirs(path, exist_ok=True) |
|
|
|
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): |
|
|
|
if output_file == 'default': |
|
|
|
output_file = str(Path(__file__).parent) |
|
|
|
output_file = os.path.join(output_file, 'saved', model_type) |
|
|
|
os.makedirs(output_file, exist_ok=True) |
|
|
|
name = self.model_name.replace('/', '-') |
|
|
|
path = os.path.join(path, name) |
|
|
|
output_file = os.path.join(output_file, name) |
|
|
|
if model_type in ['pytorch', 'torchscript']: |
|
|
|
output_file = output_file + '.pt' |
|
|
|
elif model_type == 'onnx': |
|
|
|
output_file = output_file + '.onnx' |
|
|
|
else: |
|
|
|
raise AttributeError('Unsupported model_type.') |
|
|
|
|
|
|
|
dummy_input = '[CLS]' |
|
|
|
inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary |
|
|
|
if format == 'pytorch': |
|
|
|
torch.save(self.model, path + '.pt') |
|
|
|
elif format == 'torchscript': |
|
|
|
if model_type == 'pytorch': |
|
|
|
torch.save(self.model, output_file) |
|
|
|
elif model_type == 'torchscript': |
|
|
|
inputs = list(inputs.values()) |
|
|
|
try: |
|
|
|
try: |
|
|
|
jit_model = torch.jit.script(self.model) |
|
|
|
except Exception: |
|
|
|
jit_model = torch.jit.trace(self.model, inputs, strict=False) |
|
|
|
torch.jit.save(jit_model, path + '.pt') |
|
|
|
torch.jit.save(jit_model, output_file) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
elif format == 'onnx': |
|
|
|
elif model_type == 'onnx': |
|
|
|
from transformers.onnx.features import FeaturesManager |
|
|
|
from transformers.onnx import export |
|
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
|
self.model, feature='default') |
|
|
|
onnx_config = model_onnx_config(self.configs) |
|
|
|
if os.path.isdir(path): |
|
|
|
shutil.rmtree(path) |
|
|
|
# if os.path.isdir(output_file[:-5]): |
|
|
|
# shutil.rmtree(output_file[:-5]) |
|
|
|
# print('********', Path(output_file)) |
|
|
|
onnx_inputs, onnx_outputs = export( |
|
|
|
self.tokenizer, |
|
|
|
self.model, |
|
|
|
config=onnx_config, |
|
|
|
opset=13, |
|
|
|
output=Path(path+'.onnx') |
|
|
|
output=Path(output_file) |
|
|
|
) |
|
|
|
# todo: elif format == 'tensorrt': |
|
|
|
else: |
|
|
|
log.error(f'Unsupported format "{format}".') |
|
|
|
return True |
|
|
|
|
|
|
|
@property |
|
|
|
def supported_formats(self): |
|
|
|
onnxes = self.supported_model_names(format='onnx') |
|
|
|
if self.model_name in onnxes: |
|
|
|
return ['onnx'] |
|
|
|
else: |
|
|
|
return ['pytorch'] |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def supported_model_names(format: str = None): |
|
|
|