diff --git a/README.md b/README.md index 9ad9307..a82e499 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ __init__(self, model_name: str = 'efficientnet-b7', framework: str = 'pytorch', **Args:** + - model_name: - the model name for embedding - supported types: `str`, for example 'efficientnet-b7' diff --git a/efficientnet_image_embedding.py b/efficientnet_image_embedding.py index ba0857f..9f7dd62 100644 --- a/efficientnet_image_embedding.py +++ b/efficientnet_image_embedding.py @@ -11,23 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import timm from typing import NamedTuple -from PIL import Image import torch -from torchvision import transforms -import sys import towhee -from pathlib import Path import numpy +import torch.nn as nn -from towhee.operator import Operator +from towhee.operator import NNOperator from towhee.utils.pil_utils import to_pil from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform -import os -class EfficientnetImageEmbedding(Operator): + +class EfficientnetImageEmbedding(NNOperator): """ Embedding extractor using efficientnet. Args: @@ -37,22 +34,28 @@ class EfficientnetImageEmbedding(Operator): Path to local weights. """ - def __init__(self, model_name: str = '', framework: str = 'pytorch', weights_path: str = None) -> None: + def __init__(self, model_name: str = '', num_classes: int = 1000, framework: str = 'pytorch', + weights_path: str = None) -> None: + super().__init__(framework=framework) model_name = model_name.replace('efficientnet-b', 'tf_efficientnet_b') - super().__init__() - if framework == 'pytorch': - import importlib.util - path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') - opname = os.path.basename(str(Path(__file__))).split('.')[0] - spec = importlib.util.spec_from_file_location(opname, path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - self.model = module.Model(model_name, weights_path) - config = resolve_data_config({}, model=self.model._model) + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + if weights_path: + self.model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=num_classes) + else: + self.model = timm.create_model(model_name, pretrained=True, num_classes=num_classes) + self.model.eval() + config = resolve_data_config({}, model=self.model) self.tfms = create_transform(**config) def __call__(self, image: 'towhee.types.Image') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) - img = self.tfms(to_pil(image)).unsqueeze(0) - features = self.model(img) - return Outputs(features.flatten().detach().numpy()) + self.model.to(self.device) + self.model.eval() + img_tensor = self.tfms(to_pil(image)).unsqueeze(0) + features = self.model.forward_features(img_tensor) + if features.dim() == 4: # if the shape of feature map is [N, C, H, W], where H > 1 and W > 1 + global_pool = nn.AdaptiveAvgPool2d(1) + features = global_pool(features) + features = features.to('cpu') + features = features.flatten().detach().numpy() + return Outputs(features) diff --git a/pytorch/__init__.py b/pytorch/__init__.py deleted file mode 100644 index ed43fad..0000000 --- a/pytorch/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import os - -# For requirements. -try: - import efficientnet_pytorch -except ModuleNotFoundError: - os.system('pip install efficientnet_pytorch') \ No newline at end of file diff --git a/pytorch/model.py b/pytorch/model.py deleted file mode 100644 index 007b645..0000000 --- a/pytorch/model.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import timm - - -class Model(): - """ - PyTorch model class - """ - def __init__(self, model_name: str, weights_path: str): - super().__init__() - if weights_path: - self._model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=0) - else: - self._model = timm.create_model(model_name, pretrained=True, num_classes=0) - self._model.eval() - - def __call__(self, img_tensor: torch.Tensor): - return self._model(img_tensor) - - def train(self): - """ - For training model - """ - pass