|
|
@ -97,15 +97,22 @@ class AutoTransformers(NNOperator): |
|
|
|
The local checkpoint path. |
|
|
|
tokenizer (`object`): |
|
|
|
The tokenizer to tokenize input text as model inputs. |
|
|
|
pool (`str`): |
|
|
|
The type of post-process pooling after token embeddings, defaults to "mean". Options: "mean", "cls" |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
model_name: str = None, |
|
|
|
checkpoint_path: str = None, |
|
|
|
tokenizer: object = None, |
|
|
|
pool: str = 'mean', |
|
|
|
device: str = None, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
if pool not in ['mean', 'cls']: |
|
|
|
log.warning('Invalid pool %s, using mean pooling instead.', pool) |
|
|
|
pool = 'mean' |
|
|
|
self.pool = pool |
|
|
|
if device: |
|
|
|
self.device = device |
|
|
|
else: |
|
|
@ -145,7 +152,10 @@ class AutoTransformers(NNOperator): |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
|
raise e |
|
|
|
outs = self.post_proc(outs, inputs) |
|
|
|
if self.pool == 'mean': |
|
|
|
outs = self.mean_pool(outs, inputs) |
|
|
|
elif self.pool == 'cls': |
|
|
|
outs = self.cls_pool(outs) |
|
|
|
features = outs.detach().numpy() |
|
|
|
if isinstance(data, str): |
|
|
|
features = features.squeeze(0) |
|
|
@ -184,7 +194,7 @@ class AutoTransformers(NNOperator): |
|
|
|
} |
|
|
|
return onnx_config |
|
|
|
|
|
|
|
def post_proc(self, token_embeddings, inputs): |
|
|
|
def mean_pool(self, token_embeddings, inputs): |
|
|
|
token_embeddings = token_embeddings |
|
|
|
attention_mask = inputs['attention_mask'] |
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
@ -192,6 +202,15 @@ class AutoTransformers(NNOperator): |
|
|
|
token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
return sentence_embs |
|
|
|
|
|
|
|
def cls_pool(self, token_embeddings): |
|
|
|
_shape = token_embeddings.shape |
|
|
|
if len(_shape) == 3: |
|
|
|
return token_embeddings[:, 0, :] |
|
|
|
elif len(_shape) == 2: |
|
|
|
return token_embeddings[0] |
|
|
|
else: |
|
|
|
raise RuntimeError(f'Invalid shape of token embeddings: {_shape}') |
|
|
|
|
|
|
|
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): |
|
|
|
if output_file == 'default': |
|
|
|
output_file = str(Path(__file__).parent) |
|
|
|