From 13d94c49b8c5b44dc4f2f0d33e992e6635edabf9 Mon Sep 17 00:00:00 2001 From: gexy5 Date: Thu, 2 Jun 2022 10:25:40 +0800 Subject: [PATCH] add save_model Signed-off-by: gexy5 --- timm_image.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/timm_image.py b/timm_image.py index 1c2a70f..2683b7c 100644 --- a/timm_image.py +++ b/timm_image.py @@ -14,6 +14,8 @@ import logging import numpy +import os +from pathlib import Path import towhee from towhee.operator.base import NNOperator, OperatorFlag @@ -51,7 +53,8 @@ class TimmImage(NNOperator): def __init__(self, model_name: str, num_classes: int = 1000, skip_preprocess: bool = False) -> None: super().__init__() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.model = create_model(model_name, pretrained=True, num_classes=num_classes) + self.model_name = model_name + self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes) self.model.to(self.device) self.model.eval() config = resolve_data_config({}, model=self.model) @@ -73,6 +76,18 @@ class TimmImage(NNOperator): vec = features.flatten().detach().numpy() return vec + def save_model(self, jit: bool = True, destination: str = 'default'): + if destination == 'default': + path = str(Path(__file__).parent) + destination = os.path.join(path, self.model_name + '.pt') + if jit: + try: + traced_model = torch.jit.script(self.model) + torch.jit.save(traced_model, destination) + except Exception as e: + raise RuntimeError(f'Fail to save as torchscript: {e}.') + else: + torch.save(self.model, destination) # if __name__ == '__main__': # from towhee import ops @@ -85,3 +100,6 @@ class TimmImage(NNOperator): # op = TimmImage('resnet50') # out = op(img) # print(out) +# op.model = torch.jit.load('resnet50.pt') +# out2 = op(img) +# print(out2)