logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

66 lines
2.5 KiB

3 years ago
from typing import List, Tuple
import os
try:
import detectron2
except ModuleNotFoundError:
os.system("git clone https://github.com/facebookresearch/detectron2.git")
os.system("python -m pip install -e detectron2")
3 years ago
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine.defaults import DefaultPredictor
import numpy as np
import torch
from towhee._types import Image
from towhee.operator import NNOperator, OperatorFlag
3 years ago
from towhee import register
3 years ago
CFG_YAMLS = {
'faster_rcnn_resnet50_c4': 'COCO-Detection/faster_rcnn_R_50_C4_3x.yaml',
'faster_rcnn_resnet50_dc5': 'COCO-Detection/faster_rcnn_R_50_DC5_3x.yaml',
'faster_rcnn_resnet50_fpn': 'COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml',
'faster_rcnn_resnet101_c4': 'COCO-Detection/faster_rcnn_R_101_C4_3x.yaml',
'faster_rcnn_resnet101_dc5': 'COCO-Detection/faster_rcnn_R_101_DC5_3x.yaml',
'faster_rcnn_resnet101_fpn': 'COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml',
'faster_rcnn_resnext101': 'COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x.yaml',
'retinanet_resnet50': 'COCO-Detection/retinanet_R_50_FPN_3x.yaml',
'retinanet_resnet101': 'COCO-Detection/retinanet_R_101_FPN_3x.yaml'
}
@register(output_schema=['boxes', 'classes', 'scores'],
flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE)
3 years ago
class Detectron2(NNOperator):
3 years ago
"""
This Operator implements object detection using the Detectron2 library.
Args:
model_name (`str`):
Detectron2-based model to use. For a full list, see `CFG_YAMLS`.
"""
def __init__(self, model_name: str = 'faster_rcnn_resnet50_c4', thresh: int = 0.5):
3 years ago
super().__init__()
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(CFG_YAMLS[model_name]))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(CFG_YAMLS[model_name])
3 years ago
cfg.MODEL.DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = thresh
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5
cfg.freeze()
self._predictor = DefaultPredictor(cfg)
def __call__(self, image: 'towhee._types.Image') -> Tuple[List]:
# Detectron2 uses BGR-formatted images
res = self._predictor(image[:,:,::-1])
res = res['instances']
3 years ago
boxes = res.get('pred_boxes').tensor.numpy()
classes = res.get('pred_classes').numpy()
scores = res.get('scores').numpy()
return (boxes, classes, scores)