diff --git a/clip_vision.py b/clip_vision.py index 3775cd5..dd341df 100644 --- a/clip_vision.py +++ b/clip_vision.py @@ -49,7 +49,10 @@ class ClipVision(NNOperator): @property def device(self): if self._device is None: - self._device = torch.device(self._device_id) + if self._device_id < 0: + self._device = torch.device('cpu') + else: + self._device = torch.device(self._device_id) return self._device def __call__(self, image: 'Image'):