diff --git a/timm_image.py b/timm_image.py index 54f4099..026625f 100644 --- a/timm_image.py +++ b/timm_image.py @@ -58,7 +58,7 @@ class Model: self.device = device if checkpoint_path: assert os.path.exists(checkpoint_path), f'File not found: {checkpoint_path}' - self.model = create_model(model_name, checkpoint_path, num_classes=num_classes) + self.model = create_model(model_name, pretrained=False, checkpoint_path=checkpoint_path, num_classes=num_classes) else: self.model = create_model(model_name, pretrained=True, num_classes=num_classes) self.model.eval()