logo
Browse Source

Fix for self._model device

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
5d235e994d
  1. 6
      auto_transformers.py

6
auto_transformers.py

@ -37,6 +37,7 @@ warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error()
def create_model(model_name, checkpoint_path, device):
model = AutoModel.from_pretrained(model_name).to(device)
if hasattr(model, 'pooler') and model.pooler:
@ -50,6 +51,7 @@ def create_model(model_name, checkpoint_path, device):
model.eval()
return model
# @accelerate
class Model:
def __init__(self, model_name, checkpoint_path, device):
@ -131,7 +133,7 @@ class AutoTransformers(NNOperator):
@property
def _model(self):
return self.model.model
return self.model.model.to('cpu')
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
if output_file == 'default':
@ -165,7 +167,7 @@ class AutoTransformers(NNOperator):
elif model_type == 'onnx':
from transformers.onnx.features import FeaturesManager
from transformers.onnx import export
self._model = self._model.to('cpu')
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self._model, feature='default')
onnx_config = model_onnx_config(self._model.config)

Loading…
Cancel
Save