zhang chen
4 years ago
2 changed files with
2 additions and
2 deletions
-
README.md
-
pytorch/model.py
|
|
@ -41,6 +41,7 @@ The Operator returns a tuple `Tuple[('feature_vector', numpy.ndarray)]` containi |
|
|
|
- feature_vector: |
|
|
|
- the embedding of the image |
|
|
|
- data type: `numpy.ndarray` |
|
|
|
- shape: (dim,) |
|
|
|
|
|
|
|
## Requirements |
|
|
|
|
|
|
|
|
|
@ -32,8 +32,7 @@ class Model(): |
|
|
|
|
|
|
|
def __call__(self, img_tensor: torch.Tensor): |
|
|
|
features = self._model.extract_features(img_tensor) |
|
|
|
channels = features.shape[0] |
|
|
|
return self._avg_pooling(features).view(channels, -1).detach().numpy() |
|
|
|
return self._avg_pooling(features).flatten().detach().numpy() |
|
|
|
|
|
|
|
def train(self): |
|
|
|
""" |
|
|
|