detectron2
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
61 lines
2.4 KiB
61 lines
2.4 KiB
from typing import List, Tuple
|
|
import os
|
|
import detectron2
|
|
from detectron2 import model_zoo
|
|
from detectron2.config import get_cfg
|
|
from detectron2.engine.defaults import DefaultPredictor
|
|
import numpy as np
|
|
import torch
|
|
from towhee._types import Image
|
|
from towhee.operator import NNOperator, OperatorFlag
|
|
from towhee import register
|
|
|
|
|
|
CFG_YAMLS = {
|
|
'faster_rcnn_resnet50_c4': 'COCO-Detection/faster_rcnn_R_50_C4_3x.yaml',
|
|
'faster_rcnn_resnet50_dc5': 'COCO-Detection/faster_rcnn_R_50_DC5_3x.yaml',
|
|
'faster_rcnn_resnet50_fpn': 'COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml',
|
|
'faster_rcnn_resnet101_c4': 'COCO-Detection/faster_rcnn_R_101_C4_3x.yaml',
|
|
'faster_rcnn_resnet101_dc5': 'COCO-Detection/faster_rcnn_R_101_DC5_3x.yaml',
|
|
'faster_rcnn_resnet101_fpn': 'COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml',
|
|
'faster_rcnn_resnext101': 'COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml',
|
|
'retinanet_resnet50': 'COCO-Detection/retinanet_R_50_FPN_3x.yaml',
|
|
'retinanet_resnet101': 'COCO-Detection/retinanet_R_101_FPN_3x.yaml'
|
|
}
|
|
|
|
|
|
@register(output_schema=['boxes', 'classes', 'scores'],
|
|
flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE)
|
|
class Detectron2(NNOperator):
|
|
"""
|
|
This Operator implements object detection using the Detectron2 library.
|
|
|
|
Args:
|
|
model_name (`str`):
|
|
Detectron2-based model to use. For a full list, see `CFG_YAMLS`.
|
|
"""
|
|
|
|
def __init__(self, model_name: str = 'faster_rcnn_resnet50_c4', thresh: int = 0.5):
|
|
super().__init__()
|
|
|
|
cfg = get_cfg()
|
|
cfg.merge_from_file(model_zoo.get_config_file(CFG_YAMLS[model_name]))
|
|
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(CFG_YAMLS[model_name])
|
|
cfg.MODEL.DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = thresh
|
|
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
|
|
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5
|
|
cfg.freeze()
|
|
|
|
self._predictor = DefaultPredictor(cfg)
|
|
|
|
def __call__(self, image: 'towhee._types.Image') -> Tuple[List]:
|
|
|
|
# Detectron2 uses BGR-formatted images
|
|
res = self._predictor(image[:,:,::-1])
|
|
res = res['instances']
|
|
boxes = res.get('pred_boxes').tensor.cpu().numpy()
|
|
classes = res.get('pred_classes').cpu().numpy()
|
|
scores = res.get('scores').cpu().numpy()
|
|
|
|
return (boxes, classes, scores)
|