|
|
@ -40,21 +40,8 @@ t_logging.set_verbosity_error() |
|
|
|
|
|
|
|
# @accelerate |
|
|
|
class Model: |
|
|
|
def __init__(self, model_name, device, checkpoint_path): |
|
|
|
try: |
|
|
|
self.model = AutoModel.from_pretrained(model_name).to(device) |
|
|
|
if hasattr(self.model, 'pooler') and self.model.pooler: |
|
|
|
self.model.pooler = None |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
|
raise e |
|
|
|
if checkpoint_path: |
|
|
|
try: |
|
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
|
self.model.load_state_dict(state_dict) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load state dict from {checkpoint_path}: {e}") |
|
|
|
self.model.eval() |
|
|
|
def __init__(self, model): |
|
|
|
self.model = model |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
outs = self.model(*args, **kwargs, return_dict=True) |
|
|
@ -79,20 +66,18 @@ class AutoTransformers(NNOperator): |
|
|
|
checkpoint_path: str = None, |
|
|
|
tokenizer: object = None, |
|
|
|
device: str = None, |
|
|
|
return_sentence_emb: bool = True |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
if device is None: |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.device = device |
|
|
|
self.model_name = model_name |
|
|
|
self.return_sentence_emb = return_sentence_emb |
|
|
|
self.checkpoint_path = checkpoint_path |
|
|
|
|
|
|
|
if self.model_name: |
|
|
|
model_list = self.supported_model_names() |
|
|
|
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" |
|
|
|
self.model = Model( |
|
|
|
model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path) |
|
|
|
self.model = Model(self._model) |
|
|
|
if tokenizer is None: |
|
|
|
try: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
@ -122,8 +107,6 @@ class AutoTransformers(NNOperator): |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
|
raise e |
|
|
|
if self.return_sentence_emb: |
|
|
|
outs = self.post_proc(outs, inputs) |
|
|
|
features = outs.cpu().detach().numpy() |
|
|
|
if isinstance(data, str): |
|
|
|
features = features.squeeze(0) |
|
|
@ -133,15 +116,17 @@ class AutoTransformers(NNOperator): |
|
|
|
|
|
|
|
@property |
|
|
|
def _model(self): |
|
|
|
return self.model.model |
|
|
|
|
|
|
|
def post_proc(self, token_embeddings, inputs): |
|
|
|
token_embeddings = token_embeddings.to(self.device) |
|
|
|
attention_mask = inputs['attention_mask'].to(self.device) |
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
|
sentence_embs = torch.sum( |
|
|
|
token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
return sentence_embs |
|
|
|
model = AutoModel.from_pretrained(self.model_name).to(self.device) |
|
|
|
if hasattr(model, 'pooler') and model.pooler: |
|
|
|
model.pooler = None |
|
|
|
if self.checkpoint_path: |
|
|
|
try: |
|
|
|
state_dict = torch.load(self.checkpoint_path, map_location=self.device) |
|
|
|
model.load_state_dict(state_dict) |
|
|
|
except Exception: |
|
|
|
log.error(f'Fail to load weights from {self.checkpoint_path}') |
|
|
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
|
|
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): |
|
|
|
if output_file == 'default': |
|
|
|