From 6985827e855905f08071251aef2928efe38a3aef Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Wed, 26 Jun 2024 13:26:43 +0800 Subject: [PATCH] Fix device Signed-off-by: junjie.jiang --- s_bert.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/s_bert.py b/s_bert.py index 6a367af..ac25c2e 100644 --- a/s_bert.py +++ b/s_bert.py @@ -79,9 +79,11 @@ class STransformers(NNOperator): def __init__(self, model_name: str = None, device: str = None, return_usage: bool = False): self.model_name = model_name - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.device = device + if device: + self.device = device + else: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if self.model_name: self.model = Model(model_name=self.model_name, device=self.device) else: