logo
Browse Source

add save_model

Signed-off-by: gexy5 <xinyu.ge@zilliz.com>
main
gexy5 3 years ago
parent
commit
13d94c49b8
  1. 20
      timm_image.py

20
timm_image.py

@ -14,6 +14,8 @@
import logging import logging
import numpy import numpy
import os
from pathlib import Path
import towhee import towhee
from towhee.operator.base import NNOperator, OperatorFlag from towhee.operator.base import NNOperator, OperatorFlag
@ -51,7 +53,8 @@ class TimmImage(NNOperator):
def __init__(self, model_name: str, num_classes: int = 1000, skip_preprocess: bool = False) -> None: def __init__(self, model_name: str, num_classes: int = 1000, skip_preprocess: bool = False) -> None:
super().__init__() super().__init__()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = create_model(model_name, pretrained=True, num_classes=num_classes)
self.model_name = model_name
self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes)
self.model.to(self.device) self.model.to(self.device)
self.model.eval() self.model.eval()
config = resolve_data_config({}, model=self.model) config = resolve_data_config({}, model=self.model)
@ -73,6 +76,18 @@ class TimmImage(NNOperator):
vec = features.flatten().detach().numpy() vec = features.flatten().detach().numpy()
return vec return vec
def save_model(self, jit: bool = True, destination: str = 'default'):
if destination == 'default':
path = str(Path(__file__).parent)
destination = os.path.join(path, self.model_name + '.pt')
if jit:
try:
traced_model = torch.jit.script(self.model)
torch.jit.save(traced_model, destination)
except Exception as e:
raise RuntimeError(f'Fail to save as torchscript: {e}.')
else:
torch.save(self.model, destination)
# if __name__ == '__main__': # if __name__ == '__main__':
# from towhee import ops # from towhee import ops
@ -85,3 +100,6 @@ class TimmImage(NNOperator):
# op = TimmImage('resnet50') # op = TimmImage('resnet50')
# out = op(img) # out = op(img)
# print(out) # print(out)
# op.model = torch.jit.load('resnet50.pt')
# out2 = op(img)
# print(out2)

Loading…
Cancel
Save