logo
Browse Source

Modify save_model to support tritonserver

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
5328585568
  1. 44
      auto_transformers.py
  2. 4
      test_onnx.py

44
auto_transformers.py

@ -119,47 +119,63 @@ class AutoTransformers(NNOperator):
vec = features.cpu().detach().numpy() vec = features.cpu().detach().numpy()
return vec return vec
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
if output_file == 'default':
output_file = str(Path(__file__).parent)
output_file = os.path.join(output_file, 'saved', model_type)
os.makedirs(output_file, exist_ok=True)
name = self.model_name.replace('/', '-') name = self.model_name.replace('/', '-')
path = os.path.join(path, name)
output_file = os.path.join(output_file, name)
if model_type in ['pytorch', 'torchscript']:
output_file = output_file + '.pt'
elif model_type == 'onnx':
output_file = output_file + '.onnx'
else:
raise AttributeError('Unsupported model_type.')
dummy_input = '[CLS]' dummy_input = '[CLS]'
inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary
if format == 'pytorch':
torch.save(self.model, path + '.pt')
elif format == 'torchscript':
if model_type == 'pytorch':
torch.save(self.model, output_file)
elif model_type == 'torchscript':
inputs = list(inputs.values()) inputs = list(inputs.values())
try: try:
try: try:
jit_model = torch.jit.script(self.model) jit_model = torch.jit.script(self.model)
except Exception: except Exception:
jit_model = torch.jit.trace(self.model, inputs, strict=False) jit_model = torch.jit.trace(self.model, inputs, strict=False)
torch.jit.save(jit_model, path + '.pt')
torch.jit.save(jit_model, output_file)
except Exception as e: except Exception as e:
log.error(f'Fail to save as torchscript: {e}.') log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
elif model_type == 'onnx':
from transformers.onnx.features import FeaturesManager from transformers.onnx.features import FeaturesManager
from transformers.onnx import export from transformers.onnx import export
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self.model, feature='default') self.model, feature='default')
onnx_config = model_onnx_config(self.configs) onnx_config = model_onnx_config(self.configs)
if os.path.isdir(path):
shutil.rmtree(path)
# if os.path.isdir(output_file[:-5]):
# shutil.rmtree(output_file[:-5])
# print('********', Path(output_file))
onnx_inputs, onnx_outputs = export( onnx_inputs, onnx_outputs = export(
self.tokenizer, self.tokenizer,
self.model, self.model,
config=onnx_config, config=onnx_config,
opset=13, opset=13,
output=Path(path+'.onnx')
output=Path(output_file)
) )
# todo: elif format == 'tensorrt': # todo: elif format == 'tensorrt':
else: else:
log.error(f'Unsupported format "{format}".') log.error(f'Unsupported format "{format}".')
return True
@property
def supported_formats(self):
onnxes = self.supported_model_names(format='onnx')
if self.model_name in onnxes:
return ['onnx']
else:
return ['pytorch']
@staticmethod @staticmethod
def supported_model_names(format: str = None): def supported_model_names(format: str = None):

4
test_onnx.py

@ -12,7 +12,7 @@ import psutil
# full_models = AutoTransformers.supported_model_names() # full_models = AutoTransformers.supported_model_names()
# checked_models = AutoTransformers.supported_model_names(format='onnx') # checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models] # models = [x for x in full_models if x not in checked_models]
models = ['bert-base-cased', 'distilbert-base-cased']
models = ['distilbert-base-cased']
test_txt = 'hello, world.' test_txt = 'hello, world.'
atol = 1e-3 atol = 1e-3
log_path = 'transformers_onnx.log' log_path = 'transformers_onnx.log'
@ -55,7 +55,7 @@ for name in models:
logger.error(f'FAIL TO LOAD OP: {e}') logger.error(f'FAIL TO LOAD OP: {e}')
continue continue
try: try:
op.save_model(format='onnx')
op.save_model(model_type='onnx')
logger.info('ONNX SAVED.') logger.info('ONNX SAVED.')
status[2] = 'success' status[2] = 'success'
except Exception as e: except Exception as e:

Loading…
Cancel
Save