logo
Browse Source

build mpvit

main
ChengZi 2 years ago
parent
commit
fc0c0f2c3d
  1. 100
      README.md
  2. 20
      __init__.py
  3. 79
      mpvit.py
  4. 5
      requirements.txt
  5. BIN
      result1.png
  6. BIN
      result2.png
  7. BIN
      towhee.jpeg

100
README.md

@ -1,2 +1,100 @@
# mpvit
# Image Embedding with MPViT
*author: Chen Zhang*
<br />
## Description
This operator extracts features for images with [Multi-Path Vision Transformer (MPViT)](https://arxiv.org/abs/2112.11010) which can generate embeddings for images. MPViT embeds features of the same size~(i.e., sequence length) with patches of different scales simultaneously by using overlapping convolutional patch embedding. Tokens of different scales are then independently fed into the Transformer encoders via multiple paths and the resulting features are aggregated, enabling both fine and coarse feature representations at the same feature level.
<br />
## Code Example
Load an image from path './towhee.jpeg'
and use the pre-trained mpvit_base model to generate an image embedding.
*Write the pipeline in simplified style:*
```python
import towhee
towhee.glob('./towhee.jpeg') \
.image_decode() \
.image_embedding.mpvit(model_name='mpvit_base') \
.show()
```
![](result1.png)
*Write a same pipeline with explicit inputs/outputs name specifications:*
```python
import towhee
towhee.glob['path']('./towhee.jpeg') \
.image_decode['path', 'img']() \
.image_embedding.mpvit['img', 'vec'](model_name='mpvit_base') \
.select['img', 'vec']() \
.show()
```
![](result2.png)
<br />
## Factory Constructor
Create the operator via the following factory method:
***image_embedding.mpvit(model_name='mpvit_base', \*\*kwargs)***
**Parameters:**
***model_name:*** *str*
Pretrained model name include `mpvit_tiny`, `mpvit_xsmall`, `mpvit_small` or `mpvit_base`, all of which are pretrained on ImageNet-1K dataset, for more information, please refer the original [MPViT github page](https://github.com/youngwanLEE/MPViT).
***weights_path:*** *str*
Your local weights path, default is None, which means using the pretrained model weights.
***device:*** *str*
Model device, `cpu` or `cuda`.
***num_classes:*** *int*
The number of classes. The default value is 1000.
It is related to model and dataset. If you want to fine-tune this operator, you can change this value to adapt your datasets.
***skip_preprocess:*** *bool*
The flag to control whether to skip image pre-process.
The default value is False.
If set to True, it will skip image preprocessing steps (transforms).
In this case, input image data must be prepared in advance in order to properly fit the model.
<br />
## Interface
An image embedding operator takes a towhee image as input.
It uses the pre-trained model specified by model name to generate an image embedding in ndarray.
**Parameters:**
***data:*** *towhee._types.Image*
The decoded image data in towhee Image (a subset of numpy.ndarray).
**Returns:** *numpy.ndarray*
An image embedding generated by model, in shape of (feature_dim,).
For mpvit_tiny model, feature_dim = 216.
For mpvit_xsmall model, feature_dim = 256.
For mpvit_small model, feature_dim = 288.
For mpvit_base model, feature_dim = 480.

20
__init__.py

@ -0,0 +1,20 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .mpvit import MPViT
def mpvit(model_name: str, **kwargs):
return MPViT(model_name, **kwargs)

79
mpvit.py

@ -0,0 +1,79 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import List, Union
from torch import nn
from towhee.types.arg import arg, to_image_color
from timm.data import resolve_data_config, create_transform
from towhee.models import mpvit
from towhee.operator.base import NNOperator
from towhee import register
import towhee
from PIL import Image as PILImage
@register(output_schema=['vec'])
class MPViT(NNOperator):
"""
MPViT embedding operator
"""
def __init__(self, model_name,
num_classes: int = 1000,
weights_path: str = None,
device: str = None,
skip_preprocess: bool = False):
super().__init__()
self.model = mpvit.create_model(model_name=model_name,
num_classes=num_classes,
pretrained=True,
weights_path=weights_path,
device=device)
self.model.eval()
self.config = resolve_data_config({}, model=self.model)
self.tfms = create_transform(**self.config)
self.skip_tfms = skip_preprocess
self.device = device
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]):
if not isinstance(data, list):
imgs = [data]
else:
imgs = data
img_list = []
for img in imgs:
img = self.convert_img(img)
img = img if self.skip_tfms else self.tfms(img)
img_list.append(img)
inputs = torch.stack(img_list)
inputs = inputs.to(self.device)
features = self.model.forward_features(inputs)
global_pool = nn.AdaptiveAvgPool2d(1)
features = global_pool(features)
features = features.to('cpu').flatten(1)
if isinstance(data, list):
vecs = list(features.detach().numpy())
else:
vecs = features.squeeze(0).detach().numpy()
return vecs
@arg(1, to_image_color('RGB'))
def convert_img(self, img: towhee._types.Image):
img = PILImage.fromarray(img.astype('uint8'), 'RGB')
return img

5
requirements.txt

@ -0,0 +1,5 @@
towhee.models >=0.7.0
towhee>=0.7.0
timm
Pillow
einops

BIN
result1.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

BIN
result2.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

BIN
towhee.jpeg

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Loading…
Cancel
Save