diff --git a/auto_transformers.py b/auto_transformers.py index f1219da..f21a5e8 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -52,7 +52,8 @@ class Model: self.model.eval() def __call__(self, *args, **kwargs): - return self.model(*args, **kwargs) + outs = self.model(*args, **kwargs) + return outs['last_hidden_state'] @register(output_schema=['vec']) @@ -110,7 +111,7 @@ class AutoTransformers(NNOperator): log.error(f'Invalid input for the model: {self.model_name}') raise e try: - features = outs['last_hidden_state'].squeeze(0) + features = outs.squeeze(0) except Exception as e: log.error(f'Fail to extract features by model: {self.model_name}') raise e