diff --git a/auto_transformers.py b/auto_transformers.py index 2e15d5a..7d88542 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -84,10 +84,8 @@ class AutoTransformers(NNOperator): self.model_name = model_name if self.model_name: - self.accelerate_model = Model( + self.model = Model( model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path) - self.model = self.accelerate_model.model - self.configs = self.model.config if tokenizer is None: try: self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -107,7 +105,7 @@ class AutoTransformers(NNOperator): log.error(f'Invalid input for the tokenizer: {self.model_name}') raise e try: - outs = self.accelerate_model(**inputs) + outs = self.model(**inputs) except Exception as e: log.error(f'Invalid input for the model: {self.model_name}') raise e @@ -119,6 +117,10 @@ class AutoTransformers(NNOperator): vec = features.cpu().detach().numpy() return vec + @property + def _model(self): + return self.model.model + def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): if output_file == 'default': output_file = str(Path(__file__).parent) @@ -152,14 +154,14 @@ class AutoTransformers(NNOperator): from transformers.onnx.features import FeaturesManager from transformers.onnx import export model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( - self.model, feature='default') - onnx_config = model_onnx_config(self.configs) + self._model, feature='default') + onnx_config = model_onnx_config(self._model.config) # if os.path.isdir(output_file[:-5]): # shutil.rmtree(output_file[:-5]) # print('********', Path(output_file)) onnx_inputs, onnx_outputs = export( self.tokenizer, - self.model, + self._model, config=onnx_config, opset=13, output=Path(output_file)