diff --git a/README.md b/README.md index 0a7e661..eabb5ef 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ The Operator returns a tuple `Tuple[('feature_vector', numpy.ndarray)]` containi - feature_vector: - the embedding of the image - data type: `numpy.ndarray` + - shape: (dim,) ## Requirements diff --git a/pytorch/model.py b/pytorch/model.py index 507eee6..728d953 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -32,8 +32,7 @@ class Model(): def __call__(self, img_tensor: torch.Tensor): features = self._model.extract_features(img_tensor) - channels = features.shape[0] - return self._avg_pooling(features).view(channels, -1).detach().numpy() + return self._avg_pooling(features).flatten().detach().numpy() def train(self): """