logo
Browse Source

fix dimension problem

main
zhang chen 4 years ago
parent
commit
c19a0952fc
  1. 5
      efficientnet_image_embedding.py
  2. 5
      pytorch/model.py

5
efficientnet_image_embedding.py

@ -43,8 +43,7 @@ class EfficientnetImageEmbedding(Operator):
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ])
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)])
img = self.tfms(Image.open(img_path)).unsqueeze(0)
features = self.model._model.extract_features(img)
features = self.model(img)
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(features.flatten().detach().numpy())
return Outputs(features)

5
pytorch/model.py

@ -27,10 +27,13 @@ class Model():
def __init__(self, model_name: str, weights_path: str):
super().__init__()
self._model = EfficientNet.from_pretrained(model_name=model_name, weights_path=weights_path)
self._avg_pooling = torch.nn.AdaptiveAvgPool2d((1, 1))
self._model.eval()
def __call__(self, img_tensor: torch.Tensor):
return self._model(img_tensor).detach().numpy()
features = self._model.extract_features(img_tensor)
channels = features.shape[0]
return self._avg_pooling(features).view(channels, -1).detach().numpy()
def train(self):
"""

Loading…
Cancel
Save