diff --git a/s_bert.py b/s_bert.py index 6a367af..ac25c2e 100644 --- a/s_bert.py +++ b/s_bert.py @@ -79,9 +79,11 @@ class STransformers(NNOperator): def __init__(self, model_name: str = None, device: str = None, return_usage: bool = False): self.model_name = model_name - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.device = device + if device: + self.device = device + else: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if self.model_name: self.model = Model(model_name=self.model_name, device=self.device) else: