|
@ -84,8 +84,10 @@ class AutoTransformers(NNOperator): |
|
|
self.model_name = model_name |
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
|
if self.model_name: |
|
|
if self.model_name: |
|
|
self.model = Model(model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path) |
|
|
|
|
|
self.configs = self.model.model.config |
|
|
|
|
|
|
|
|
self.accelerate_model = Model( |
|
|
|
|
|
model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path) |
|
|
|
|
|
self.model = self.accelerate_model.model |
|
|
|
|
|
self.configs = self.model.config |
|
|
if tokenizer is None: |
|
|
if tokenizer is None: |
|
|
try: |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
@ -105,7 +107,7 @@ class AutoTransformers(NNOperator): |
|
|
log.error(f'Invalid input for the tokenizer: {self.model_name}') |
|
|
log.error(f'Invalid input for the tokenizer: {self.model_name}') |
|
|
raise e |
|
|
raise e |
|
|
try: |
|
|
try: |
|
|
outs = self.model(**inputs) |
|
|
|
|
|
|
|
|
outs = self.accelerate_model(**inputs) |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
raise e |
|
|
raise e |
|
@ -144,13 +146,13 @@ class AutoTransformers(NNOperator): |
|
|
from transformers.onnx.features import FeaturesManager |
|
|
from transformers.onnx.features import FeaturesManager |
|
|
from transformers.onnx import export |
|
|
from transformers.onnx import export |
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
self.model.model, feature='default') |
|
|
|
|
|
|
|
|
self.model, feature='default') |
|
|
onnx_config = model_onnx_config(self.configs) |
|
|
onnx_config = model_onnx_config(self.configs) |
|
|
if os.path.isdir(path): |
|
|
if os.path.isdir(path): |
|
|
shutil.rmtree(path) |
|
|
shutil.rmtree(path) |
|
|
onnx_inputs, onnx_outputs = export( |
|
|
onnx_inputs, onnx_outputs = export( |
|
|
self.tokenizer, |
|
|
self.tokenizer, |
|
|
self.model.model, |
|
|
|
|
|
|
|
|
self.model, |
|
|
config=onnx_config, |
|
|
config=onnx_config, |
|
|
opset=13, |
|
|
opset=13, |
|
|
output=Path(path+'.onnx') |
|
|
output=Path(path+'.onnx') |
|
|