logo
Browse Source

Debug performance issue

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
1f59441ad4
  1. 29
      swag.py

29
swag.py

@ -55,8 +55,8 @@ class Swag(NNOperator):
self.model_name = model_name
self.model = torch.hub.load("facebookresearch/swag", model=model_name)
self.model.to(self.device)
self.model.head = None # To extract features without model head
self.model.eval()
self.extract_features = FeatureExtractor(self.model)
@arg(1, to_image_color('RGB'))
def __call__(self, img: towhee._types.Image) -> numpy.ndarray:
@ -64,7 +64,7 @@ class Swag(NNOperator):
if not self.skip_tfms:
img = self.tfms(img).unsqueeze(0)
img = img.to(self.device)
features, _ = self.extract_features(img)
features = self.model(img)
if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1)
features = global_pool(features)
@ -147,27 +147,6 @@ class Swag(NNOperator):
return transform
class FeatureExtractor(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
self.features = None
for name, child in self.model.named_children():
if name == 'trunk_output':
self.handler = child.register_forward_hook(self.save_outputs_hook())
def save_outputs_hook(self):
def fn(_, __, output):
self.features = output
return fn
def forward(self, x):
outs = self.model(x)
self.handler.remove()
return self.features, outs
if __name__ == '__main__':
from towhee import ops
@ -175,7 +154,7 @@ if __name__ == '__main__':
decoder = ops.image_decode.cv2()
img = decoder(path)
op = Swag('vit_b16_in1k')
# op = Swag('regnety_16gf_in1k')
# op = Swag('vit_b16_in1k')
op = Swag('regnety_16gf_in1k')
out = op(img)
print(out.shape)

Loading…
Cancel
Save