# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy import torch import towhee from PIL import Image as PILImage from transformers import BeitFeatureExtractor, Data2VecVisionForImageClassification from towhee.operator.base import NNOperator from towhee.types.arg import arg, to_image_color class Data2VecVision(NNOperator): def __init__(self, model_name='facebook/data2vec-vision-base'): self.model = Data2VecVisionForImageClassification.from_pretrained(model_name) self.feature_extractor = BeitFeatureExtractor.from_pretrained(model_name) @arg(1, to_image_color('RGB')) def __call__(self, img: 'towhee.types.Image') -> numpy.ndarray: img = PILImage.fromarray(img.astype('uint8'), 'RGB') inputs = self.feature_extractor(img, return_tensors="pt") with torch.no_grad(): outputs = self.model.data2vec_vision(**inputs).pooler_output return outputs.detach().cpu().numpy().flatten()