Browse Source
optimizer for supporting local checkpoint_path
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
main
ChengZi
2 years ago
1 changed files with
3 additions and
3 deletions
-
rerank.py
|
|
@ -18,10 +18,10 @@ class Model: |
|
|
|
def __init__(self, model_name, checkpoint_path, config, device): |
|
|
|
self.device = device |
|
|
|
self.config = config |
|
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) |
|
|
|
if checkpoint_path: |
|
|
|
state_dict = torch.load(checkpoint_path, map_location=device) |
|
|
|
self.model.load_state_dict(state_dict) |
|
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(checkpoint_path, config=self.config) |
|
|
|
else: |
|
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config) |
|
|
|
self.model.to(self.device) |
|
|
|
self.model.eval() |
|
|
|
|
|
|
|