From 862459b19a7ecfdc9a690fe06d8144cfd84acfba Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 19 Dec 2022 15:06:55 +0800 Subject: [PATCH] Update forward method to support TritonServe Signed-off-by: Jael Gu --- auto_transformers.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index f1219da..f21a5e8 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -52,7 +52,8 @@ class Model: self.model.eval() def __call__(self, *args, **kwargs): - return self.model(*args, **kwargs) + outs = self.model(*args, **kwargs) + return outs['last_hidden_state'] @register(output_schema=['vec']) @@ -110,7 +111,7 @@ class AutoTransformers(NNOperator): log.error(f'Invalid input for the model: {self.model_name}') raise e try: - features = outs['last_hidden_state'].squeeze(0) + features = outs.squeeze(0) except Exception as e: log.error(f'Fail to extract features by model: {self.model_name}') raise e