diff --git a/efficientnet_image_embedding.py b/efficientnet_image_embedding.py index 418ca4f..1dccb50 100644 --- a/efficientnet_image_embedding.py +++ b/efficientnet_image_embedding.py @@ -21,7 +21,8 @@ from pathlib import Path import numpy from towhee.operator import Operator - +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform class EfficientnetImageEmbedding(Operator): """ @@ -33,17 +34,19 @@ class EfficientnetImageEmbedding(Operator): Path to local weights. """ - def __init__(self, model_name: str = 'efficientnet-b7', framework: str = 'pytorch', weights_path: str = None) -> None: + def __init__(self, model_name: str = '', framework: str = 'pytorch', weights_path: str = None) -> None: + model_name = model_name.replace('efficientnet-b', 'tf_efficientnet_b') super().__init__() sys.path.append(str(Path(__file__).parent)) if framework == 'pytorch': + import pytorch from pytorch.model import Model self.model = Model(model_name, weights_path) - self.tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) + config = resolve_data_config({}, model=self.model._model) + self.tfms = create_transform(**config) def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): + Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) img = self.tfms(Image.open(img_path)).unsqueeze(0) features = self.model(img) - Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) - return Outputs(features) + return Outputs(features.flatten().detach().numpy()) diff --git a/pytorch/model.py b/pytorch/model.py index 728d953..007b645 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -13,11 +13,8 @@ # limitations under the License. -from typing import NamedTuple - -import numpy import torch -from efficientnet_pytorch import EfficientNet +import timm class Model(): @@ -26,13 +23,14 @@ class Model(): """ def __init__(self, model_name: str, weights_path: str): super().__init__() - self._model = EfficientNet.from_pretrained(model_name=model_name, weights_path=weights_path) - self._avg_pooling = torch.nn.AdaptiveAvgPool2d((1, 1)) + if weights_path: + self._model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=0) + else: + self._model = timm.create_model(model_name, pretrained=True, num_classes=0) self._model.eval() def __call__(self, img_tensor: torch.Tensor): - features = self._model.extract_features(img_tensor) - return self._avg_pooling(features).flatten().detach().numpy() + return self._model(img_tensor) def train(self): """