diff --git a/auto_transformers.py b/auto_transformers.py index a5e741f..02422d2 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -37,6 +37,7 @@ warnings.filterwarnings('ignore') os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' t_logging.set_verbosity_error() + def create_model(model_name, checkpoint_path, device): model = AutoModel.from_pretrained(model_name).to(device) if hasattr(model, 'pooler') and model.pooler: @@ -50,6 +51,7 @@ def create_model(model_name, checkpoint_path, device): model.eval() return model + # @accelerate class Model: def __init__(self, model_name, checkpoint_path, device): @@ -131,7 +133,7 @@ class AutoTransformers(NNOperator): @property def _model(self): - return self.model.model + return self.model.model.to('cpu') def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): if output_file == 'default': @@ -165,7 +167,7 @@ class AutoTransformers(NNOperator): elif model_type == 'onnx': from transformers.onnx.features import FeaturesManager from transformers.onnx import export - self._model = self._model.to('cpu') + model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( self._model, feature='default') onnx_config = model_onnx_config(self._model.config)