From db6727933aa00a8afe1f5062e8ac73098c8d590b Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 26 Jun 2023 18:37:58 +0800 Subject: [PATCH] Allow to pass checkpoint path Signed-off-by: Jael Gu --- timm_image.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/timm_image.py b/timm_image.py index 515f5c2..96cc0b1 100644 --- a/timm_image.py +++ b/timm_image.py @@ -54,9 +54,13 @@ def torch_no_grad(f): @accelerate class Model: - def __init__(self, model_name, device, num_classes): + def __init__(self, model_name, device, num_classes, checkpoint_path=None): self.device = device - self.model = create_model(model_name, pretrained=True, num_classes=num_classes) + if checkpoint_path: + assert os.path.exists(checkpoint_path), f'File not found: {checkpoint_path}' + self.model = create_model(model_name, checkpoint_path, num_classes=num_classes) + else: + self.model = create_model(model_name, pretrained=True, num_classes=num_classes) self.model.eval() self.model.to(device) @@ -81,7 +85,8 @@ class TimmImage(NNOperator): model_name: str = None, num_classes: int = 1000, skip_preprocess: bool = False, - device: str = None + device: str = None, + checkpoint_path: str = None ) -> None: super().__init__() if device is None: @@ -92,7 +97,8 @@ class TimmImage(NNOperator): self.model = Model( model_name=model_name, device=self.device, - num_classes=num_classes + num_classes=num_classes, + checkpoint_path=checkpoint_path ) try: self.tfms = create_transform(