diff --git a/mpvit.py b/mpvit.py index 5cd3fc1..4058b4d 100644 --- a/mpvit.py +++ b/mpvit.py @@ -40,6 +40,9 @@ class MPViT(NNOperator): skip_preprocess: bool = False): super().__init__() + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model = mpvit.create_model(model_name=model_name, num_classes=num_classes, pretrained=True,