logo
Browse Source

Refactor

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
test
Jael Gu 3 years ago
parent
commit
6a341403a7
  1. 78
      README.md
  2. 19
      __init__.py
  3. 62
      timm_image.py

78
README.md

@ -1,2 +1,78 @@
# timm-image
# Image Embedding with Timm
*author: Jael Gu, Filip*
## Desription
An image embedding operator implemented with pretrained models provided by [Timm](https://github.com/rwightman/pytorch-image-models).
```python
from towhee import ops
import numpy as np
img_encoder = ops.image_embedding.timm('resnet50')
fake_img = np.zeros((256, 256, 3))
image_embedding = img_encoder(fake_img)
```
## Factory Constructor
Create the operator via the following factory method
***ops.image_embedding.timm(model_name)***
## Interface
An image decode operator takes an image path as input. It decodes the image back to ndarray.
**Parameters:**
***img***: *numpy.ndarray*
​ The decoded image data in numpy.ndarray.
**Returns**: *numpy.ndarray*
​ The image embedding extracted by model.
## Code Example
Load an image from path './dog.jpg'
and use the pretrained ResNet50 model ('resnet50') to generate an image embedding.
*Write the pipeline in simplified style*:
```python
import towhee.DataCollection as dc
dc.glob(./dog.jpg)
.image_decode()
.image_embedding.timm('resnet50')
.show()
```
*Write a same pipeline with explicit inputs/outputs name specifications:*
```python
from towhee import DataCollection as dc
dc.glob['path'](./dog.jpg)
.image_decode['path', 'img']()
.image_embedding.timm['img', 'vec']('resnet50')
.select('img')
.show()
```

19
__init__.py

@ -0,0 +1,19 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .timm_image import TimmImage
def timm():
return TimmImage()

62
timm_image.py

@ -1,27 +1,50 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy import numpy
import torch
from typing import NamedTuple
from towhee.operator.base import NNOperator
from towhee.utils.pil_utils import to_pil
from towhee.types.image import Image as towheeImage
from towhee.operator.base import NNOperator, OperatorFlag
from towhee import register
import torch
from torch import nn from torch import nn
from PIL import Image as PILImage
from timm.data.transforms_factory import create_transform from timm.data.transforms_factory import create_transform
from timm.data import resolve_data_config from timm.data import resolve_data_config
from timm.models.factory import create_model from timm.models.factory import create_model
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
log = logging.getLogger()
@register(output_schema=['vec'])
class TimmImage(NNOperator): class TimmImage(NNOperator):
""" """
Pytorch image embedding operator that uses the Pytorch Image Model (timm) collection. Pytorch image embedding operator that uses the Pytorch Image Model (timm) collection.
Args: Args:
model_name (`str`): model_name (`str`):
Which model to use for the embeddings. Which model to use for the embeddings.
num_classes (`int = 1000`):
Number of classes for classification.
""" """
def __init__(self, model_name: str, num_classes: int = 1000) -> None: def __init__(self, model_name: str, num_classes: int = 1000) -> None:
super().__init__() super().__init__()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ -31,8 +54,16 @@ class TimmImage(NNOperator):
config = resolve_data_config({}, model=self.model) config = resolve_data_config({}, model=self.model)
self.tfms = create_transform(**config) self.tfms = create_transform(**config)
def __call__(self, image: 'towheeImage') -> NamedTuple('Outputs', [('vec', numpy.ndarray)]):
img = self.tfms(to_pil(image)).unsqueeze(0)
def __call__(self, img: numpy.ndarray) -> numpy.ndarray:
if hasattr(img, 'mode'):
if img.mode != 'RGB':
log.error(f'Invalid image mode: expect "RGB" but receive "{img.mode}".')
raise AssertionError(f'Invalid image mode "{img.mode}".')
else:
log.warning(f'Image mode is not specified. Using "RGB" now.')
img = PILImage.fromarray(img.astype('uint8'), 'RGB')
img = self.tfms(img).unsqueeze(0)
img = img.to(self.device) img = img.to(self.device)
features = self.model.forward_features(img) features = self.model.forward_features(img)
if features.dim() == 4: if features.dim() == 4:
@ -41,5 +72,18 @@ class TimmImage(NNOperator):
features = features.to('cpu') features = features.to('cpu')
feature_vector = features.flatten().detach().numpy() feature_vector = features.flatten().detach().numpy()
Outputs = NamedTuple('Outputs', [('vec', numpy.ndarray)])
return Outputs(feature_vector)
return feature_vector
# if __name__ == '__main__':
# import cv2
# from towhee._types import Image
#
#
# path = '/path/to/image'
# img = cv2.imread(path)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# img = Image(img, 'RGB')
#
# op = TimmImage('resnet50')
# out = op(img)

Loading…
Cancel
Save