Browse Source
Allow to pass checkpoint path
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
10 additions and
4 deletions
-
timm_image.py
|
|
@ -54,8 +54,12 @@ def torch_no_grad(f): |
|
|
|
|
|
|
|
@accelerate |
|
|
|
class Model: |
|
|
|
def __init__(self, model_name, device, num_classes): |
|
|
|
def __init__(self, model_name, device, num_classes, checkpoint_path=None): |
|
|
|
self.device = device |
|
|
|
if checkpoint_path: |
|
|
|
assert os.path.exists(checkpoint_path), f'File not found: {checkpoint_path}' |
|
|
|
self.model = create_model(model_name, checkpoint_path, num_classes=num_classes) |
|
|
|
else: |
|
|
|
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) |
|
|
|
self.model.eval() |
|
|
|
self.model.to(device) |
|
|
@ -81,7 +85,8 @@ class TimmImage(NNOperator): |
|
|
|
model_name: str = None, |
|
|
|
num_classes: int = 1000, |
|
|
|
skip_preprocess: bool = False, |
|
|
|
device: str = None |
|
|
|
device: str = None, |
|
|
|
checkpoint_path: str = None |
|
|
|
) -> None: |
|
|
|
super().__init__() |
|
|
|
if device is None: |
|
|
@ -92,7 +97,8 @@ class TimmImage(NNOperator): |
|
|
|
self.model = Model( |
|
|
|
model_name=model_name, |
|
|
|
device=self.device, |
|
|
|
num_classes=num_classes |
|
|
|
num_classes=num_classes, |
|
|
|
checkpoint_path=checkpoint_path |
|
|
|
) |
|
|
|
try: |
|
|
|
self.tfms = create_transform( |
|
|
|