logo
Browse Source

Fix issue of mismatched device

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
08584eca02
  1. 3
      mpvit.py

3
mpvit.py

@ -40,6 +40,9 @@ class MPViT(NNOperator):
skip_preprocess: bool = False):
super().__init__()
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = mpvit.create_model(model_name=model_name,
num_classes=num_classes,
pretrained=True,

Loading…
Cancel
Save