diff --git a/efficientnet_image_embedding.py b/efficientnet_image_embedding.py index bb62091..418ca4f 100644 --- a/efficientnet_image_embedding.py +++ b/efficientnet_image_embedding.py @@ -43,8 +43,7 @@ class EfficientnetImageEmbedding(Operator): transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): - Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) img = self.tfms(Image.open(img_path)).unsqueeze(0) - features = self.model._model.extract_features(img) + features = self.model(img) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) - return Outputs(features.flatten().detach().numpy()) + return Outputs(features) diff --git a/pytorch/model.py b/pytorch/model.py index d973dac..507eee6 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -27,10 +27,13 @@ class Model(): def __init__(self, model_name: str, weights_path: str): super().__init__() self._model = EfficientNet.from_pretrained(model_name=model_name, weights_path=weights_path) + self._avg_pooling = torch.nn.AdaptiveAvgPool2d((1, 1)) self._model.eval() def __call__(self, img_tensor: torch.Tensor): - return self._model(img_tensor).detach().numpy() + features = self._model.extract_features(img_tensor) + channels = features.shape[0] + return self._avg_pooling(features).view(channels, -1).detach().numpy() def train(self): """