diff --git a/timm_image.py b/timm_image.py index 026625f..e7f936e 100644 --- a/timm_image.py +++ b/timm_image.py @@ -89,9 +89,14 @@ class TimmImage(NNOperator): checkpoint_path: str = None ) -> None: super().__init__() - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.device = device + if not torch.cuda.is_available(): + log.warning('Gpu is not available, use cpu') + self.device = 'cpu' + else: + if device is None: + self.device = 'cuda' + else: + self.device = device self.model_name = model_name if self.model_name: self.model = Model(