logo
Browse Source

Update device

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
f6033fcc81
  1. 10
      timm_image.py

10
timm_image.py

@ -51,12 +51,13 @@ def torch_no_grad(f):
# @accelerate
class Model:
def __init__(self, model_name, device, num_classes):
self.device = device
self.model = create_model(model_name, pretrained=True, num_classes=num_classes)
self.model.eval()
self.model.to(device)
def __call__(self, x: torch.Tensor):
return self.model.forward_features(x)
return self.model.forward_features(x.to(self.device))
@register(output_schema=['vec'])
@ -113,7 +114,7 @@ class TimmImage(NNOperator):
img = img if self.skip_tfms else self.tfms(img)
img_list.append(img)
inputs = torch.stack(img_list)
inputs = inputs.to(self.device)
inputs = inputs
features = self.model(inputs)
if isinstance(features, list):
features = [self.post_proc(x) for x in features]
@ -143,6 +144,7 @@ class TimmImage(NNOperator):
return img
def post_proc(self, features):
features = features.to(self.device)
if features.dim() == 3:
features = features[:, 0]
if features.dim() == 4:
@ -166,7 +168,7 @@ class TimmImage(NNOperator):
path = path + '.onnx'
else:
raise AttributeError(f'Invalid format {format}.')
dummy_input = torch.rand((1,) + self.config['input_size']).to(self.device)
dummy_input = torch.rand((1,) + self.config['input_size'])
if format == 'pytorch':
torch.save(self._model, path)
elif format == 'torchscript':
@ -182,7 +184,7 @@ class TimmImage(NNOperator):
elif format == 'onnx':
self._model.forward = self._model.forward_features
try:
torch.onnx.export(self._model,
torch.onnx.export(self._model.to('cpu'),
dummy_input,
path,
input_names=['input_0'],

Loading…
Cancel
Save