logo
Browse Source

Update device

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
f6033fcc81
  1. 10
      timm_image.py

10
timm_image.py

@ -51,12 +51,13 @@ def torch_no_grad(f):
# @accelerate # @accelerate
class Model: class Model:
def __init__(self, model_name, device, num_classes): def __init__(self, model_name, device, num_classes):
self.device = device
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) self.model = create_model(model_name, pretrained=True, num_classes=num_classes)
self.model.eval() self.model.eval()
self.model.to(device) self.model.to(device)
def __call__(self, x: torch.Tensor): def __call__(self, x: torch.Tensor):
return self.model.forward_features(x)
return self.model.forward_features(x.to(self.device))
@register(output_schema=['vec']) @register(output_schema=['vec'])
@ -113,7 +114,7 @@ class TimmImage(NNOperator):
img = img if self.skip_tfms else self.tfms(img) img = img if self.skip_tfms else self.tfms(img)
img_list.append(img) img_list.append(img)
inputs = torch.stack(img_list) inputs = torch.stack(img_list)
inputs = inputs.to(self.device)
inputs = inputs
features = self.model(inputs) features = self.model(inputs)
if isinstance(features, list): if isinstance(features, list):
features = [self.post_proc(x) for x in features] features = [self.post_proc(x) for x in features]
@ -143,6 +144,7 @@ class TimmImage(NNOperator):
return img return img
def post_proc(self, features): def post_proc(self, features):
features = features.to(self.device)
if features.dim() == 3: if features.dim() == 3:
features = features[:, 0] features = features[:, 0]
if features.dim() == 4: if features.dim() == 4:
@ -166,7 +168,7 @@ class TimmImage(NNOperator):
path = path + '.onnx' path = path + '.onnx'
else: else:
raise AttributeError(f'Invalid format {format}.') raise AttributeError(f'Invalid format {format}.')
dummy_input = torch.rand((1,) + self.config['input_size']).to(self.device)
dummy_input = torch.rand((1,) + self.config['input_size'])
if format == 'pytorch': if format == 'pytorch':
torch.save(self._model, path) torch.save(self._model, path)
elif format == 'torchscript': elif format == 'torchscript':
@ -182,7 +184,7 @@ class TimmImage(NNOperator):
elif format == 'onnx': elif format == 'onnx':
self._model.forward = self._model.forward_features self._model.forward = self._model.forward_features
try: try:
torch.onnx.export(self._model,
torch.onnx.export(self._model.to('cpu'),
dummy_input, dummy_input,
path, path,
input_names=['input_0'], input_names=['input_0'],

Loading…
Cancel
Save