diff --git a/README.md b/README.md index a76829d..ac1d8e3 100644 --- a/README.md +++ b/README.md @@ -41,10 +41,16 @@ A face detection operator takes an image as input. it generates the bounding box ​ supported types: numpy.ndarray -**Returns:**: *numpy.ndarray* +**Returns:**: + +*numpy.ndarray* ​ The detected face bounding boxes. +*numpy.ndarray* + +​ The detected face bounding boxes confident scores. + ## Code Example get detected face bounding boxes from './img1.jpg'. diff --git a/retinaface.py b/retinaface.py index 58ab283..aab7280 100644 --- a/retinaface.py +++ b/retinaface.py @@ -22,10 +22,11 @@ from torchvision import transforms from pathlib import Path import numpy -import towhee +from towhee import register from towhee.operator import Operator from towhee.types.image_utils import to_pil from towhee._types import Image +from towhee.types import arg, to_image_color from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform @@ -45,4 +46,5 @@ class Retinaface(Operator): def __call__(self, image: 'towhee._types.Image'): img = torch.FloatTensor(numpy.asarray(to_pil(image))) bboxes, keypoints = self.model(img) - return bboxes[:,:4], bboxes[:,4] + bboxes = bboxes.cpu().detach().numpy() + return bboxes[:,:4].astype(int), bboxes[:,4]