From 6a341403a771f5c993a7f7253d34cfc9a0ea9c9b Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 25 Mar 2022 17:14:17 +0800 Subject: [PATCH] Refactor Signed-off-by: Jael Gu --- README.md | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++- __init__.py | 19 +++++++++++++ timm_image.py | 62 ++++++++++++++++++++++++++++++++++------ 3 files changed, 149 insertions(+), 10 deletions(-) create mode 100644 __init__.py diff --git a/README.md b/README.md index 99fa16b..bd9fbaa 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,78 @@ -# timm-image +# Image Embedding with Timm + +*author: Jael Gu, Filip* + + + +## Desription + +An image embedding operator implemented with pretrained models provided by [Timm](https://github.com/rwightman/pytorch-image-models). + + + +```python +from towhee import ops +import numpy as np + +img_encoder = ops.image_embedding.timm('resnet50') +fake_img = np.zeros((256, 256, 3)) +image_embedding = img_encoder(fake_img) +``` + +## Factory Constructor + +Create the operator via the following factory method + +***ops.image_embedding.timm(model_name)*** + + + +## Interface + +An image decode operator takes an image path as input. It decodes the image back to ndarray. + + + +**Parameters:** + +​ ***img***: *numpy.ndarray* + +​ The decoded image data in numpy.ndarray. + + + +**Returns**: *numpy.ndarray* + +​ The image embedding extracted by model. + + + +## Code Example + +Load an image from path './dog.jpg' +and use the pretrained ResNet50 model ('resnet50') to generate an image embedding. + + *Write the pipeline in simplified style*: + +```python +import towhee.DataCollection as dc + +dc.glob(./dog.jpg) + .image_decode() + .image_embedding.timm('resnet50') + .show() +``` + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +from towhee import DataCollection as dc + +dc.glob['path'](./dog.jpg) + .image_decode['path', 'img']() + .image_embedding.timm['img', 'vec']('resnet50') + .select('img') + .show() +``` + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..85a7c81 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .timm_image import TimmImage + + +def timm(): + return TimmImage() diff --git a/timm_image.py b/timm_image.py index fef714f..09361e7 100644 --- a/timm_image.py +++ b/timm_image.py @@ -1,27 +1,50 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging import numpy -import torch -from typing import NamedTuple -from towhee.operator.base import NNOperator -from towhee.utils.pil_utils import to_pil -from towhee.types.image import Image as towheeImage +from towhee.operator.base import NNOperator, OperatorFlag +from towhee import register +import torch from torch import nn +from PIL import Image as PILImage + from timm.data.transforms_factory import create_transform from timm.data import resolve_data_config from timm.models.factory import create_model import warnings + warnings.filterwarnings('ignore') + +log = logging.getLogger() + +@register(output_schema=['vec']) class TimmImage(NNOperator): """ Pytorch image embedding operator that uses the Pytorch Image Model (timm) collection. Args: model_name (`str`): Which model to use for the embeddings. + num_classes (`int = 1000`): + Number of classes for classification. """ + def __init__(self, model_name: str, num_classes: int = 1000) -> None: super().__init__() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -31,8 +54,16 @@ class TimmImage(NNOperator): config = resolve_data_config({}, model=self.model) self.tfms = create_transform(**config) - def __call__(self, image: 'towheeImage') -> NamedTuple('Outputs', [('vec', numpy.ndarray)]): - img = self.tfms(to_pil(image)).unsqueeze(0) + def __call__(self, img: numpy.ndarray) -> numpy.ndarray: + if hasattr(img, 'mode'): + if img.mode != 'RGB': + log.error(f'Invalid image mode: expect "RGB" but receive "{img.mode}".') + raise AssertionError(f'Invalid image mode "{img.mode}".') + else: + log.warning(f'Image mode is not specified. Using "RGB" now.') + + img = PILImage.fromarray(img.astype('uint8'), 'RGB') + img = self.tfms(img).unsqueeze(0) img = img.to(self.device) features = self.model.forward_features(img) if features.dim() == 4: @@ -41,5 +72,18 @@ class TimmImage(NNOperator): features = features.to('cpu') feature_vector = features.flatten().detach().numpy() - Outputs = NamedTuple('Outputs', [('vec', numpy.ndarray)]) - return Outputs(feature_vector) + return feature_vector + + +# if __name__ == '__main__': +# import cv2 +# from towhee._types import Image +# +# +# path = '/path/to/image' +# img = cv2.imread(path) +# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) +# img = Image(img, 'RGB') +# +# op = TimmImage('resnet50') +# out = op(img)