diff --git a/rerank.py b/rerank.py index cb663c2..1365df6 100644 --- a/rerank.py +++ b/rerank.py @@ -42,8 +42,11 @@ class ReRank(NNOperator): super().__init__() self._model_name = model_name self.config = AutoConfig.from_pretrained(model_name) - self.device = device - self.model = Model(model_name, checkpoint_path, self.config, device) + if isinstance(device, int) and device >=0 and torch.cuda.is_available(): + self.device = device + else: + self.device = 'cpu' + self.model = Model(model_name, checkpoint_path, self.config, self.device) self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.max_length = max_length self._threshold = threshold