|
@ -51,12 +51,13 @@ def torch_no_grad(f): |
|
|
# @accelerate |
|
|
# @accelerate |
|
|
class Model: |
|
|
class Model: |
|
|
def __init__(self, model_name, device, num_classes): |
|
|
def __init__(self, model_name, device, num_classes): |
|
|
|
|
|
self.device = device |
|
|
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) |
|
|
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) |
|
|
self.model.eval() |
|
|
self.model.eval() |
|
|
self.model.to(device) |
|
|
self.model.to(device) |
|
|
|
|
|
|
|
|
def __call__(self, x: torch.Tensor): |
|
|
def __call__(self, x: torch.Tensor): |
|
|
return self.model.forward_features(x) |
|
|
|
|
|
|
|
|
return self.model.forward_features(x.to(self.device)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
@register(output_schema=['vec']) |
|
@ -113,7 +114,7 @@ class TimmImage(NNOperator): |
|
|
img = img if self.skip_tfms else self.tfms(img) |
|
|
img = img if self.skip_tfms else self.tfms(img) |
|
|
img_list.append(img) |
|
|
img_list.append(img) |
|
|
inputs = torch.stack(img_list) |
|
|
inputs = torch.stack(img_list) |
|
|
inputs = inputs.to(self.device) |
|
|
|
|
|
|
|
|
inputs = inputs |
|
|
features = self.model(inputs) |
|
|
features = self.model(inputs) |
|
|
if isinstance(features, list): |
|
|
if isinstance(features, list): |
|
|
features = [self.post_proc(x) for x in features] |
|
|
features = [self.post_proc(x) for x in features] |
|
@ -143,6 +144,7 @@ class TimmImage(NNOperator): |
|
|
return img |
|
|
return img |
|
|
|
|
|
|
|
|
def post_proc(self, features): |
|
|
def post_proc(self, features): |
|
|
|
|
|
features = features.to(self.device) |
|
|
if features.dim() == 3: |
|
|
if features.dim() == 3: |
|
|
features = features[:, 0] |
|
|
features = features[:, 0] |
|
|
if features.dim() == 4: |
|
|
if features.dim() == 4: |
|
@ -166,7 +168,7 @@ class TimmImage(NNOperator): |
|
|
path = path + '.onnx' |
|
|
path = path + '.onnx' |
|
|
else: |
|
|
else: |
|
|
raise AttributeError(f'Invalid format {format}.') |
|
|
raise AttributeError(f'Invalid format {format}.') |
|
|
dummy_input = torch.rand((1,) + self.config['input_size']).to(self.device) |
|
|
|
|
|
|
|
|
dummy_input = torch.rand((1,) + self.config['input_size']) |
|
|
if format == 'pytorch': |
|
|
if format == 'pytorch': |
|
|
torch.save(self._model, path) |
|
|
torch.save(self._model, path) |
|
|
elif format == 'torchscript': |
|
|
elif format == 'torchscript': |
|
@ -182,7 +184,7 @@ class TimmImage(NNOperator): |
|
|
elif format == 'onnx': |
|
|
elif format == 'onnx': |
|
|
self._model.forward = self._model.forward_features |
|
|
self._model.forward = self._model.forward_features |
|
|
try: |
|
|
try: |
|
|
torch.onnx.export(self._model, |
|
|
|
|
|
|
|
|
torch.onnx.export(self._model.to('cpu'), |
|
|
dummy_input, |
|
|
dummy_input, |
|
|
path, |
|
|
path, |
|
|
input_names=['input_0'], |
|
|
input_names=['input_0'], |
|
|