|
|
@ -37,6 +37,7 @@ warnings.filterwarnings('ignore') |
|
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
t_logging.set_verbosity_error() |
|
|
|
|
|
|
|
|
|
|
|
def create_model(model_name, checkpoint_path, device): |
|
|
|
model = AutoModel.from_pretrained(model_name).to(device) |
|
|
|
if hasattr(model, 'pooler') and model.pooler: |
|
|
@ -50,6 +51,7 @@ def create_model(model_name, checkpoint_path, device): |
|
|
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
# @accelerate |
|
|
|
class Model: |
|
|
|
def __init__(self, model_name, checkpoint_path, device): |
|
|
@ -131,7 +133,7 @@ class AutoTransformers(NNOperator): |
|
|
|
|
|
|
|
@property |
|
|
|
def _model(self): |
|
|
|
return self.model.model |
|
|
|
return self.model.model.to('cpu') |
|
|
|
|
|
|
|
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): |
|
|
|
if output_file == 'default': |
|
|
@ -165,7 +167,7 @@ class AutoTransformers(NNOperator): |
|
|
|
elif model_type == 'onnx': |
|
|
|
from transformers.onnx.features import FeaturesManager |
|
|
|
from transformers.onnx import export |
|
|
|
self._model = self._model.to('cpu') |
|
|
|
|
|
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
|
self._model, feature='default') |
|
|
|
onnx_config = model_onnx_config(self._model.config) |
|
|
|