logo
Browse Source

Add cls pooling

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
34f1a1e4b8
  1. 23
      auto_transformers.py

23
auto_transformers.py

@ -97,15 +97,22 @@ class AutoTransformers(NNOperator):
The local checkpoint path.
tokenizer (`object`):
The tokenizer to tokenize input text as model inputs.
pool (`str`):
The type of post-process pooling after token embeddings, defaults to "mean". Options: "mean", "cls"
"""
def __init__(self,
model_name: str = None,
checkpoint_path: str = None,
tokenizer: object = None,
pool: str = 'mean',
device: str = None,
):
super().__init__()
if pool not in ['mean', 'cls']:
log.warning('Invalid pool %s, using mean pooling instead.', pool)
pool = 'mean'
self.pool = pool
if device:
self.device = device
else:
@ -145,7 +152,10 @@ class AutoTransformers(NNOperator):
except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}')
raise e
outs = self.post_proc(outs, inputs)
if self.pool == 'mean':
outs = self.mean_pool(outs, inputs)
elif self.pool == 'cls':
outs = self.cls_pool(outs)
features = outs.detach().numpy()
if isinstance(data, str):
features = features.squeeze(0)
@ -184,7 +194,7 @@ class AutoTransformers(NNOperator):
}
return onnx_config
def post_proc(self, token_embeddings, inputs):
def mean_pool(self, token_embeddings, inputs):
token_embeddings = token_embeddings
attention_mask = inputs['attention_mask']
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
@ -192,6 +202,15 @@ class AutoTransformers(NNOperator):
token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sentence_embs
def cls_pool(self, token_embeddings):
_shape = token_embeddings.shape
if len(_shape) == 3:
return token_embeddings[:, 0, :]
elif len(_shape) == 2:
return token_embeddings[0]
else:
raise RuntimeError(f'Invalid shape of token embeddings: {_shape}')
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
if output_file == 'default':
output_file = str(Path(__file__).parent)

Loading…
Cancel
Save