# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # adapted from https://github.com/cunjian/pytorch_face_landmark import os import sys from typing import NamedTuple from pathlib import Path import numpy as np import torch from torchvision import transforms from towhee.operator import NNOperator from towhee.types.image_utils import to_pil from towhee._types import Image #import mobilefacenet @register(output_schema=['landmark']) class Mobilefacenet(NNOperator): """ Mobilefacenet """ def __init__(self, framework: str = 'pytorch', pretrained = True): super().__init__(framework=framework) sys.path.append(str(Path(__file__).parent)) from mobilefacenet_impl import MobileFaceNet self.model = MobileFaceNet([112, 112], 136) if pretrained == True: map_location = 'cpu' checkpoint = torch.load( os.path.dirname(__file__) +'/mobilefacenet_model_best.pth', map_location=map_location) self.model.load_state_dict(checkpoint['state_dict']) normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.tfms = transforms.Compose([transforms.Scale(112), transforms.ToTensor(), normalize]) @arg(1, to_image_color('RGB') ) def __call__(self, image: Image): image = to_pil(image) h, w = image.size tensor = self._preprocess(image) if len(tensor.shape) == 3: tensor = torch.unsqueeze(tensor, 0) self.model.eval() landmark = self.model(tensor)[0][0] landmark = landmark.reshape(-1, 2) landmark[:, 0] = landmark[:, 0] * w landmark[:, 1] = landmark[:, 1] * h return np.asarray(landmark.cpu().detach(), dtype=np.int32) def _preprocess(self, image): return self.tfms(image) def _postprocess(self, landmark): pass def train(self): pass