|
|
@ -55,8 +55,8 @@ class Swag(NNOperator): |
|
|
|
self.model_name = model_name |
|
|
|
self.model = torch.hub.load("facebookresearch/swag", model=model_name) |
|
|
|
self.model.to(self.device) |
|
|
|
self.model.head = None # To extract features without model head |
|
|
|
self.model.eval() |
|
|
|
self.extract_features = FeatureExtractor(self.model) |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def __call__(self, img: towhee._types.Image) -> numpy.ndarray: |
|
|
@ -64,7 +64,7 @@ class Swag(NNOperator): |
|
|
|
if not self.skip_tfms: |
|
|
|
img = self.tfms(img).unsqueeze(0) |
|
|
|
img = img.to(self.device) |
|
|
|
features, _ = self.extract_features(img) |
|
|
|
features = self.model(img) |
|
|
|
if features.dim() == 4: |
|
|
|
global_pool = nn.AdaptiveAvgPool2d(1) |
|
|
|
features = global_pool(features) |
|
|
@ -147,27 +147,6 @@ class Swag(NNOperator): |
|
|
|
return transform |
|
|
|
|
|
|
|
|
|
|
|
class FeatureExtractor(nn.Module): |
|
|
|
def __init__(self, model: nn.Module): |
|
|
|
super().__init__() |
|
|
|
self.model = model |
|
|
|
self.features = None |
|
|
|
|
|
|
|
for name, child in self.model.named_children(): |
|
|
|
if name == 'trunk_output': |
|
|
|
self.handler = child.register_forward_hook(self.save_outputs_hook()) |
|
|
|
|
|
|
|
def save_outputs_hook(self): |
|
|
|
def fn(_, __, output): |
|
|
|
self.features = output |
|
|
|
return fn |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
outs = self.model(x) |
|
|
|
self.handler.remove() |
|
|
|
return self.features, outs |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
from towhee import ops |
|
|
|
|
|
|
@ -175,7 +154,7 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
decoder = ops.image_decode.cv2() |
|
|
|
img = decoder(path) |
|
|
|
op = Swag('vit_b16_in1k') |
|
|
|
# op = Swag('regnety_16gf_in1k') |
|
|
|
# op = Swag('vit_b16_in1k') |
|
|
|
op = Swag('regnety_16gf_in1k') |
|
|
|
out = op(img) |
|
|
|
print(out.shape) |
|
|
|