From 08584eca028141aae5c378db81f22418312e5c13 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 8 Sep 2022 17:49:07 +0800 Subject: [PATCH] Fix issue of mismatched device Signed-off-by: Jael Gu --- mpvit.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mpvit.py b/mpvit.py index 5cd3fc1..4058b4d 100644 --- a/mpvit.py +++ b/mpvit.py @@ -40,6 +40,9 @@ class MPViT(NNOperator): skip_preprocess: bool = False): super().__init__() + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model = mpvit.create_model(model_name=model_name, num_classes=num_classes, pretrained=True,