logo
rerank
repo-copy-icon

copied

Browse Source

Checkout device available

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 2 years ago
parent
commit
21e37c0b3e
  1. 5
      rerank.py

5
rerank.py

@ -42,8 +42,11 @@ class ReRank(NNOperator):
super().__init__() super().__init__()
self._model_name = model_name self._model_name = model_name
self.config = AutoConfig.from_pretrained(model_name) self.config = AutoConfig.from_pretrained(model_name)
if isinstance(device, int) and device >=0 and torch.cuda.is_available():
self.device = device self.device = device
self.model = Model(model_name, checkpoint_path, self.config, device)
else:
self.device = 'cpu'
self.model = Model(model_name, checkpoint_path, self.config, self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_length = max_length self.max_length = max_length
self._threshold = threshold self._threshold = threshold

Loading…
Cancel
Save