Browse Source
Remove pooler
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
3 additions and
1 deletions
-
auto_transformers.py
|
@ -40,6 +40,8 @@ class Model: |
|
|
def __init__(self, model_name, device, checkpoint_path): |
|
|
def __init__(self, model_name, device, checkpoint_path): |
|
|
try: |
|
|
try: |
|
|
self.model = AutoModel.from_pretrained(model_name).to(device) |
|
|
self.model = AutoModel.from_pretrained(model_name).to(device) |
|
|
|
|
|
if hasattr(self.model, 'pooler') and self.model.pooler: |
|
|
|
|
|
self.model.pooler = None |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
raise e |
|
|
raise e |
|
@ -52,7 +54,7 @@ class Model: |
|
|
self.model.eval() |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
def __call__(self, *args, **kwargs): |
|
|
outs = self.model(*args, **kwargs) |
|
|
|
|
|
|
|
|
outs = self.model(*args, **kwargs, return_dict=True) |
|
|
return outs['last_hidden_state'] |
|
|
return outs['last_hidden_state'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|