image-embedding
copied
Fix issue of mismatched device
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
@ -40,6 +40,9 @@ class MPViT(NNOperator):
skip_preprocess: bool = False):
super().__init__()
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = mpvit.create_model(model_name=model_name,
num_classes=num_classes,
pretrained=True,