From 21e37c0b3e4abae7fa3261c677d5bad2e5bc5b25 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Fri, 15 Sep 2023 16:12:54 +0800 Subject: [PATCH] Checkout device available Signed-off-by: junjie.jiang --- rerank.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/rerank.py b/rerank.py index cb663c2..1365df6 100644 --- a/rerank.py +++ b/rerank.py @@ -42,8 +42,11 @@ class ReRank(NNOperator): super().__init__() self._model_name = model_name self.config = AutoConfig.from_pretrained(model_name) - self.device = device - self.model = Model(model_name, checkpoint_path, self.config, device) + if isinstance(device, int) and device >=0 and torch.cuda.is_available(): + self.device = device + else: + self.device = 'cpu' + self.model = Model(model_name, checkpoint_path, self.config, self.device) self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.max_length = max_length self._threshold = threshold