|
@ -22,10 +22,11 @@ from torchvision import transforms |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
import numpy |
|
|
import numpy |
|
|
|
|
|
|
|
|
import towhee |
|
|
|
|
|
|
|
|
from towhee import register |
|
|
from towhee.operator import Operator |
|
|
from towhee.operator import Operator |
|
|
from towhee.types.image_utils import to_pil |
|
|
from towhee.types.image_utils import to_pil |
|
|
from towhee._types import Image |
|
|
from towhee._types import Image |
|
|
|
|
|
from towhee.types import arg, to_image_color |
|
|
|
|
|
|
|
|
from timm.data import resolve_data_config |
|
|
from timm.data import resolve_data_config |
|
|
from timm.data.transforms_factory import create_transform |
|
|
from timm.data.transforms_factory import create_transform |
|
@ -45,4 +46,5 @@ class Retinaface(Operator): |
|
|
def __call__(self, image: 'towhee._types.Image'): |
|
|
def __call__(self, image: 'towhee._types.Image'): |
|
|
img = torch.FloatTensor(numpy.asarray(to_pil(image))) |
|
|
img = torch.FloatTensor(numpy.asarray(to_pil(image))) |
|
|
bboxes, keypoints = self.model(img) |
|
|
bboxes, keypoints = self.model(img) |
|
|
return bboxes[:,:4], bboxes[:,4] |
|
|
|
|
|
|
|
|
bboxes = bboxes.cpu().detach().numpy() |
|
|
|
|
|
return bboxes[:,:4].astype(int), bboxes[:,4] |
|
|