logo
Browse Source

update the operator.

main
wxywb 3 years ago
parent
commit
33deb1c307
  1. 78
      README.md
  2. 19
      __init__.py
  3. BIN
      pytorch_retinaface_mobilenet_widerface.pth
  4. 1
      requirements.txt
  5. 48
      retinaface.py
  6. 41
      retinaface_impl.py

78
README.md

@ -1,2 +1,78 @@
# retinaface
# Retinaface Face Detection (Pytorch)
*Authors: David Wang*
## Desription
This opertator detects faces in the images by using RetinaFace Detector[1]. It will returns the bounding box positions and the confidence scores of detected faces. This repo is a adopataion from [2].
```python
from towhee import ops
model = ops.face_detection.retinaface()
embedding = model(img)
```
## Factory Constructor
Create the operator via the following factory method
***ops.face_detection.retinaface()***
## Interface
A face detection operator takes an image as input. it generates the bounding box position and confidence score back to ndarray.
**Args:**
***framework***
​ the framework of the model
​ supported types: `str`, default is 'pytorch'
**Parameters:**
***image***
​ the image to detect faces.
​ supported types: numpy.ndarray
**Returns:**: *numpy.ndarray*
​ The detected face bounding boxes.
## Code Example
get detected face bounding boxes from './img1.jpg'.
*Write the pipeline in simplified style*:
```python
import towhee.DataCollection as dc
dc.glob('./img1.jpg')
.image_decode.cv2()
.face_detection.retinaface()
.to_list()
```
*Write a same pipeline with explicit inputs/outputs name specifications:*
```python
import towhee.DataCollection as dc
dc.glob['path']('./img1.jpg')
.image_decode.cv2['path', 'img']()
.face_detection.retinaface()
.to_list()
```
## Reference
[1]. https://arxiv.org/abs/1905.00641
[2]. https://github.com/biubug6/Pytorch_Retinaface

19
__init__.py

@ -0,0 +1,19 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .retinaface import Retinaface
def retinaface():
return Retinaface()

BIN
pytorch_retinaface_mobilenet_widerface.pth (Stored with Git LFS)

Binary file not shown.

1
requirements.txt

@ -0,0 +1 @@
torch

48
retinaface.py

@ -0,0 +1,48 @@
# Copyright 2022 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
from typing import NamedTuple, List
from PIL import Image
import torch
from torchvision import transforms
from pathlib import Path
import numpy
import towhee
from towhee.operator import Operator
from towhee.types.image_utils import to_pil
from towhee._types import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
@register(output_schema=['bboxes', 'scores'])
class Retinaface(Operator):
"""
Retinaface
"""
def __init__(self, framework: str = 'pytorch') -> None:
super().__init__()
sys.path.append(str(Path(__file__).parent))
from retinaface_impl import Model
self.model = Model()
@arg(1, to_image_color('RGB') )
def __call__(self, image: 'towhee._types.Image'):
img = torch.FloatTensor(numpy.asarray(to_pil(image)))
bboxes, keypoints = self.model(img)
return bboxes[:,:4], bboxes[:,4]

41
retinaface_impl.py

@ -0,0 +1,41 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from towhee.models.retina_face.retinaface import RetinaFace
from towhee.models.retina_face.configs import build_configs
from towhee.models.utils.pretrained_utils import load_pretrained_weights
class Model:
"""
Pytorch model class
"""
def __init__(self):
model_name = 'cfg_mnet'
cfg = build_configs(model_name)
self._model = RetinaFace(cfg=cfg, phase='test')
load_pretrained_weights(self._model, 'mnet', None, os.path.dirname(__file__) + '/pytorch_retinaface_mobilenet_widerface.pth')
self._model.eval()
def __call__(self, img_tensor: torch.Tensor):
outputs = self._model.inference(img_tensor)
return outputs
def train(self):
"""
For training model
"""
pass
Loading…
Cancel
Save