logo
Browse Source

Bugfix and README update

main
Frank Liu 3 years ago
parent
commit
6c2b265dc0
  1. 37
      README.md
  2. 8
      detectron2.py
  3. BIN
      example.jpg

37
README.md

@ -16,11 +16,17 @@ This operator uses Facebook's [Detectron2](https://github.com/facebookresearch/d
import towhee
towhee.glob('./towhee.jpg') \
.image_decode.cv2() \
.image_decode() \
.object_detection.detectron2(model_name='retinanet_resnet50') \
.show()
```
| Image | `boxes` | `classes` | `scores` |
| ----- | ------- | --------- | -------- |
| ![](example.jpg) | array([2645.9973 , 1200.3245 , 3176.163 , 2722.6785 ]) | array([0]) | array([0.9998573]) |
## Factory Constructor
Create the operator via the following factory method
@ -29,10 +35,33 @@ Create the operator via the following factory method
**Parameters:**
***model_name:*** *str*
***model_name:*** `str`
A string indicating which model to use.
A string indicating which model to use. Available options:
***thresh:*** *float*
1. `faster_rcnn_resnet50_c4`
2. `faster_rcnn_resnet50_dc5`
3. `faster_rcnn_resnet50_fpn`
4. `faster_rcnn_resnet101_c4`
5. `faster_rcnn_resnet101_dc5`
6. `faster_rcnn_resnet101_fpn`
7. `faster_rcnn_resnext101`
8. `retinanet_resnet50`
9. `retinanet_resnet101`
***thresh:*** `float`
The threshold value for which an object is detected (default value: `0.5`). Set this value lower to detect more objects at the expense of accuracy, or higher to reduce the total number of detections but increase the quality of detected objects.
### Interface
This operator takes an image as input. It first detects the objects appeared in the image, and generates a bounding box around each object.
**Parameters:**
**img**: `towhee._types.Image`
Image data wrapped in a (as a Towhee `Image`).
**Return**: `List[numpy.ndarray[4], ...], List[str], numpy.ndarray`
The return value is a tuple of `(boxes, classes, scores)`. `boxes` is a list of bounding boxes. Each bounding box is represented as a 1-dimensional numpy array consisting of the top-left and the bottom-right corners, i.e. `numpy.ndarray([x1, y1, x2, y2])`. `classes` is a list of prediction labels for each bounding box. `*scores*` is a list of confidence scores corresponding to each class and bounding box.

8
detectron2.py

@ -23,8 +23,8 @@ CFG_YAMLS = {
}
@register(outputschema=['boxes', 'classes', 'scores'],
flag=OperatorFlag.STATELESS | OperatorFlag.REUSABLE)
@register(output_schema=['boxes', 'classes', 'scores'],
flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE)
class Detectron2(NNOperator):
"""
This Operator implements object detection using the Detectron2 library.
@ -34,11 +34,12 @@ class Detectron2(NNOperator):
Detectron2-based model to use. For a full list, see `CFG_YAMLS`.
"""
def __init__(self, model_name: str = 'retinanet_resnet50', thresh: int = 0.5):
def __init__(self, model_name: str = 'faster_rcnn_resnet50_c4', thresh: int = 0.5):
super().__init__()
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(CFG_YAMLS[model_name]))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(CFG_YAMLS[model_name])
cfg.MODEL.DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = thresh
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
@ -51,6 +52,7 @@ class Detectron2(NNOperator):
# Detectron2 uses BGR-formatted images
res = self._predictor(image.to_ndarray()[:,:,::-1])
res = res['instances']
boxes = res.get('pred_boxes').tensor.numpy()
classes = res.get('pred_classes').numpy()
scores = res.get('scores').numpy()

BIN
example.jpg

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

Loading…
Cancel
Save