Browse Source
Update forward method to support TritonServe
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
3 additions and
2 deletions
-
auto_transformers.py
|
@ -52,7 +52,8 @@ class Model: |
|
|
self.model.eval() |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
def __call__(self, *args, **kwargs): |
|
|
return self.model(*args, **kwargs) |
|
|
|
|
|
|
|
|
outs = self.model(*args, **kwargs) |
|
|
|
|
|
return outs['last_hidden_state'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
@register(output_schema=['vec']) |
|
@ -110,7 +111,7 @@ class AutoTransformers(NNOperator): |
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
raise e |
|
|
raise e |
|
|
try: |
|
|
try: |
|
|
features = outs['last_hidden_state'].squeeze(0) |
|
|
|
|
|
|
|
|
features = outs.squeeze(0) |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
log.error(f'Fail to extract features by model: {self.model_name}') |
|
|
log.error(f'Fail to extract features by model: {self.model_name}') |
|
|
raise e |
|
|
raise e |
|
|