logo
Browse Source

Fix device

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 8 months ago
parent
commit
6985827e85
  1. 8
      s_bert.py

8
s_bert.py

@ -79,9 +79,11 @@ class STransformers(NNOperator):
def __init__(self, model_name: str = None, device: str = None, return_usage: bool = False):
self.model_name = model_name
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
if device:
self.device = device
else:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if self.model_name:
self.model = Model(model_name=self.model_name, device=self.device)
else:

Loading…
Cancel
Save