Browse Source
Fix issue of mismatched device
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
3 additions and
0 deletions
-
mpvit.py
|
@ -40,6 +40,9 @@ class MPViT(NNOperator): |
|
|
skip_preprocess: bool = False): |
|
|
skip_preprocess: bool = False): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
if device is None: |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
self.model = mpvit.create_model(model_name=model_name, |
|
|
self.model = mpvit.create_model(model_name=model_name, |
|
|
num_classes=num_classes, |
|
|
num_classes=num_classes, |
|
|
pretrained=True, |
|
|
pretrained=True, |
|
|