diff --git a/README.md b/README.md index a450f76..5cbda02 100644 --- a/README.md +++ b/README.md @@ -16,11 +16,17 @@ This operator uses Facebook's [Detectron2](https://github.com/facebookresearch/d import towhee towhee.glob('./towhee.jpg') \ - .image_decode.cv2() \ + .image_decode() \ .object_detection.detectron2(model_name='retinanet_resnet50') \ .show() ``` +| Image | `boxes` | `classes` | `scores` | +| ----- | ------- | --------- | -------- | +| ![](example.jpg) | array([2645.9973 , 1200.3245 , 3176.163 , 2722.6785 ]) | array([0]) | array([0.9998573]) | + + + ## Factory Constructor Create the operator via the following factory method @@ -29,10 +35,33 @@ Create the operator via the following factory method **Parameters:** -***model_name:*** *str* +***model_name:*** `str` -A string indicating which model to use. +A string indicating which model to use. Available options: -***thresh:*** *float* +1. `faster_rcnn_resnet50_c4` +2. `faster_rcnn_resnet50_dc5` +3. `faster_rcnn_resnet50_fpn` +4. `faster_rcnn_resnet101_c4` +5. `faster_rcnn_resnet101_dc5` +6. `faster_rcnn_resnet101_fpn` +7. `faster_rcnn_resnext101` +8. `retinanet_resnet50` +9. `retinanet_resnet101` + +***thresh:*** `float` The threshold value for which an object is detected (default value: `0.5`). Set this value lower to detect more objects at the expense of accuracy, or higher to reduce the total number of detections but increase the quality of detected objects. + +### Interface + +This operator takes an image as input. It first detects the objects appeared in the image, and generates a bounding box around each object. + +**Parameters:** + +​ **img**: `towhee._types.Image` + Image data wrapped in a (as a Towhee `Image`). + +**Return**: `List[numpy.ndarray[4], ...], List[str], numpy.ndarray` + +The return value is a tuple of `(boxes, classes, scores)`. `boxes` is a list of bounding boxes. Each bounding box is represented as a 1-dimensional numpy array consisting of the top-left and the bottom-right corners, i.e. `numpy.ndarray([x1, y1, x2, y2])`. `classes` is a list of prediction labels for each bounding box. `*scores*` is a list of confidence scores corresponding to each class and bounding box. diff --git a/detectron2.py b/detectron2.py index a3f64a3..52e4c93 100644 --- a/detectron2.py +++ b/detectron2.py @@ -23,8 +23,8 @@ CFG_YAMLS = { } -@register(outputschema=['boxes', 'classes', 'scores'], - flag=OperatorFlag.STATELESS | OperatorFlag.REUSABLE) +@register(output_schema=['boxes', 'classes', 'scores'], + flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE) class Detectron2(NNOperator): """ This Operator implements object detection using the Detectron2 library. @@ -34,11 +34,12 @@ class Detectron2(NNOperator): Detectron2-based model to use. For a full list, see `CFG_YAMLS`. """ - def __init__(self, model_name: str = 'retinanet_resnet50', thresh: int = 0.5): + def __init__(self, model_name: str = 'faster_rcnn_resnet50_c4', thresh: int = 0.5): super().__init__() cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file(CFG_YAMLS[model_name])) + cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(CFG_YAMLS[model_name]) cfg.MODEL.DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' cfg.MODEL.RETINANET.SCORE_THRESH_TEST = thresh cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 @@ -51,6 +52,7 @@ class Detectron2(NNOperator): # Detectron2 uses BGR-formatted images res = self._predictor(image.to_ndarray()[:,:,::-1]) + res = res['instances'] boxes = res.get('pred_boxes').tensor.numpy() classes = res.get('pred_classes').numpy() scores = res.get('scores').numpy() diff --git a/example.jpg b/example.jpg new file mode 100644 index 0000000..f36248d Binary files /dev/null and b/example.jpg differ