# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch from towhee.models.retina_face.retinaface import RetinaFace from towhee.models.retina_face.configs import build_configs from towhee.models.utils.pretrained_utils import load_pretrained_weights class Model: """ Pytorch model class """ def __init__(self): model_name = 'cfg_mnet' cfg = build_configs(model_name) self._model = RetinaFace(cfg=cfg, phase='test') load_pretrained_weights(self._model, 'mnet', None, os.path.dirname(__file__) + '/pytorch_retinaface_mobilenet_widerface.pth') self._model.eval() def __call__(self, img_tensor: torch.Tensor): outputs = self._model.inference(img_tensor) return outputs def train(self): """ For training model """ pass