logo
Browse Source

Remove sentence post_proc

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
83e096c6db
  1. 45
      auto_transformers.py

45
auto_transformers.py

@ -40,21 +40,8 @@ t_logging.set_verbosity_error()
# @accelerate # @accelerate
class Model: class Model:
def __init__(self, model_name, device, checkpoint_path):
try:
self.model = AutoModel.from_pretrained(model_name).to(device)
if hasattr(self.model, 'pooler') and self.model.pooler:
self.model.pooler = None
except Exception as e:
log.error(f"Fail to load model by name: {self.model_name}")
raise e
if checkpoint_path:
try:
state_dict = torch.load(checkpoint_path, map_location=device)
self.model.load_state_dict(state_dict)
except Exception as e:
log.error(f"Fail to load state dict from {checkpoint_path}: {e}")
self.model.eval()
def __init__(self, model):
self.model = model
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
outs = self.model(*args, **kwargs, return_dict=True) outs = self.model(*args, **kwargs, return_dict=True)
@ -79,20 +66,18 @@ class AutoTransformers(NNOperator):
checkpoint_path: str = None, checkpoint_path: str = None,
tokenizer: object = None, tokenizer: object = None,
device: str = None, device: str = None,
return_sentence_emb: bool = True
): ):
super().__init__() super().__init__()
if device is None: if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device self.device = device
self.model_name = model_name self.model_name = model_name
self.return_sentence_emb = return_sentence_emb
self.checkpoint_path = checkpoint_path
if self.model_name: if self.model_name:
model_list = self.supported_model_names() model_list = self.supported_model_names()
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}"
self.model = Model(
model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path)
self.model = Model(self._model)
if tokenizer is None: if tokenizer is None:
try: try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -122,8 +107,6 @@ class AutoTransformers(NNOperator):
except Exception as e: except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}') log.error(f'Invalid input for the model: {self.model_name}')
raise e raise e
if self.return_sentence_emb:
outs = self.post_proc(outs, inputs)
features = outs.cpu().detach().numpy() features = outs.cpu().detach().numpy()
if isinstance(data, str): if isinstance(data, str):
features = features.squeeze(0) features = features.squeeze(0)
@ -133,15 +116,17 @@ class AutoTransformers(NNOperator):
@property @property
def _model(self): def _model(self):
return self.model.model
def post_proc(self, token_embeddings, inputs):
token_embeddings = token_embeddings.to(self.device)
attention_mask = inputs['attention_mask'].to(self.device)
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sentence_embs = torch.sum(
token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sentence_embs
model = AutoModel.from_pretrained(self.model_name).to(self.device)
if hasattr(model, 'pooler') and model.pooler:
model.pooler = None
if self.checkpoint_path:
try:
state_dict = torch.load(self.checkpoint_path, map_location=self.device)
model.load_state_dict(state_dict)
except Exception:
log.error(f'Fail to load weights from {self.checkpoint_path}')
model.eval()
return model
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
if output_file == 'default': if output_file == 'default':

Loading…
Cancel
Save