retinaface
copied
wxywb
3 years ago
6 changed files with 189 additions and 1 deletions
@ -1,2 +1,78 @@ |
|||||
# retinaface |
|
||||
|
# Retinaface Face Detection (Pytorch) |
||||
|
|
||||
|
*Authors: David Wang* |
||||
|
|
||||
|
## Desription |
||||
|
|
||||
|
This opertator detects faces in the images by using RetinaFace Detector[1]. It will returns the bounding box positions and the confidence scores of detected faces. This repo is a adopataion from [2]. |
||||
|
|
||||
|
|
||||
|
```python |
||||
|
from towhee import ops |
||||
|
|
||||
|
model = ops.face_detection.retinaface() |
||||
|
embedding = model(img) |
||||
|
``` |
||||
|
|
||||
|
## Factory Constructor |
||||
|
|
||||
|
Create the operator via the following factory method |
||||
|
|
||||
|
***ops.face_detection.retinaface()*** |
||||
|
|
||||
|
|
||||
|
## Interface |
||||
|
|
||||
|
A face detection operator takes an image as input. it generates the bounding box position and confidence score back to ndarray. |
||||
|
|
||||
|
**Args:** |
||||
|
|
||||
|
***framework*** |
||||
|
|
||||
|
the framework of the model |
||||
|
|
||||
|
supported types: `str`, default is 'pytorch' |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
***image*** |
||||
|
|
||||
|
the image to detect faces. |
||||
|
|
||||
|
supported types: numpy.ndarray |
||||
|
|
||||
|
**Returns:**: *numpy.ndarray* |
||||
|
|
||||
|
The detected face bounding boxes. |
||||
|
|
||||
|
## Code Example |
||||
|
|
||||
|
get detected face bounding boxes from './img1.jpg'. |
||||
|
|
||||
|
*Write the pipeline in simplified style*: |
||||
|
|
||||
|
```python |
||||
|
import towhee.DataCollection as dc |
||||
|
|
||||
|
dc.glob('./img1.jpg') |
||||
|
.image_decode.cv2() |
||||
|
.face_detection.retinaface() |
||||
|
.to_list() |
||||
|
``` |
||||
|
|
||||
|
*Write a same pipeline with explicit inputs/outputs name specifications:* |
||||
|
|
||||
|
```python |
||||
|
import towhee.DataCollection as dc |
||||
|
|
||||
|
dc.glob['path']('./img1.jpg') |
||||
|
.image_decode.cv2['path', 'img']() |
||||
|
.face_detection.retinaface() |
||||
|
.to_list() |
||||
|
``` |
||||
|
|
||||
|
|
||||
|
## Reference |
||||
|
|
||||
|
[1]. https://arxiv.org/abs/1905.00641 |
||||
|
[2]. https://github.com/biubug6/Pytorch_Retinaface |
||||
|
@ -0,0 +1,19 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .retinaface import Retinaface |
||||
|
|
||||
|
def retinaface(): |
||||
|
return Retinaface() |
||||
|
|
Binary file not shown.
@ -0,0 +1 @@ |
|||||
|
torch |
@ -0,0 +1,48 @@ |
|||||
|
# Copyright 2022 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
import sys |
||||
|
import os |
||||
|
from typing import NamedTuple, List |
||||
|
|
||||
|
from PIL import Image |
||||
|
import torch |
||||
|
from torchvision import transforms |
||||
|
from pathlib import Path |
||||
|
import numpy |
||||
|
|
||||
|
import towhee |
||||
|
from towhee.operator import Operator |
||||
|
from towhee.types.image_utils import to_pil |
||||
|
from towhee._types import Image |
||||
|
|
||||
|
from timm.data import resolve_data_config |
||||
|
from timm.data.transforms_factory import create_transform |
||||
|
|
||||
|
@register(output_schema=['bboxes', 'scores']) |
||||
|
class Retinaface(Operator): |
||||
|
""" |
||||
|
Retinaface |
||||
|
""" |
||||
|
def __init__(self, framework: str = 'pytorch') -> None: |
||||
|
super().__init__() |
||||
|
sys.path.append(str(Path(__file__).parent)) |
||||
|
from retinaface_impl import Model |
||||
|
self.model = Model() |
||||
|
|
||||
|
@arg(1, to_image_color('RGB') ) |
||||
|
def __call__(self, image: 'towhee._types.Image'): |
||||
|
img = torch.FloatTensor(numpy.asarray(to_pil(image))) |
||||
|
bboxes, keypoints = self.model(img) |
||||
|
return bboxes[:,:4], bboxes[:,4] |
@ -0,0 +1,41 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
import os |
||||
|
|
||||
|
import torch |
||||
|
|
||||
|
from towhee.models.retina_face.retinaface import RetinaFace |
||||
|
from towhee.models.retina_face.configs import build_configs |
||||
|
from towhee.models.utils.pretrained_utils import load_pretrained_weights |
||||
|
|
||||
|
class Model: |
||||
|
""" |
||||
|
Pytorch model class |
||||
|
""" |
||||
|
def __init__(self): |
||||
|
model_name = 'cfg_mnet' |
||||
|
cfg = build_configs(model_name) |
||||
|
self._model = RetinaFace(cfg=cfg, phase='test') |
||||
|
load_pretrained_weights(self._model, 'mnet', None, os.path.dirname(__file__) + '/pytorch_retinaface_mobilenet_widerface.pth') |
||||
|
self._model.eval() |
||||
|
|
||||
|
def __call__(self, img_tensor: torch.Tensor): |
||||
|
outputs = self._model.inference(img_tensor) |
||||
|
return outputs |
||||
|
|
||||
|
def train(self): |
||||
|
""" |
||||
|
For training model |
||||
|
""" |
||||
|
pass |
Loading…
Reference in new issue