logo
Browse Source

Update forward method to support TritonServe

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
862459b19a
  1. 5
      auto_transformers.py

5
auto_transformers.py

@ -52,7 +52,8 @@ class Model:
self.model.eval()
def __call__(self, *args, **kwargs):
return self.model(*args, **kwargs)
outs = self.model(*args, **kwargs)
return outs['last_hidden_state']
@register(output_schema=['vec'])
@ -110,7 +111,7 @@ class AutoTransformers(NNOperator):
log.error(f'Invalid input for the model: {self.model_name}')
raise e
try:
features = outs['last_hidden_state'].squeeze(0)
features = outs.squeeze(0)
except Exception as e:
log.error(f'Fail to extract features by model: {self.model_name}')
raise e

Loading…
Cancel
Save