diff --git a/isc.py b/isc.py index f9a3d2e..a187b6f 100644 --- a/isc.py +++ b/isc.py @@ -76,7 +76,7 @@ class Isc(NNOperator): super().__init__() if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.device = device + self.device = device if isinstance(device, str) else 'cpu' if device < 0 else torch.device(device) self.skip_tfms = skip_preprocess self.timm_backbone = timm_backbone