From 33deb1c307ebfe19c52db4270865d9b16b490f43 Mon Sep 17 00:00:00 2001 From: wxywb Date: Thu, 31 Mar 2022 14:55:55 +0800 Subject: [PATCH] update the operator. --- README.md | 78 +++++++++++++++++++++- __init__.py | 19 ++++++ pytorch_retinaface_mobilenet_widerface.pth | 3 + requirements.txt | 1 + retinaface.py | 48 +++++++++++++ retinaface_impl.py | 41 ++++++++++++ 6 files changed, 189 insertions(+), 1 deletion(-) create mode 100644 __init__.py create mode 100644 pytorch_retinaface_mobilenet_widerface.pth create mode 100644 requirements.txt create mode 100644 retinaface.py create mode 100644 retinaface_impl.py diff --git a/README.md b/README.md index 22c1c7f..a76829d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,78 @@ -# retinaface +# Retinaface Face Detection (Pytorch) +*Authors: David Wang* + +## Desription + +This opertator detects faces in the images by using RetinaFace Detector[1]. It will returns the bounding box positions and the confidence scores of detected faces. This repo is a adopataion from [2]. + + +```python +from towhee import ops + +model = ops.face_detection.retinaface() +embedding = model(img) +``` + +## Factory Constructor + +Create the operator via the following factory method + +***ops.face_detection.retinaface()*** + + +## Interface + +A face detection operator takes an image as input. it generates the bounding box position and confidence score back to ndarray. + +**Args:** + +​ ***framework*** + +​ the framework of the model + +​ supported types: `str`, default is 'pytorch' + +**Parameters:** + +​ ***image*** + +​ the image to detect faces. + +​ supported types: numpy.ndarray + +**Returns:**: *numpy.ndarray* + +​ The detected face bounding boxes. + +## Code Example + +get detected face bounding boxes from './img1.jpg'. + + *Write the pipeline in simplified style*: + +```python +import towhee.DataCollection as dc + +dc.glob('./img1.jpg') + .image_decode.cv2() + .face_detection.retinaface() + .to_list() +``` + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee.DataCollection as dc + +dc.glob['path']('./img1.jpg') + .image_decode.cv2['path', 'img']() + .face_detection.retinaface() + .to_list() +``` + + +## Reference + +[1]. https://arxiv.org/abs/1905.00641 +[2]. https://github.com/biubug6/Pytorch_Retinaface diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..98ce600 --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .retinaface import Retinaface + +def retinaface(): + return Retinaface() + diff --git a/pytorch_retinaface_mobilenet_widerface.pth b/pytorch_retinaface_mobilenet_widerface.pth new file mode 100644 index 0000000..287f82f --- /dev/null +++ b/pytorch_retinaface_mobilenet_widerface.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b1e031a840c18c24b3f5229c2971285eb4f897cf05f6568f955ae1110b60c88 +size 1821417 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..08ed5ee --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +torch \ No newline at end of file diff --git a/retinaface.py b/retinaface.py new file mode 100644 index 0000000..58ab283 --- /dev/null +++ b/retinaface.py @@ -0,0 +1,48 @@ +# Copyright 2022 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +from typing import NamedTuple, List + +from PIL import Image +import torch +from torchvision import transforms +from pathlib import Path +import numpy + +import towhee +from towhee.operator import Operator +from towhee.types.image_utils import to_pil +from towhee._types import Image + +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform + +@register(output_schema=['bboxes', 'scores']) +class Retinaface(Operator): + """ + Retinaface + """ + def __init__(self, framework: str = 'pytorch') -> None: + super().__init__() + sys.path.append(str(Path(__file__).parent)) + from retinaface_impl import Model + self.model = Model() + + @arg(1, to_image_color('RGB') ) + def __call__(self, image: 'towhee._types.Image'): + img = torch.FloatTensor(numpy.asarray(to_pil(image))) + bboxes, keypoints = self.model(img) + return bboxes[:,:4], bboxes[:,4] diff --git a/retinaface_impl.py b/retinaface_impl.py new file mode 100644 index 0000000..4921173 --- /dev/null +++ b/retinaface_impl.py @@ -0,0 +1,41 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch + +from towhee.models.retina_face.retinaface import RetinaFace +from towhee.models.retina_face.configs import build_configs +from towhee.models.utils.pretrained_utils import load_pretrained_weights + +class Model: + """ + Pytorch model class + """ + def __init__(self): + model_name = 'cfg_mnet' + cfg = build_configs(model_name) + self._model = RetinaFace(cfg=cfg, phase='test') + load_pretrained_weights(self._model, 'mnet', None, os.path.dirname(__file__) + '/pytorch_retinaface_mobilenet_widerface.pth') + self._model.eval() + + def __call__(self, img_tensor: torch.Tensor): + outputs = self._model.inference(img_tensor) + return outputs + + def train(self): + """ + For training model + """ + pass