logo
rerank
repo-copy-icon

copied

Browse Source

optimizer for supporting local checkpoint_path

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
main
ChengZi 2 years ago
parent
commit
37eadfa858
  1. 6
      rerank.py

6
rerank.py

@ -18,10 +18,10 @@ class Model:
def __init__(self, model_name, checkpoint_path, config, device):
self.device = device
self.config = config
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
if checkpoint_path:
state_dict = torch.load(checkpoint_path, map_location=device)
self.model.load_state_dict(state_dict)
self.model = AutoModelForSequenceClassification.from_pretrained(checkpoint_path, config=self.config)
else:
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
self.model.to(self.device)
self.model.eval()

Loading…
Cancel
Save