|
@ -37,14 +37,33 @@ warnings.filterwarnings('ignore') |
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
t_logging.set_verbosity_error() |
|
|
t_logging.set_verbosity_error() |
|
|
|
|
|
|
|
|
|
|
|
def create_model(model_name, checkpoint_path, device): |
|
|
|
|
|
model = AutoModel.from_pretrained(model_name).to(device) |
|
|
|
|
|
if hasattr(model, 'pooler') and model.pooler: |
|
|
|
|
|
model.pooler = None |
|
|
|
|
|
if checkpoint_path: |
|
|
|
|
|
try: |
|
|
|
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
except Exception: |
|
|
|
|
|
log.error(f'Fail to load weights from {checkpoint_path}') |
|
|
|
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
# @accelerate |
|
|
# @accelerate |
|
|
class Model: |
|
|
class Model: |
|
|
def __init__(self, model): |
|
|
|
|
|
self.model = model |
|
|
|
|
|
|
|
|
def __init__(self, model_name, checkpoint_path, device): |
|
|
|
|
|
self.device = device |
|
|
|
|
|
self.model = create_model(model_name, checkpoint_path, device) |
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
def __call__(self, *args, **kwargs): |
|
|
outs = self.model(*args, **kwargs, return_dict=True) |
|
|
|
|
|
|
|
|
new_args = [] |
|
|
|
|
|
for x in args: |
|
|
|
|
|
new_args.append(x.to(self.device)) |
|
|
|
|
|
new_kwargs = {} |
|
|
|
|
|
for k, v in kwargs.items(): |
|
|
|
|
|
new_kwargs[k] = v.to(self.device) |
|
|
|
|
|
outs = self.model(*new_args, **new_kwargs, return_dict=True) |
|
|
return outs['last_hidden_state'] |
|
|
return outs['last_hidden_state'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -75,17 +94,13 @@ class AutoTransformers(NNOperator): |
|
|
self.checkpoint_path = checkpoint_path |
|
|
self.checkpoint_path = checkpoint_path |
|
|
|
|
|
|
|
|
if self.model_name: |
|
|
if self.model_name: |
|
|
model_list = self.supported_model_names() |
|
|
|
|
|
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" |
|
|
|
|
|
self.model = Model(self._model) |
|
|
|
|
|
if tokenizer is None: |
|
|
|
|
|
try: |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
log.error(f'Fail to load default tokenizer by name: {self.model_name}') |
|
|
|
|
|
raise e |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
# model_list = self.supported_model_names() |
|
|
|
|
|
# assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" |
|
|
|
|
|
self.model = Model(model_name=self.model_name, checkpoint_path=self.checkpoint_path, device=self.device) |
|
|
|
|
|
if tokenizer: |
|
|
self.tokenizer = tokenizer |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
else: |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
if not self.tokenizer.pad_token: |
|
|
if not self.tokenizer.pad_token: |
|
|
self.tokenizer.pad_token = '[PAD]' |
|
|
self.tokenizer.pad_token = '[PAD]' |
|
|
else: |
|
|
else: |
|
@ -98,7 +113,7 @@ class AutoTransformers(NNOperator): |
|
|
else: |
|
|
else: |
|
|
txt = data |
|
|
txt = data |
|
|
try: |
|
|
try: |
|
|
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device) |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt') |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
log.error(f'Fail to tokenize inputs: {e}') |
|
|
log.error(f'Fail to tokenize inputs: {e}') |
|
|
raise e |
|
|
raise e |
|
@ -116,17 +131,7 @@ class AutoTransformers(NNOperator): |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def _model(self): |
|
|
def _model(self): |
|
|
model = AutoModel.from_pretrained(self.model_name).to(self.device) |
|
|
|
|
|
if hasattr(model, 'pooler') and model.pooler: |
|
|
|
|
|
model.pooler = None |
|
|
|
|
|
if self.checkpoint_path: |
|
|
|
|
|
try: |
|
|
|
|
|
state_dict = torch.load(self.checkpoint_path, map_location=self.device) |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
except Exception: |
|
|
|
|
|
log.error(f'Fail to load weights from {self.checkpoint_path}') |
|
|
|
|
|
model.eval() |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
return self.model.model |
|
|
|
|
|
|
|
|
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): |
|
|
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): |
|
|
if output_file == 'default': |
|
|
if output_file == 'default': |
|
@ -160,6 +165,7 @@ class AutoTransformers(NNOperator): |
|
|
elif model_type == 'onnx': |
|
|
elif model_type == 'onnx': |
|
|
from transformers.onnx.features import FeaturesManager |
|
|
from transformers.onnx.features import FeaturesManager |
|
|
from transformers.onnx import export |
|
|
from transformers.onnx import export |
|
|
|
|
|
self._model = self._model.to('cpu') |
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
self._model, feature='default') |
|
|
self._model, feature='default') |
|
|
onnx_config = model_onnx_config(self._model.config) |
|
|
onnx_config = model_onnx_config(self._model.config) |
|
|