lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
120 lines
4.1 KiB
120 lines
4.1 KiB
import pickle
|
|
|
|
import numpy as np
|
|
import torch as t
|
|
import torch
|
|
from torch.nn import functional as F
|
|
from torch import nn
|
|
import torchvision
|
|
from torch import nn
|
|
|
|
from .generate_anchors import generate_anchors
|
|
from .bbox_transform import bbox_transform_inv, clip_boxes
|
|
|
|
|
|
class RegionProposalNetwork(nn.Module):
|
|
def __init__(self, pre_nms_topN,post_nms_topN, nms_thresh, min_size, anchor_scales, feat_stride):
|
|
super(RegionProposalNetwork, self).__init__()
|
|
self._anchors = generate_anchors(scales=np.array(anchor_scales))
|
|
self._num_anchors = self._anchors.shape[0]
|
|
self._feat_stride = feat_stride
|
|
self.pre_nms_topN = pre_nms_topN
|
|
self.post_nms_topN = post_nms_topN
|
|
self.nms_thresh = nms_thresh
|
|
self.min_size = min_size
|
|
self.anchor_scales = anchor_scales
|
|
self.feat_stride = feat_stride
|
|
|
|
|
|
def forward(self, rpn_cls_prob, rpn_bbox_pred, img_size):
|
|
scores = rpn_cls_prob[:,self._num_anchors:, :, :]
|
|
bbox_deltas = rpn_bbox_pred
|
|
min_size = self.min_size
|
|
|
|
pre_nms_topN = self.pre_nms_topN
|
|
post_nms_topN = self.post_nms_topN
|
|
nms_thresh = self.nms_thresh
|
|
min_size = self.min_size
|
|
anchor_scales = self.anchor_scales
|
|
feat_stride = self.feat_stride
|
|
|
|
n, _, hh, ww = scores.shape
|
|
|
|
#n_anchor = anchor.shape[0] // (hh * ww)
|
|
|
|
anchors = self._enumerate_shifted_anchor(self._anchors, self._feat_stride, hh, ww)
|
|
bbox_deltas = bbox_deltas.permute((0, 2, 3, 1)).reshape((-1, 4))
|
|
|
|
bbox_deltas = bbox_deltas.cpu().detach().numpy()
|
|
proposals = bbox_transform_inv(anchors, bbox_deltas)
|
|
scores = scores.permute((0, 2, 3, 1)).reshape((-1, 1))
|
|
|
|
proposals = bbox_transform_inv(anchors, bbox_deltas)
|
|
|
|
proposals = clip_boxes(proposals, img_size[:2])
|
|
|
|
keep = _filter_boxes(proposals, min_size * img_size[2])
|
|
proposals = proposals[keep, :]
|
|
scores = scores[keep]
|
|
|
|
order = scores.ravel().argsort(descending=True)
|
|
proposals = t.FloatTensor(proposals, device = rpn_cls_prob.device)
|
|
|
|
if pre_nms_topN > 0:
|
|
order = order[:pre_nms_topN]
|
|
|
|
proposals = proposals[order, :]
|
|
scores = scores[order]
|
|
|
|
keep = torchvision.ops.nms(proposals, scores.ravel(), nms_thresh)
|
|
|
|
if post_nms_topN > 0:
|
|
keep = keep[:post_nms_topN]
|
|
proposals = proposals[keep, :]
|
|
scores = scores[keep]
|
|
|
|
batch_inds = t.zeros((proposals.shape[0], 1), dtype = proposals.dtype, device = proposals.device)
|
|
rois = t.hstack([batch_inds, proposals])
|
|
|
|
return rois
|
|
|
|
#keep = nms(np.hstack((proposals, scores)), nms_thresh)
|
|
#if post_nms_topN > 0:
|
|
# keep = keep[:post_nms_topN]
|
|
#proposals = proposals[keep, :]
|
|
#scores = scores[keep]
|
|
|
|
def _enumerate_shifted_anchor(self, anchor_base, feat_stride, height, width):
|
|
# Enumerate all shifted anchors:
|
|
#
|
|
# add A anchors (1, A, 4) to
|
|
# cell K shifts (K, 1, 4) to get
|
|
# shift anchors (K, A, 4)
|
|
# reshape to (K*A, 4) shifted anchors
|
|
# return (K*A, 4)
|
|
|
|
# !TODO: add support for torch.CudaTensor
|
|
# xp = cuda.get_array_module(anchor_base)
|
|
# it seems that it can't be boosed using GPU
|
|
import numpy as xp
|
|
shift_y = xp.arange(0, height * feat_stride, feat_stride)
|
|
shift_x = xp.arange(0, width * feat_stride, feat_stride)
|
|
shift_x, shift_y = xp.meshgrid(shift_x, shift_y)
|
|
shift = xp.stack((shift_x.ravel(), shift_y.ravel(),
|
|
shift_x.ravel(), shift_y.ravel()), axis=1)
|
|
|
|
A = anchor_base.shape[0]
|
|
K = shift.shape[0]
|
|
anchor = anchor_base.reshape((1, A, 4)) + \
|
|
shift.reshape((1, K, 4)).transpose((1, 0, 2))
|
|
anchor = anchor.reshape((K * A, 4)).astype(np.float32)
|
|
|
|
return anchor
|
|
|
|
def _filter_boxes(boxes, min_size):
|
|
"""Remove all boxes with any side smaller than min_size."""
|
|
ws = boxes[:, 2] - boxes[:, 0] + 1
|
|
hs = boxes[:, 3] - boxes[:, 1] + 1
|
|
keep = np.where((ws >= min_size) & (hs >= min_size))[0]
|
|
return keep
|
|
|