From 27666263eeb21ca6ae952afd2a5f7856743df515 Mon Sep 17 00:00:00 2001 From: zhang chen Date: Fri, 17 Dec 2021 16:52:45 +0800 Subject: [PATCH] add return shape introduction --- README.md | 1 + pytorch/model.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 0a7e661..eabb5ef 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ The Operator returns a tuple `Tuple[('feature_vector', numpy.ndarray)]` containi - feature_vector: - the embedding of the image - data type: `numpy.ndarray` + - shape: (dim,) ## Requirements diff --git a/pytorch/model.py b/pytorch/model.py index 507eee6..728d953 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -32,8 +32,7 @@ class Model(): def __call__(self, img_tensor: torch.Tensor): features = self._model.extract_features(img_tensor) - channels = features.shape[0] - return self._avg_pooling(features).view(channels, -1).detach().numpy() + return self._avg_pooling(features).flatten().detach().numpy() def train(self): """