|
@ -14,6 +14,8 @@ |
|
|
|
|
|
|
|
|
import logging |
|
|
import logging |
|
|
import numpy |
|
|
import numpy |
|
|
|
|
|
import os |
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
import towhee |
|
|
import towhee |
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
@ -51,7 +53,8 @@ class TimmImage(NNOperator): |
|
|
def __init__(self, model_name: str, num_classes: int = 1000, skip_preprocess: bool = False) -> None: |
|
|
def __init__(self, model_name: str, num_classes: int = 1000, skip_preprocess: bool = False) -> None: |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) |
|
|
|
|
|
|
|
|
self.model_name = model_name |
|
|
|
|
|
self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes) |
|
|
self.model.to(self.device) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
self.model.eval() |
|
|
config = resolve_data_config({}, model=self.model) |
|
|
config = resolve_data_config({}, model=self.model) |
|
@ -73,6 +76,18 @@ class TimmImage(NNOperator): |
|
|
vec = features.flatten().detach().numpy() |
|
|
vec = features.flatten().detach().numpy() |
|
|
return vec |
|
|
return vec |
|
|
|
|
|
|
|
|
|
|
|
def save_model(self, jit: bool = True, destination: str = 'default'): |
|
|
|
|
|
if destination == 'default': |
|
|
|
|
|
path = str(Path(__file__).parent) |
|
|
|
|
|
destination = os.path.join(path, self.model_name + '.pt') |
|
|
|
|
|
if jit: |
|
|
|
|
|
try: |
|
|
|
|
|
traced_model = torch.jit.script(self.model) |
|
|
|
|
|
torch.jit.save(traced_model, destination) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
|
|
else: |
|
|
|
|
|
torch.save(self.model, destination) |
|
|
|
|
|
|
|
|
# if __name__ == '__main__': |
|
|
# if __name__ == '__main__': |
|
|
# from towhee import ops |
|
|
# from towhee import ops |
|
@ -85,3 +100,6 @@ class TimmImage(NNOperator): |
|
|
# op = TimmImage('resnet50') |
|
|
# op = TimmImage('resnet50') |
|
|
# out = op(img) |
|
|
# out = op(img) |
|
|
# print(out) |
|
|
# print(out) |
|
|
|
|
|
# op.model = torch.jit.load('resnet50.pt') |
|
|
|
|
|
# out2 = op(img) |
|
|
|
|
|
# print(out2) |
|
|