diff --git a/README.md b/README.md index dcf72cf..28196d9 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,100 @@ -# mpvit +# Image Embedding with MPViT +*author: Chen Zhang* + +
+ +## Description + +This operator extracts features for images with [Multi-Path Vision Transformer (MPViT)](https://arxiv.org/abs/2112.11010) which can generate embeddings for images. MPViT embeds features of the same size~(i.e., sequence length) with patches of different scales simultaneously by using overlapping convolutional patch embedding. Tokens of different scales are then independently fed into the Transformer encoders via multiple paths and the resulting features are aggregated, enabling both fine and coarse feature representations at the same feature level. + + +
+ +## Code Example + +Load an image from path './towhee.jpeg' +and use the pre-trained mpvit_base model to generate an image embedding. + + *Write the pipeline in simplified style:* + +```python +import towhee + +towhee.glob('./towhee.jpeg') \ + .image_decode() \ + .image_embedding.mpvit(model_name='mpvit_base') \ + .show() +``` + +![](result1.png) + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +towhee.glob['path']('./towhee.jpeg') \ + .image_decode['path', 'img']() \ + .image_embedding.mpvit['img', 'vec'](model_name='mpvit_base') \ + .select['img', 'vec']() \ + .show() +``` +![](result2.png) + +
+ +## Factory Constructor + +Create the operator via the following factory method: + +***image_embedding.mpvit(model_name='mpvit_base', \*\*kwargs)*** + +**Parameters:** + +***model_name:*** *str* + +Pretrained model name include `mpvit_tiny`, `mpvit_xsmall`, `mpvit_small` or `mpvit_base`, all of which are pretrained on ImageNet-1K dataset, for more information, please refer the original [MPViT github page](https://github.com/youngwanLEE/MPViT). + +***weights_path:*** *str* + +Your local weights path, default is None, which means using the pretrained model weights. + +***device:*** *str* + +Model device, `cpu` or `cuda`. + +***num_classes:*** *int* + +The number of classes. The default value is 1000. +It is related to model and dataset. If you want to fine-tune this operator, you can change this value to adapt your datasets. + +***skip_preprocess:*** *bool* + +The flag to control whether to skip image pre-process. +The default value is False. +If set to True, it will skip image preprocessing steps (transforms). +In this case, input image data must be prepared in advance in order to properly fit the model. + +
+ +## Interface + +An image embedding operator takes a towhee image as input. +It uses the pre-trained model specified by model name to generate an image embedding in ndarray. + +**Parameters:** + +***data:*** *towhee._types.Image* + +The decoded image data in towhee Image (a subset of numpy.ndarray). + + +**Returns:** *numpy.ndarray* + + An image embedding generated by model, in shape of (feature_dim,). + For mpvit_tiny model, feature_dim = 216. + For mpvit_xsmall model, feature_dim = 256. + For mpvit_small model, feature_dim = 288. + For mpvit_base model, feature_dim = 480. + \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..78e6cb9 --- /dev/null +++ b/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mpvit import MPViT + + +def mpvit(model_name: str, **kwargs): + return MPViT(model_name, **kwargs) + diff --git a/mpvit.py b/mpvit.py new file mode 100644 index 0000000..5cd3fc1 --- /dev/null +++ b/mpvit.py @@ -0,0 +1,79 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from typing import List, Union + +from torch import nn +from towhee.types.arg import arg, to_image_color + +from timm.data import resolve_data_config, create_transform +from towhee.models import mpvit +from towhee.operator.base import NNOperator +from towhee import register +import towhee +from PIL import Image as PILImage + + +@register(output_schema=['vec']) +class MPViT(NNOperator): + """ + MPViT embedding operator + """ + + def __init__(self, model_name, + num_classes: int = 1000, + weights_path: str = None, + device: str = None, + skip_preprocess: bool = False): + super().__init__() + + self.model = mpvit.create_model(model_name=model_name, + num_classes=num_classes, + pretrained=True, + weights_path=weights_path, + device=device) + self.model.eval() + self.config = resolve_data_config({}, model=self.model) + self.tfms = create_transform(**self.config) + self.skip_tfms = skip_preprocess + self.device = device + + def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): + if not isinstance(data, list): + imgs = [data] + else: + imgs = data + img_list = [] + for img in imgs: + img = self.convert_img(img) + img = img if self.skip_tfms else self.tfms(img) + img_list.append(img) + inputs = torch.stack(img_list) + inputs = inputs.to(self.device) + features = self.model.forward_features(inputs) + global_pool = nn.AdaptiveAvgPool2d(1) + features = global_pool(features) + features = features.to('cpu').flatten(1) + if isinstance(data, list): + vecs = list(features.detach().numpy()) + else: + vecs = features.squeeze(0).detach().numpy() + return vecs + + @arg(1, to_image_color('RGB')) + def convert_img(self, img: towhee._types.Image): + img = PILImage.fromarray(img.astype('uint8'), 'RGB') + return img diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7e6cf10 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +towhee.models >=0.7.0 +towhee>=0.7.0 +timm +Pillow +einops \ No newline at end of file diff --git a/result1.png b/result1.png new file mode 100644 index 0000000..a543dc9 Binary files /dev/null and b/result1.png differ diff --git a/result2.png b/result2.png new file mode 100644 index 0000000..43a5f0a Binary files /dev/null and b/result2.png differ diff --git a/towhee.jpeg b/towhee.jpeg new file mode 100644 index 0000000..caf63b3 Binary files /dev/null and b/towhee.jpeg differ