diff --git a/clip.py b/clip.py index 8bedb00..cf110dc 100644 --- a/clip.py +++ b/clip.py @@ -96,7 +96,11 @@ class Clip(NNOperator): def __init__(self, model_name: str, modality: str, device: str = 'cpu', checkpoint_path: str = None): self.model_name = model_name self.modality = modality - self.device = device + if not torch.cuda.is_available(): + log.warning("Gpu not available, use cpu") + self.device = 'cpu' + else: + self.device = device self.checkpoint_path = checkpoint_path config = self._configs() real_name = model_name