Browse Source
Checkout device available
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
1 changed files with
5 additions and
2 deletions
-
rerank.py
|
|
@ -42,8 +42,11 @@ class ReRank(NNOperator): |
|
|
|
super().__init__() |
|
|
|
self._model_name = model_name |
|
|
|
self.config = AutoConfig.from_pretrained(model_name) |
|
|
|
if isinstance(device, int) and device >=0 and torch.cuda.is_available(): |
|
|
|
self.device = device |
|
|
|
self.model = Model(model_name, checkpoint_path, self.config, device) |
|
|
|
else: |
|
|
|
self.device = 'cpu' |
|
|
|
self.model = Model(model_name, checkpoint_path, self.config, self.device) |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
self.max_length = max_length |
|
|
|
self._threshold = threshold |
|
|
|