logo
Browse Source

Remove pooler

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
824ecf2154
  1. 4
      auto_transformers.py

4
auto_transformers.py

@ -40,6 +40,8 @@ class Model:
def __init__(self, model_name, device, checkpoint_path): def __init__(self, model_name, device, checkpoint_path):
try: try:
self.model = AutoModel.from_pretrained(model_name).to(device) self.model = AutoModel.from_pretrained(model_name).to(device)
if hasattr(self.model, 'pooler') and self.model.pooler:
self.model.pooler = None
except Exception as e: except Exception as e:
log.error(f"Fail to load model by name: {self.model_name}") log.error(f"Fail to load model by name: {self.model_name}")
raise e raise e
@ -52,7 +54,7 @@ class Model:
self.model.eval() self.model.eval()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
outs = self.model(*args, **kwargs)
outs = self.model(*args, **kwargs, return_dict=True)
return outs['last_hidden_state'] return outs['last_hidden_state']

Loading…
Cancel
Save