logo
Browse Source

Update op to support both models

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
9c26196077
  1. 12
      auto_transformers.py

12
auto_transformers.py

@ -84,8 +84,10 @@ class AutoTransformers(NNOperator):
self.model_name = model_name self.model_name = model_name
if self.model_name: if self.model_name:
self.model = Model(model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path)
self.configs = self.model.model.config
self.accelerate_model = Model(
model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path)
self.model = self.accelerate_model.model
self.configs = self.model.config
if tokenizer is None: if tokenizer is None:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -105,7 +107,7 @@ class AutoTransformers(NNOperator):
log.error(f'Invalid input for the tokenizer: {self.model_name}') log.error(f'Invalid input for the tokenizer: {self.model_name}')
raise e raise e
try: try:
outs = self.model(**inputs)
outs = self.accelerate_model(**inputs)
except Exception as e: except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}') log.error(f'Invalid input for the model: {self.model_name}')
raise e raise e
@ -144,13 +146,13 @@ class AutoTransformers(NNOperator):
from transformers.onnx.features import FeaturesManager from transformers.onnx.features import FeaturesManager
from transformers.onnx import export from transformers.onnx import export
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self.model.model, feature='default')
self.model, feature='default')
onnx_config = model_onnx_config(self.configs) onnx_config = model_onnx_config(self.configs)
if os.path.isdir(path): if os.path.isdir(path):
shutil.rmtree(path) shutil.rmtree(path)
onnx_inputs, onnx_outputs = export( onnx_inputs, onnx_outputs = export(
self.tokenizer, self.tokenizer,
self.model.model,
self.model,
config=onnx_config, config=onnx_config,
opset=13, opset=13,
output=Path(path+'.onnx') output=Path(path+'.onnx')

Loading…
Cancel
Save