|
|
@ -23,6 +23,7 @@ from towhee.operator.base import NNOperator, OperatorFlag |
|
|
|
from towhee.types.arg import arg, to_image_color |
|
|
|
from towhee import register |
|
|
|
from towhee.models import isc |
|
|
|
# from towhee.dc2 import accelerate |
|
|
|
|
|
|
|
import torch |
|
|
|
from torch import nn |
|
|
@ -36,6 +37,20 @@ warnings.filterwarnings('ignore') |
|
|
|
log = logging.getLogger('isc_op') |
|
|
|
|
|
|
|
|
|
|
|
# @accelerate |
|
|
|
class Model: |
|
|
|
def __init__(self, timm_backbone, checkpoint_path, device): |
|
|
|
self.device = device |
|
|
|
self.backbone = timm.create_model(timm_backbone, features_only=True, pretrained=False) |
|
|
|
self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device, |
|
|
|
backbone=self.backbone, p=3.0, eval_p=1.0) |
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
def __call__(self, x): |
|
|
|
x = x.to(self.device) |
|
|
|
return self.model(x) |
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
|
class Isc(NNOperator): |
|
|
|
""" |
|
|
@ -62,16 +77,13 @@ class Isc(NNOperator): |
|
|
|
if checkpoint_path is None: |
|
|
|
checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth') |
|
|
|
|
|
|
|
backbone = timm.create_model(timm_backbone, features_only=True, pretrained=False) |
|
|
|
self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device, |
|
|
|
backbone=backbone, p=3.0, eval_p=1.0) |
|
|
|
self.model.eval() |
|
|
|
self.model = Model(self.timm_backbone, checkpoint_path, self.device) |
|
|
|
|
|
|
|
self.tfms = transforms.Compose([ |
|
|
|
transforms.Resize((img_size, img_size)), |
|
|
|
transforms.ToTensor(), |
|
|
|
transforms.Normalize(mean=backbone.default_cfg['mean'], |
|
|
|
std=backbone.default_cfg['std']) |
|
|
|
transforms.Normalize(mean=self.backbone.default_cfg['mean'], |
|
|
|
std=self.backbone.default_cfg['std']) |
|
|
|
]) |
|
|
|
|
|
|
|
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): |
|
|
@ -96,6 +108,14 @@ class Isc(NNOperator): |
|
|
|
vecs = features.squeeze(0).detach().numpy() |
|
|
|
return vecs |
|
|
|
|
|
|
|
@property |
|
|
|
def _model(self): |
|
|
|
return self.model.model |
|
|
|
|
|
|
|
@property |
|
|
|
def backbone(self): |
|
|
|
return self.model.backbone |
|
|
|
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|
|
if path == 'default': |
|
|
|
path = str(Path(__file__).parent) |
|
|
@ -103,25 +123,28 @@ class Isc(NNOperator): |
|
|
|
os.makedirs(path, exist_ok=True) |
|
|
|
name = self.timm_backbone.replace('/', '-') |
|
|
|
path = os.path.join(path, name) |
|
|
|
if format in ['pytorch', 'torchscript']: |
|
|
|
path = path + '.pt' |
|
|
|
elif format == 'onnx': |
|
|
|
path = path + '.onnx' |
|
|
|
else: |
|
|
|
raise ValueError(f'Invalid format {format}.') |
|
|
|
dummy_input = torch.rand(1, 3, 224, 224) |
|
|
|
if format == 'pytorch': |
|
|
|
path = path + '.pt' |
|
|
|
torch.save(self.model, path) |
|
|
|
torch.save(self._model, path) |
|
|
|
elif format == 'torchscript': |
|
|
|
path = path + '.pt' |
|
|
|
try: |
|
|
|
try: |
|
|
|
jit_model = torch.jit.script(self.model) |
|
|
|
jit_model = torch.jit.script(self._model) |
|
|
|
except Exception: |
|
|
|
jit_model = torch.jit.trace(self.model, dummy_input, strict=False) |
|
|
|
jit_model = torch.jit.trace(self._model, dummy_input, strict=False) |
|
|
|
torch.jit.save(jit_model, path) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
elif format == 'onnx': |
|
|
|
path = path + '.onnx' |
|
|
|
try: |
|
|
|
torch.onnx.export(self.model, |
|
|
|
torch.onnx.export(self._model, |
|
|
|
dummy_input, |
|
|
|
path, |
|
|
|
input_names=['input_0'], |
|
|
@ -139,12 +162,17 @@ class Isc(NNOperator): |
|
|
|
# todo: elif format == 'tensorrt': |
|
|
|
else: |
|
|
|
log.error(f'Unsupported format "{format}".') |
|
|
|
return path |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def convert_img(self, img: towhee._types.Image): |
|
|
|
img = PILImage.fromarray(img.astype('uint8'), 'RGB') |
|
|
|
return img |
|
|
|
|
|
|
|
@property |
|
|
|
def supported_formats(self): |
|
|
|
return ['onnx'] |
|
|
|
|
|
|
|
|
|
|
|
# if __name__ == '__main__': |
|
|
|
# from towhee import ops |
|
|
|