|
|
@ -59,10 +59,17 @@ def create_model(model_name, checkpoint_path, device): |
|
|
|
# @accelerate |
|
|
|
class Model: |
|
|
|
def __init__(self, model_name, checkpoint_path, device): |
|
|
|
self.device = device |
|
|
|
self.model = create_model(model_name, checkpoint_path, device) |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
outs = self.model(*args, **kwargs, return_dict=True) |
|
|
|
new_args = [] |
|
|
|
for x in args: |
|
|
|
new_args.append(x.to(self.device)) |
|
|
|
new_kwargs = {} |
|
|
|
for k, v in kwargs.items(): |
|
|
|
new_kwargs[k] = v.to(self.device) |
|
|
|
outs = self.model(*new_args, **new_kwargs, return_dict=True) |
|
|
|
return outs['last_hidden_state'] |
|
|
|
|
|
|
|
|
|
|
@ -84,7 +91,6 @@ class AutoTransformers(NNOperator): |
|
|
|
checkpoint_path: str = None, |
|
|
|
tokenizer: object = None, |
|
|
|
device: str = None, |
|
|
|
norm: bool = False |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
if device: |
|
|
@ -95,7 +101,6 @@ class AutoTransformers(NNOperator): |
|
|
|
self.model_name = 'sentence-transformers/' + model_name |
|
|
|
else: |
|
|
|
self.model_name = model_name |
|
|
|
self.norm = norm |
|
|
|
self.checkpoint_path = checkpoint_path |
|
|
|
|
|
|
|
if self.model_name: |
|
|
@ -118,7 +123,7 @@ class AutoTransformers(NNOperator): |
|
|
|
else: |
|
|
|
txt = data |
|
|
|
try: |
|
|
|
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt').to(self.device) |
|
|
|
inputs = self.tokenizer(txt, padding=True, truncation=True, return_tensors='pt') |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to tokenize inputs: {e}') |
|
|
|
raise e |
|
|
@ -128,8 +133,6 @@ class AutoTransformers(NNOperator): |
|
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
|
raise e |
|
|
|
outs = self.post_proc(outs, inputs) |
|
|
|
if self.norm: |
|
|
|
outs = torch.nn.functional.normalize(outs, ) |
|
|
|
features = outs.cpu().detach().numpy() |
|
|
|
if isinstance(data, str): |
|
|
|
features = features.squeeze(0) |
|
|
@ -169,8 +172,8 @@ class AutoTransformers(NNOperator): |
|
|
|
return onnx_config |
|
|
|
|
|
|
|
def post_proc(self, token_embeddings, inputs): |
|
|
|
token_embeddings = token_embeddings.to(self.device) |
|
|
|
attention_mask = inputs['attention_mask'].to(self.device) |
|
|
|
token_embeddings = token_embeddings |
|
|
|
attention_mask = inputs['attention_mask'] |
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
|
sentence_embs = torch.sum( |
|
|
|
token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
@ -191,7 +194,7 @@ class AutoTransformers(NNOperator): |
|
|
|
raise AttributeError('Unsupported model_type.') |
|
|
|
|
|
|
|
dummy_input = 'test sentence' |
|
|
|
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt').to(self.device) |
|
|
|
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') |
|
|
|
if model_type == 'pytorch': |
|
|
|
torch.save(self._model, output_file) |
|
|
|
elif model_type == 'torchscript': |
|
|
@ -214,7 +217,7 @@ class AutoTransformers(NNOperator): |
|
|
|
if hasattr(self._model.config, 'use_cache'): |
|
|
|
self._model.config.use_cache = False |
|
|
|
torch.onnx.export( |
|
|
|
self._model, |
|
|
|
self._model.to('cpu'), |
|
|
|
tuple(inputs.values()), |
|
|
|
output_file, |
|
|
|
input_names=list(self.onnx_config['inputs'].keys()), |
|
|
|