diff --git a/isc.py b/isc.py index 0253e11..f9a3d2e 100644 --- a/isc.py +++ b/isc.py @@ -48,9 +48,8 @@ _ = sys.modules[__name__] class Model: def __init__(self, timm_backbone, checkpoint_path, device): self.device = device - self.backbone = create_model(timm_backbone, features_only=True, pretrained=False) self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device, - backbone=self.backbone, p=1.0, eval_p=1.0) + timm_backbone=timm_backbone, p=1.0, eval_p=1.0) self.model.eval() def __call__(self, x):