# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import NamedTuple import numpy as np import torch import cv2 from facenet_pytorch import InceptionResnetV1 from towhee import register from towhee.operator import NNOperator from towhee.types.image_utils import to_pil from towhee._types import Image from towhee.types import arg, to_image_color @register(output_schema=['vec']) class Inceptionresnetv1(NNOperator): """ comment placeholder """ def __init__(self, image_size = 160): self.image_size = image_size self._model = InceptionResnetV1(pretrained='vggface2') self._model.eval() @arg(1, to_image_color('RGB') ) def __call__(self, img: Image) -> np.ndarray: img = self.preprocess(img) embs = self._model(torch.FloatTensor(img).permute(0,3,1,2)).detach().numpy() return embs def preprocess(self, img: Image): #img shape expected to be [n, h, w, c=3] or [h, w, c=3]. if len(img.shape) == 3: img = cv2.resize(img, (self.image_size, self.image_size)) img = np.expand_dims(img, 0) elif len(img.shape) == 4: pass else: raise ValueError('unknown tensor shape, need to be [n, h, w, c=3] or [h, w, c=3].') img = self._fixed_image_standardization(img) return img def _fixed_image_standardization(self, image_tensor): processed_tensor = (image_tensor - 127.5) / 128.0 return processed_tensor def train(self): """ For training model """ pass