diff --git a/auto_transformers.py b/auto_transformers.py index f21a5e8..6fca0b8 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -40,6 +40,8 @@ class Model: def __init__(self, model_name, device, checkpoint_path): try: self.model = AutoModel.from_pretrained(model_name).to(device) + if hasattr(self.model, 'pooler') and self.model.pooler: + self.model.pooler = None except Exception as e: log.error(f"Fail to load model by name: {self.model_name}") raise e @@ -52,7 +54,7 @@ class Model: self.model.eval() def __call__(self, *args, **kwargs): - outs = self.model(*args, **kwargs) + outs = self.model(*args, **kwargs, return_dict=True) return outs['last_hidden_state']