logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
28419844ff
  1. 8
      README.md
  2. 56
      auto_transformers.py

8
README.md

@ -48,7 +48,7 @@ Create the operator via the following factory method:
The model name in string, defaults to None.
If None, the operator will be initialized without specified model.
Supported model names:
Please note only supported models are tested by us:
<details><summary>Albert</summary>
@ -309,6 +309,12 @@ If None, the operator will download and load pretrained model by `model_name` fr
<br />
***device***: *str*
The device in string, defaults to None. If None, it will enable "cuda" automatically when cuda is available.
<br />
***tokenizer***: *object*
The method to tokenize input text, defaults to None.

56
auto_transformers.py

@ -37,14 +37,33 @@ warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
t_logging.set_verbosity_error()
def create_model(model_name, checkpoint_path, device):
model = AutoModel.from_pretrained(model_name).to(device)
if hasattr(model, 'pooler') and model.pooler:
model.pooler = None
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
except Exception:
log.error(f'Fail to load weights from {checkpoint_path}')
model.eval()
return model
# @accelerate
class Model:
def __init__(self, model):
self.model = model
def __init__(self, model_name, checkpoint_path, device):
self.device = device
self.model = create_model(model_name, checkpoint_path, device)
def __call__(self, *args, **kwargs):
outs = self.model(*args, **kwargs, return_dict=True)
new_args = []
for x in args:
new_args.append(x.to(self.device))
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v.to(self.device)
outs = self.model(*new_args, **new_kwargs, return_dict=True)
return outs['last_hidden_state']
@ -75,17 +94,13 @@ class AutoTransformers(NNOperator):
self.checkpoint_path = checkpoint_path
if self.model_name:
model_list = self.supported_model_names()
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
self.model = Model(self._model)
if tokenizer is None:
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
log.error(f'Fail to load default tokenizer by name: {self.model_name}')
raise e
else:
# model_list = self.supported_model_names()
# assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
self.model = Model(model_name=self.model_name, checkpoint_path=self.checkpoint_path, device=self.device)
if tokenizer:
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = '[PAD]'
else:
@ -98,7 +113,7 @@ class AutoTransformers(NNOperator):
else:
txt = data
try:
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors="pt").to(self.device)
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt')
except Exception as e:
log.error(f'Fail to tokenize inputs: {e}')
raise e
@ -116,17 +131,7 @@ class AutoTransformers(NNOperator):
@property
def _model(self):
model = AutoModel.from_pretrained(self.model_name).to(self.device)
if hasattr(model, 'pooler') and model.pooler:
model.pooler = None
if self.checkpoint_path:
try:
state_dict = torch.load(self.checkpoint_path, map_location=self.device)
model.load_state_dict(state_dict)
except Exception:
log.error(f'Fail to load weights from {self.checkpoint_path}')
model.eval()
return model
return self.model.model
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
if output_file == 'default':
@ -160,6 +165,7 @@ class AutoTransformers(NNOperator):
elif model_type == 'onnx':
from transformers.onnx.features import FeaturesManager
from transformers.onnx import export
self._model = self._model.to('cpu')
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self._model, feature='default')
onnx_config = model_onnx_config(self._model.config)

Loading…
Cancel
Save