logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

120 lines
4.1 KiB

import pickle
import numpy as np
import torch as t
import torch
from torch.nn import functional as F
from torch import nn
import torchvision
from torch import nn
from .generate_anchors import generate_anchors
from .bbox_transform import bbox_transform_inv, clip_boxes
class RegionProposalNetwork(nn.Module):
def __init__(self, pre_nms_topN,post_nms_topN, nms_thresh, min_size, anchor_scales, feat_stride):
super(RegionProposalNetwork, self).__init__()
self._anchors = generate_anchors(scales=np.array(anchor_scales))
self._num_anchors = self._anchors.shape[0]
self._feat_stride = feat_stride
self.pre_nms_topN = pre_nms_topN
self.post_nms_topN = post_nms_topN
self.nms_thresh = nms_thresh
self.min_size = min_size
self.anchor_scales = anchor_scales
self.feat_stride = feat_stride
def forward(self, rpn_cls_prob, rpn_bbox_pred, img_size):
scores = rpn_cls_prob[:,self._num_anchors:, :, :]
bbox_deltas = rpn_bbox_pred
min_size = self.min_size
pre_nms_topN = self.pre_nms_topN
post_nms_topN = self.post_nms_topN
nms_thresh = self.nms_thresh
min_size = self.min_size
anchor_scales = self.anchor_scales
feat_stride = self.feat_stride
n, _, hh, ww = scores.shape
#n_anchor = anchor.shape[0] // (hh * ww)
anchors = self._enumerate_shifted_anchor(self._anchors, self._feat_stride, hh, ww)
bbox_deltas = bbox_deltas.permute((0, 2, 3, 1)).reshape((-1, 4))
bbox_deltas = bbox_deltas.cpu().detach().numpy()
proposals = bbox_transform_inv(anchors, bbox_deltas)
scores = scores.permute((0, 2, 3, 1)).reshape((-1, 1))
proposals = bbox_transform_inv(anchors, bbox_deltas)
proposals = clip_boxes(proposals, img_size[:2])
keep = _filter_boxes(proposals, min_size * img_size[2])
proposals = proposals[keep, :]
scores = scores[keep]
order = scores.ravel().argsort(descending=True)
proposals = t.FloatTensor(proposals, device = rpn_cls_prob.device)
if pre_nms_topN > 0:
order = order[:pre_nms_topN]
proposals = proposals[order, :]
scores = scores[order]
keep = torchvision.ops.nms(proposals, scores.ravel(), nms_thresh)
if post_nms_topN > 0:
keep = keep[:post_nms_topN]
proposals = proposals[keep, :]
scores = scores[keep]
batch_inds = t.zeros((proposals.shape[0], 1), dtype = proposals.dtype, device = proposals.device)
rois = t.hstack([batch_inds, proposals])
return rois
#keep = nms(np.hstack((proposals, scores)), nms_thresh)
#if post_nms_topN > 0:
# keep = keep[:post_nms_topN]
#proposals = proposals[keep, :]
#scores = scores[keep]
def _enumerate_shifted_anchor(self, anchor_base, feat_stride, height, width):
# Enumerate all shifted anchors:
#
# add A anchors (1, A, 4) to
# cell K shifts (K, 1, 4) to get
# shift anchors (K, A, 4)
# reshape to (K*A, 4) shifted anchors
# return (K*A, 4)
# !TODO: add support for torch.CudaTensor
# xp = cuda.get_array_module(anchor_base)
# it seems that it can't be boosed using GPU
import numpy as xp
shift_y = xp.arange(0, height * feat_stride, feat_stride)
shift_x = xp.arange(0, width * feat_stride, feat_stride)
shift_x, shift_y = xp.meshgrid(shift_x, shift_y)
shift = xp.stack((shift_x.ravel(), shift_y.ravel(),
shift_x.ravel(), shift_y.ravel()), axis=1)
A = anchor_base.shape[0]
K = shift.shape[0]
anchor = anchor_base.reshape((1, A, 4)) + \
shift.reshape((1, K, 4)).transpose((1, 0, 2))
anchor = anchor.reshape((K * A, 4)).astype(np.float32)
return anchor
def _filter_boxes(boxes, min_size):
"""Remove all boxes with any side smaller than min_size."""
ws = boxes[:, 2] - boxes[:, 0] + 1
hs = boxes[:, 3] - boxes[:, 1] + 1
keep = np.where((ws >= min_size) & (hs >= min_size))[0]
return keep