from typing import List, Tuple from detectron2 import model_zoo from detectron2.config import get_cfg from detectron2.engine.defaults import DefaultPredictor import numpy as np import torch from towhee._types import Image from towhee.operator import NNOperator, OperatorFlag from towhee import register CFG_YAMLS = { 'faster_rcnn_resnet50_c4': 'COCO-Detection/faster_rcnn_R_50_C4_3x.yaml', 'faster_rcnn_resnet50_dc5': 'COCO-Detection/faster_rcnn_R_50_DC5_3x.yaml', 'faster_rcnn_resnet50_fpn': 'COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml', 'faster_rcnn_resnet101_c4': 'COCO-Detection/faster_rcnn_R_101_C4_3x.yaml', 'faster_rcnn_resnet101_dc5': 'COCO-Detection/faster_rcnn_R_101_DC5_3x.yaml', 'faster_rcnn_resnet101_fpn': 'COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml', 'faster_rcnn_resnext101': 'COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml', 'retinanet_resnet50': 'COCO-Detection/retinanet_R_50_FPN_3x.yaml', 'retinanet_resnet101': 'COCO-Detection/retinanet_R_101_FPN_3x.yaml' } @register(output_schema=['boxes', 'classes', 'scores'], flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE) class Detectron2(NNOperator): """ This Operator implements object detection using the Detectron2 library. Args: model_name (`str`): Detectron2-based model to use. For a full list, see `CFG_YAMLS`. """ def __init__(self, model_name: str = 'faster_rcnn_resnet50_c4', thresh: int = 0.5): super().__init__() cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file(CFG_YAMLS[model_name])) cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(CFG_YAMLS[model_name]) cfg.MODEL.DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' cfg.MODEL.RETINANET.SCORE_THRESH_TEST = thresh cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5 cfg.freeze() self._predictor = DefaultPredictor(cfg) def __call__(self, image: 'towhee._types.Image') -> Tuple[List]: # Detectron2 uses BGR-formatted images res = self._predictor(image[:,:,::-1]) res = res['instances'] boxes = res.get('pred_boxes').tensor.numpy() classes = res.get('pred_classes').numpy() scores = res.get('scores').numpy() return (boxes, classes, scores)