From 1f59441ad4513325a015232f0b9a8800ba27aba0 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 16 Jun 2022 19:22:11 +0800 Subject: [PATCH] Debug performance issue Signed-off-by: Jael Gu --- swag.py | 29 ++++------------------------- 1 file changed, 4 insertions(+), 25 deletions(-) diff --git a/swag.py b/swag.py index 25cbe52..7bbc412 100644 --- a/swag.py +++ b/swag.py @@ -55,8 +55,8 @@ class Swag(NNOperator): self.model_name = model_name self.model = torch.hub.load("facebookresearch/swag", model=model_name) self.model.to(self.device) + self.model.head = None # To extract features without model head self.model.eval() - self.extract_features = FeatureExtractor(self.model) @arg(1, to_image_color('RGB')) def __call__(self, img: towhee._types.Image) -> numpy.ndarray: @@ -64,7 +64,7 @@ class Swag(NNOperator): if not self.skip_tfms: img = self.tfms(img).unsqueeze(0) img = img.to(self.device) - features, _ = self.extract_features(img) + features = self.model(img) if features.dim() == 4: global_pool = nn.AdaptiveAvgPool2d(1) features = global_pool(features) @@ -147,27 +147,6 @@ class Swag(NNOperator): return transform -class FeatureExtractor(nn.Module): - def __init__(self, model: nn.Module): - super().__init__() - self.model = model - self.features = None - - for name, child in self.model.named_children(): - if name == 'trunk_output': - self.handler = child.register_forward_hook(self.save_outputs_hook()) - - def save_outputs_hook(self): - def fn(_, __, output): - self.features = output - return fn - - def forward(self, x): - outs = self.model(x) - self.handler.remove() - return self.features, outs - - if __name__ == '__main__': from towhee import ops @@ -175,7 +154,7 @@ if __name__ == '__main__': decoder = ops.image_decode.cv2() img = decoder(path) - op = Swag('vit_b16_in1k') - # op = Swag('regnety_16gf_in1k') + # op = Swag('vit_b16_in1k') + op = Swag('regnety_16gf_in1k') out = op(img) print(out.shape)