diff --git a/README.md b/README.md index 59fa07f..626d677 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,93 @@ -# data2vec-vision +# Image Embdding with data2vec + +*author: David Wang* + + +
+ + + +## Description + +This operator extracts features for image with [data2vec](https://arxiv.org/abs/2202.03555). The core idea is to predict latent representations of the full input data based on a masked view of the input in a self-distillation setup using a standard Transformer architecture. + +
+ + +## Code Example + +Load an image from path './towhee.jpg' to generate an image embedding. + + *Write the pipeline in simplified style*: + +```python +import towhee + +towhee.glob('./towhee.jpg') \ + .image_decode.cv2() \ + .image_embedding.data2vec_vision(model_name='facebook/data2vec-vision-base-ft1k') \ + .show() + +``` +result1 + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +towhee.glob['path']('./towhee.jpg') \ + .image_decode.cv2['path', 'img']() \ + .image_embedding.data2vec_vision['img', 'vec'](model_name='facebook/data2vec-vision-base-ft1k') \ + .select['img', 'vec']() \ + .show() +``` +result2 + + +
+ + + +## Factory Constructor + +Create the operator via the following factory method + +***data2vec_vision(model_name='facebook/data2vec-vision-base')*** + +**Parameters:** + + +​ ***model_name***: *str* + +The model name in string. +The default value is "facebook/data2vec-vision-base-ft1k". + +Supported model name: +- facebook/data2vec-vision-base-ft1k +- facebook/data2vec-vision-large-ft1k + +
+ + + +## Interface + +An image embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input. +It uses the pre-trained model specified by model name to generate an image embedding in ndarray. + + +**Parameters:** + +​ ***img:*** *towhee.types.Image (a sub-class of numpy.ndarray)* + +​ The decoded image data in towhee.types.Image (numpy.ndarray). + + + +**Returns:** *numpy.ndarray* + +​ The image embedding extracted by model. + + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..5fd45d1 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .data2vec_vision import Data2VecVision + + +def data2vec_vision(model_name='facebook/data2vec-vision-base'): + return Data2VecVision(model_name) diff --git a/data2vec_vision.py b/data2vec_vision.py new file mode 100644 index 0000000..e23b8a9 --- /dev/null +++ b/data2vec_vision.py @@ -0,0 +1,37 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy +import torch +import towhee + +from PIL import Image as PILImage + +from transformers import BeitFeatureExtractor, Data2VecVisionForImageClassification +from towhee.operator.base import NNOperator +from towhee.types.arg import arg, to_image_color + +class Data2VecVision(NNOperator): + def __init__(self, model_name='facebook/data2vec-vision-base'): + self.model = Data2VecVisionForImageClassification.from_pretrained(model_name) + self.feature_extractor = BeitFeatureExtractor.from_pretrained(model_name) + + @arg(1, to_image_color('RGB')) + def __call__(self, img: towhee._types.Image) -> numpy.ndarray: + img = PILImage.fromarray(img.astype('uint8'), 'RGB') + inputs = self.feature_extractor(img, return_tensors="pt") + with torch.no_grad(): + outputs = self.model.data2vec_vision(**inputs).pooler_output + + return outputs.detach().cpu().numpy().flatten() + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0c49ac4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +numpy +transformers>4.19.0 diff --git a/result1.png b/result1.png new file mode 100644 index 0000000..e246440 Binary files /dev/null and b/result1.png differ diff --git a/result2.png b/result2.png new file mode 100644 index 0000000..140e330 Binary files /dev/null and b/result2.png differ