logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
28419844ff
  1. 8
      README.md
  2. 56
      auto_transformers.py

8
README.md

@ -48,7 +48,7 @@ Create the operator via the following factory method:
The model name in string, defaults to None. The model name in string, defaults to None.
If None, the operator will be initialized without specified model. If None, the operator will be initialized without specified model.
Supported model names:
Please note only supported models are tested by us:
<details><summary>Albert</summary> <details><summary>Albert</summary>
@ -309,6 +309,12 @@ If None, the operator will download and load pretrained model by `model_name` fr
<br /> <br />
***device***: *str*
The device in string, defaults to None. If None, it will enable "cuda" automatically when cuda is available.
<br />
***tokenizer***: *object* ***tokenizer***: *object*
The method to tokenize input text, defaults to None. The method to tokenize input text, defaults to None.

56
auto_transformers.py

@ -37,14 +37,33 @@ warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error() t_logging.set_verbosity_error()
def create_model(model_name, checkpoint_path, device):
model = AutoModel.from_pretrained(model_name).to(device)
if hasattr(model, 'pooler') and model.pooler:
model.pooler = None
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
except Exception:
log.error(f'Fail to load weights from {checkpoint_path}')
model.eval()
return model
# @accelerate # @accelerate
class Model: class Model:
def __init__(self, model):
self.model = model
def __init__(self, model_name, checkpoint_path, device):
self.device = device
self.model = create_model(model_name, checkpoint_path, device)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
outs = self.model(*args, **kwargs, return_dict=True)
new_args = []
for x in args:
new_args.append(x.to(self.device))
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v.to(self.device)
outs = self.model(*new_args, **new_kwargs, return_dict=True)
return outs['last_hidden_state'] return outs['last_hidden_state']
@ -75,17 +94,13 @@ class AutoTransformers(NNOperator):
self.checkpoint_path = checkpoint_path self.checkpoint_path = checkpoint_path
if self.model_name: if self.model_name:
model_list = self.supported_model_names()
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
self.model = Model(self._model)
if tokenizer is None:
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
log.error(f'Fail to load default tokenizer by name: {self.model_name}')
raise e
else:
# model_list = self.supported_model_names()
# assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
self.model = Model(model_name=self.model_name, checkpoint_path=self.checkpoint_path, device=self.device)
if tokenizer:
self.tokenizer = tokenizer self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if not self.tokenizer.pad_token: if not self.tokenizer.pad_token:
self.tokenizer.pad_token = '[PAD]' self.tokenizer.pad_token = '[PAD]'
else: else:
@ -98,7 +113,7 @@ class AutoTransformers(NNOperator):
else: else:
txt = data txt = data
try: try:
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device)
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt')
except Exception as e: except Exception as e:
log.error(f'Fail to tokenize inputs: {e}') log.error(f'Fail to tokenize inputs: {e}')
raise e raise e
@ -116,17 +131,7 @@ class AutoTransformers(NNOperator):
@property @property
def _model(self): def _model(self):
model = AutoModel.from_pretrained(self.model_name).to(self.device)
if hasattr(model, 'pooler') and model.pooler:
model.pooler = None
if self.checkpoint_path:
try:
state_dict = torch.load(self.checkpoint_path, map_location=self.device)
model.load_state_dict(state_dict)
except Exception:
log.error(f'Fail to load weights from {self.checkpoint_path}')
model.eval()
return model
return self.model.model
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
if output_file == 'default': if output_file == 'default':
@ -160,6 +165,7 @@ class AutoTransformers(NNOperator):
elif model_type == 'onnx': elif model_type == 'onnx':
from transformers.onnx.features import FeaturesManager from transformers.onnx.features import FeaturesManager
from transformers.onnx import export from transformers.onnx import export
self._model = self._model.to('cpu')
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self._model, feature='default') self._model, feature='default')
onnx_config = model_onnx_config(self._model.config) onnx_config = model_onnx_config(self._model.config)

Loading…
Cancel
Save