diff --git a/auto_transformers.py b/auto_transformers.py index 225b3e2..d9662a2 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -84,8 +84,10 @@ class AutoTransformers(NNOperator): self.model_name = model_name if self.model_name: - self.model = Model(model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path) - self.configs = self.model.model.config + self.accelerate_model = Model( + model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path) + self.model = self.accelerate_model.model + self.configs = self.model.config if tokenizer is None: try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -105,7 +107,7 @@ class AutoTransformers(NNOperator): log.error(f'Invalid input for the tokenizer: {self.model_name}') raise e try: - outs = self.model(**inputs) + outs = self.accelerate_model(**inputs) except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e @@ -144,13 +146,13 @@ class AutoTransformers(NNOperator): from transformers.onnx.features import FeaturesManager from transformers.onnx import export model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( - self.model.model, feature='default') + self.model, feature='default') onnx_config = model_onnx_config(self.configs) if os.path.isdir(path): shutil.rmtree(path) onnx_inputs, onnx_outputs = export( self.tokenizer, - self.model.model, + self.model, config=onnx_config, opset=13, output=Path(path+'.onnx')