From f6033fcc81af8b4ec87c33d02903051e2e61ec42 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 8 Feb 2023 19:03:47 +0800 Subject: [PATCH] Update device Signed-off-by: Jael Gu --- timm_image.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/timm_image.py b/timm_image.py index f59c7bd..286c9bc 100644 --- a/timm_image.py +++ b/timm_image.py @@ -51,12 +51,13 @@ def torch_no_grad(f): # @accelerate class Model: def __init__(self, model_name, device, num_classes): + self.device = device self.model = create_model(model_name, pretrained=True, num_classes=num_classes) self.model.eval() self.model.to(device) def __call__(self, x: torch.Tensor): - return self.model.forward_features(x) + return self.model.forward_features(x.to(self.device)) @register(output_schema=['vec']) @@ -113,7 +114,7 @@ class TimmImage(NNOperator): img = img if self.skip_tfms else self.tfms(img) img_list.append(img) inputs = torch.stack(img_list) - inputs = inputs.to(self.device) + inputs = inputs features = self.model(inputs) if isinstance(features, list): features = [self.post_proc(x) for x in features] @@ -143,6 +144,7 @@ class TimmImage(NNOperator): return img def post_proc(self, features): + features = features.to(self.device) if features.dim() == 3: features = features[:, 0] if features.dim() == 4: @@ -166,7 +168,7 @@ class TimmImage(NNOperator): path = path + '.onnx' else: raise AttributeError(f'Invalid format {format}.') - dummy_input = torch.rand((1,) + self.config['input_size']).to(self.device) + dummy_input = torch.rand((1,) + self.config['input_size']) if format == 'pytorch': torch.save(self._model, path) elif format == 'torchscript': @@ -182,7 +184,7 @@ class TimmImage(NNOperator): elif format == 'onnx': self._model.forward = self._model.forward_features try: - torch.onnx.export(self._model, + torch.onnx.export(self._model.to('cpu'), dummy_input, path, input_names=['input_0'],