import pickle import numpy as np import torch as t import torch from torch.nn import functional as F from torch import nn import torchvision from torch import nn from .generate_anchors import generate_anchors from .bbox_transform import bbox_transform_inv, clip_boxes class RegionProposalNetwork(nn.Module): def __init__(self, pre_nms_topN,post_nms_topN, nms_thresh, min_size, anchor_scales, feat_stride): super(RegionProposalNetwork, self).__init__() self._anchors = generate_anchors(scales=np.array(anchor_scales)) self._num_anchors = self._anchors.shape[0] self._feat_stride = feat_stride self.pre_nms_topN = pre_nms_topN self.post_nms_topN = post_nms_topN self.nms_thresh = nms_thresh self.min_size = min_size self.anchor_scales = anchor_scales self.feat_stride = feat_stride def forward(self, rpn_cls_prob, rpn_bbox_pred, img_size): scores = rpn_cls_prob[:,self._num_anchors:, :, :] bbox_deltas = rpn_bbox_pred min_size = self.min_size pre_nms_topN = self.pre_nms_topN post_nms_topN = self.post_nms_topN nms_thresh = self.nms_thresh min_size = self.min_size anchor_scales = self.anchor_scales feat_stride = self.feat_stride n, _, hh, ww = scores.shape #n_anchor = anchor.shape[0] // (hh * ww) anchors = self._enumerate_shifted_anchor(self._anchors, self._feat_stride, hh, ww) bbox_deltas = bbox_deltas.permute((0, 2, 3, 1)).reshape((-1, 4)) bbox_deltas = bbox_deltas.cpu().detach().numpy() proposals = bbox_transform_inv(anchors, bbox_deltas) scores = scores.permute((0, 2, 3, 1)).reshape((-1, 1)) proposals = bbox_transform_inv(anchors, bbox_deltas) proposals = clip_boxes(proposals, img_size[:2]) keep = _filter_boxes(proposals, min_size * img_size[2]) proposals = proposals[keep, :] scores = scores[keep] order = scores.ravel().argsort(descending=True) proposals = t.FloatTensor(proposals, device = rpn_cls_prob.device) if pre_nms_topN > 0: order = order[:pre_nms_topN] proposals = proposals[order, :] scores = scores[order] keep = torchvision.ops.nms(proposals, scores.ravel(), nms_thresh) if post_nms_topN > 0: keep = keep[:post_nms_topN] proposals = proposals[keep, :] scores = scores[keep] batch_inds = t.zeros((proposals.shape[0], 1), dtype = proposals.dtype, device = proposals.device) rois = t.hstack([batch_inds, proposals]) return rois #keep = nms(np.hstack((proposals, scores)), nms_thresh) #if post_nms_topN > 0: # keep = keep[:post_nms_topN] #proposals = proposals[keep, :] #scores = scores[keep] def _enumerate_shifted_anchor(self, anchor_base, feat_stride, height, width): # Enumerate all shifted anchors: # # add A anchors (1, A, 4) to # cell K shifts (K, 1, 4) to get # shift anchors (K, A, 4) # reshape to (K*A, 4) shifted anchors # return (K*A, 4) # !TODO: add support for torch.CudaTensor # xp = cuda.get_array_module(anchor_base) # it seems that it can't be boosed using GPU import numpy as xp shift_y = xp.arange(0, height * feat_stride, feat_stride) shift_x = xp.arange(0, width * feat_stride, feat_stride) shift_x, shift_y = xp.meshgrid(shift_x, shift_y) shift = xp.stack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel()), axis=1) A = anchor_base.shape[0] K = shift.shape[0] anchor = anchor_base.reshape((1, A, 4)) + \ shift.reshape((1, K, 4)).transpose((1, 0, 2)) anchor = anchor.reshape((K * A, 4)).astype(np.float32) return anchor def _filter_boxes(boxes, min_size): """Remove all boxes with any side smaller than min_size.""" ws = boxes[:, 2] - boxes[:, 0] + 1 hs = boxes[:, 3] - boxes[:, 1] + 1 keep = np.where((ws >= min_size) & (hs >= min_size))[0] return keep