logo
Browse Source

Allow to pass checkpoint path

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 10 months ago
parent
commit
db6727933a
  1. 14
      timm_image.py

14
timm_image.py

@ -54,9 +54,13 @@ def torch_no_grad(f):
@accelerate
class Model:
def __init__(self, model_name, device, num_classes):
def __init__(self, model_name, device, num_classes, checkpoint_path=None):
self.device = device
self.model = create_model(model_name, pretrained=True, num_classes=num_classes)
if checkpoint_path:
assert os.path.exists(checkpoint_path), f'File not found: {checkpoint_path}'
self.model = create_model(model_name, checkpoint_path, num_classes=num_classes)
else:
self.model = create_model(model_name, pretrained=True, num_classes=num_classes)
self.model.eval()
self.model.to(device)
@ -81,7 +85,8 @@ class TimmImage(NNOperator):
model_name: str = None,
num_classes: int = 1000,
skip_preprocess: bool = False,
device: str = None
device: str = None,
checkpoint_path: str = None
) -> None:
super().__init__()
if device is None:
@ -92,7 +97,8 @@ class TimmImage(NNOperator):
self.model = Model(
model_name=model_name,
device=self.device,
num_classes=num_classes
num_classes=num_classes,
checkpoint_path=checkpoint_path
)
try:
self.tfms = create_transform(

Loading…
Cancel
Save