logo
Browse Source

Allow to pass checkpoint path

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
db6727933a
  1. 12
      timm_image.py

12
timm_image.py

@ -54,8 +54,12 @@ def torch_no_grad(f):
@accelerate @accelerate
class Model: class Model:
def __init__(self, model_name, device, num_classes):
def __init__(self, model_name, device, num_classes, checkpoint_path=None):
self.device = device self.device = device
if checkpoint_path:
assert os.path.exists(checkpoint_path), f'File not found: {checkpoint_path}'
self.model = create_model(model_name, checkpoint_path, num_classes=num_classes)
else:
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) self.model = create_model(model_name, pretrained=True, num_classes=num_classes)
self.model.eval() self.model.eval()
self.model.to(device) self.model.to(device)
@ -81,7 +85,8 @@ class TimmImage(NNOperator):
model_name: str = None, model_name: str = None,
num_classes: int = 1000, num_classes: int = 1000,
skip_preprocess: bool = False, skip_preprocess: bool = False,
device: str = None
device: str = None,
checkpoint_path: str = None
) -> None: ) -> None:
super().__init__() super().__init__()
if device is None: if device is None:
@ -92,7 +97,8 @@ class TimmImage(NNOperator):
self.model = Model( self.model = Model(
model_name=model_name, model_name=model_name,
device=self.device, device=self.device,
num_classes=num_classes
num_classes=num_classes,
checkpoint_path=checkpoint_path
) )
try: try:
self.tfms = create_transform( self.tfms = create_transform(

Loading…
Cancel
Save