logo
Browse Source

Bugfix and README update

main
Frank Liu 3 years ago
parent
commit
6c2b265dc0
  1. 37
      README.md
  2. 8
      detectron2.py
  3. BIN
      example.jpg

37
README.md

@ -16,11 +16,17 @@ This operator uses Facebook's [Detectron2](https://github.com/facebookresearch/d
import towhee import towhee
towhee.glob('./towhee.jpg') \ towhee.glob('./towhee.jpg') \
.image_decode.cv2() \
.image_decode() \
.object_detection.detectron2(model_name='retinanet_resnet50') \ .object_detection.detectron2(model_name='retinanet_resnet50') \
.show() .show()
``` ```
| Image | `boxes` | `classes` | `scores` |
| ----- | ------- | --------- | -------- |
| ![](example.jpg) | array([2645.9973 , 1200.3245 , 3176.163 , 2722.6785 ]) | array([0]) | array([0.9998573]) |
## Factory Constructor ## Factory Constructor
Create the operator via the following factory method Create the operator via the following factory method
@ -29,10 +35,33 @@ Create the operator via the following factory method
**Parameters:** **Parameters:**
***model_name:*** *str*
***model_name:*** `str`
A string indicating which model to use.
A string indicating which model to use. Available options:
***thresh:*** *float*
1. `faster_rcnn_resnet50_c4`
2. `faster_rcnn_resnet50_dc5`
3. `faster_rcnn_resnet50_fpn`
4. `faster_rcnn_resnet101_c4`
5. `faster_rcnn_resnet101_dc5`
6. `faster_rcnn_resnet101_fpn`
7. `faster_rcnn_resnext101`
8. `retinanet_resnet50`
9. `retinanet_resnet101`
***thresh:*** `float`
The threshold value for which an object is detected (default value: `0.5`). Set this value lower to detect more objects at the expense of accuracy, or higher to reduce the total number of detections but increase the quality of detected objects. The threshold value for which an object is detected (default value: `0.5`). Set this value lower to detect more objects at the expense of accuracy, or higher to reduce the total number of detections but increase the quality of detected objects.
### Interface
This operator takes an image as input. It first detects the objects appeared in the image, and generates a bounding box around each object.
**Parameters:**
**img**: `towhee._types.Image`
Image data wrapped in a (as a Towhee `Image`).
**Return**: `List[numpy.ndarray[4], ...], List[str], numpy.ndarray`
The return value is a tuple of `(boxes, classes, scores)`. `boxes` is a list of bounding boxes. Each bounding box is represented as a 1-dimensional numpy array consisting of the top-left and the bottom-right corners, i.e. `numpy.ndarray([x1, y1, x2, y2])`. `classes` is a list of prediction labels for each bounding box. `*scores*` is a list of confidence scores corresponding to each class and bounding box.

8
detectron2.py

@ -23,8 +23,8 @@ CFG_YAMLS = {
} }
@register(outputschema=['boxes', 'classes', 'scores'],
flag=OperatorFlag.STATELESS | OperatorFlag.REUSABLE)
@register(output_schema=['boxes', 'classes', 'scores'],
flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE)
class Detectron2(NNOperator): class Detectron2(NNOperator):
""" """
This Operator implements object detection using the Detectron2 library. This Operator implements object detection using the Detectron2 library.
@ -34,11 +34,12 @@ class Detectron2(NNOperator):
Detectron2-based model to use. For a full list, see `CFG_YAMLS`. Detectron2-based model to use. For a full list, see `CFG_YAMLS`.
""" """
def __init__(self, model_name: str = 'retinanet_resnet50', thresh: int = 0.5):
def __init__(self, model_name: str = 'faster_rcnn_resnet50_c4', thresh: int = 0.5):
super().__init__() super().__init__()
cfg = get_cfg() cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(CFG_YAMLS[model_name])) cfg.merge_from_file(model_zoo.get_config_file(CFG_YAMLS[model_name]))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(CFG_YAMLS[model_name])
cfg.MODEL.DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' cfg.MODEL.DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = thresh cfg.MODEL.RETINANET.SCORE_THRESH_TEST = thresh
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
@ -51,6 +52,7 @@ class Detectron2(NNOperator):
# Detectron2 uses BGR-formatted images # Detectron2 uses BGR-formatted images
res = self._predictor(image.to_ndarray()[:,:,::-1]) res = self._predictor(image.to_ndarray()[:,:,::-1])
res = res['instances']
boxes = res.get('pred_boxes').tensor.numpy() boxes = res.get('pred_boxes').tensor.numpy()
classes = res.get('pred_classes').numpy() classes = res.get('pred_classes').numpy()
scores = res.get('scores').numpy() scores = res.get('scores').numpy()

BIN
example.jpg

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

Loading…
Cancel
Save