import pickle import torch import numpy as np import cv2 import torchvision from torch import nn from .rpn import RegionProposalNetwork class ConvBlock(nn.Module): def __init__(self,i,o,k,s,p,d,use_relu = True): super(ConvBlock, self).__init__() self.conv = nn.Conv2d(i, o, k, s, p, d) self.bn = nn.BatchNorm2d(o) self.use_relu = use_relu if self.use_relu == True: self.relu = nn.ReLU() def forward(self, x): x = self.conv(x) x = self.bn(x) if self.use_relu == True: x = self.relu(x) return x def load_convblock(block, convname, bnname, scalename, weights): block.conv.weight = nn.Parameter(torch.FloatTensor(weights[convname][0])) block.conv.bias = nn.Parameter(torch.zeros_like(block.conv.bias)) block.bn.running_mean = nn.Parameter(torch.FloatTensor(weights[bnname][0] / weights[bnname][2])) block.bn.running_var = nn.Parameter(torch.FloatTensor(weights[bnname][1] / weights[bnname][2])) block.bn.weight = nn.Parameter(torch.FloatTensor(weights[scalename][0])) block.bn.bias = nn.Parameter(torch.FloatTensor(weights[scalename][1])) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = ConvBlock(3,64,7,2,3,1,True) self.pool1 = nn.MaxPool2d(3,2,0,ceil_mode=True) self.res2a_branch1 = ConvBlock(64,256,1,1,0,1,False) self.res2a_branch2a = ConvBlock(64,64,1,1,0,1,True) self.res2a_branch2b = ConvBlock(64,64,3,1,1,1,True) self.res2a_branch2c = ConvBlock(64,256,1,1,0,1,False) self.res2b_branch2a = ConvBlock(256,64,1,1,0,1,True) self.res2b_branch2b = ConvBlock(64,64,3,1,1,1,True) self.res2b_branch2c = ConvBlock(64,256,1,1,0,1,False) self.res2c_branch2a = ConvBlock(256,64,1,1,0,1,True) self.res2c_branch2b = ConvBlock(64,64,3,1,1,1,True) self.res2c_branch2c = ConvBlock(64,256,1,1,0,1,False) self.res3a_branch1 = ConvBlock(256,512,1,2,0,1,False) self.res3a_branch2a = ConvBlock(256,128,1,2,0,1,True) self.res3a_branch2b = ConvBlock(128,128,3,1,1,1,True) self.res3a_branch2c = ConvBlock(128,512,1,1,0,1,False) self.res3b1_branch2a = ConvBlock(512,128,1,1,0,1,True) self.res3b1_branch2b = ConvBlock(128,128,3,1,1,1,True) self.res3b1_branch2c = ConvBlock(128,512,1,1,0,1,False) self.res3b2_branch2a = ConvBlock(512,128,1,1,0,1,True) self.res3b2_branch2b = ConvBlock(128,128,3,1,1,1,True) self.res3b2_branch2c = ConvBlock(128,512,1,1,0,1,False) self.res3b3_branch2a = ConvBlock(512,128,1,1,0,1,True) self.res3b3_branch2b = ConvBlock(128,128,3,1,1,1,True) self.res3b3_branch2c = ConvBlock(128,512,1,1,0,1,False) self.res4a_branch1 = ConvBlock(512,1024,1,2,0,1,False) self.res4a_branch2a = ConvBlock(512,256,1,2,0,1,True) self.res4a_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4a_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b1_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b1_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b1_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b2_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b2_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b2_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b3_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b3_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b3_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b4_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b4_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b4_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b5_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b5_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b5_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b6_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b6_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b6_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b7_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b7_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b7_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b8_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b8_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b8_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b9_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b9_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b9_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b10_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b10_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b10_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b11_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b11_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b11_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b12_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b12_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b12_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b13_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b13_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b13_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b14_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b14_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b14_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b15_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b15_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b15_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b16_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b16_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b16_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b17_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b17_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b17_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b18_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b18_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b18_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b19_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b19_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b19_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b20_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b20_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b20_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b21_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b21_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b21_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res4b22_branch2a = ConvBlock(1024,256,1,1,0,1,True) self.res4b22_branch2b = ConvBlock(256,256,3,1,1,1,True) self.res4b22_branch2c = ConvBlock(256,1024,1,1,0,1,False) self.res5a_branch1 = ConvBlock(1024,2048,1,1,0,1,False) self.res5a_branch2a = ConvBlock(1024,512,1,1,0,1,True) self.res5a_branch2b = ConvBlock(512,512,3,1,2,2,True) self.res5a_branch2c = ConvBlock(512,2048,1,1,0,1,False) self.res5b_branch2a = ConvBlock(2048,512,1,1,0,1,True) self.res5b_branch2b = ConvBlock(512,512,3,1,2,2,True) self.res5b_branch2c = ConvBlock(512,2048,1,1,0,1,False) self.res5c_branch2a = ConvBlock(2048,512,1,1,0,1,True) self.res5c_branch2b = ConvBlock(512,512,3,1,2,2,True) self.res5c_branch2c = ConvBlock(512,2048,1,1,0,1,False) self.rpn_conv_3x3 = nn.Conv2d(1024,512,3,1,1,1) self.rpn_cls_score = nn.Conv2d(512,24,1,1,0,1) self.rpn_bbox_pred = nn.Conv2d(512,48,1,1,0,1) self.rpn = RegionProposalNetwork(pre_nms_topN = 6000, post_nms_topN = 300, nms_thresh = 0.7, min_size = 16, anchor_scales = (4, 8, 16, 32), feat_stride=16) #self.pool5 = nn.MaxPool2d(3,2,1,ceil_mode=True) self.cls_score = nn.Linear(2048, 1601) def infer_resblock(self, l, r, x): xl = x xr = x for b in l: xl = b(xl) for b in r: xr = b(xr) return xl + xr def forward(self, x, im_size): x = self.conv1(x) x = self.pool1(x) x = nn.functional.relu(self.infer_resblock([self.res2a_branch1], [self.res2a_branch2a,self.res2a_branch2b,self.res2a_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res2b_branch2a,self.res2b_branch2b,self.res2b_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res2c_branch2a,self.res2c_branch2b,self.res2c_branch2c],x)) x = nn.functional.relu(self.infer_resblock([self.res3a_branch1], [self.res3a_branch2a,self.res3a_branch2b,self.res3a_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res3b1_branch2a,self.res3b1_branch2b,self.res3b1_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res3b2_branch2a,self.res3b2_branch2b,self.res3b2_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res3b3_branch2a,self.res3b3_branch2b,self.res3b3_branch2c],x)) x = nn.functional.relu(self.infer_resblock([self.res4a_branch1], [self.res4a_branch2a,self.res4a_branch2b,self.res4a_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b1_branch2a,self.res4b1_branch2b,self.res4b1_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b2_branch2a,self.res4b2_branch2b,self.res4b2_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b3_branch2a,self.res4b3_branch2b,self.res4b3_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b4_branch2a,self.res4b4_branch2b,self.res4b4_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b5_branch2a,self.res4b5_branch2b,self.res4b5_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b6_branch2a,self.res4b6_branch2b,self.res4b6_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b7_branch2a,self.res4b7_branch2b,self.res4b7_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b8_branch2a,self.res4b8_branch2b,self.res4b8_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b9_branch2a,self.res4b9_branch2b,self.res4b9_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b10_branch2a,self.res4b10_branch2b,self.res4b10_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b11_branch2a,self.res4b11_branch2b,self.res4b11_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b12_branch2a,self.res4b12_branch2b,self.res4b12_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b13_branch2a,self.res4b13_branch2b,self.res4b13_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b14_branch2a,self.res4b14_branch2b,self.res4b14_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b15_branch2a,self.res4b15_branch2b,self.res4b15_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b16_branch2a,self.res4b16_branch2b,self.res4b16_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b17_branch2a,self.res4b17_branch2b,self.res4b17_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b18_branch2a,self.res4b18_branch2b,self.res4b18_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b19_branch2a,self.res4b19_branch2b,self.res4b19_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b20_branch2a,self.res4b20_branch2b,self.res4b20_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b21_branch2a,self.res4b21_branch2b,self.res4b21_branch2c],x)) #x = data_kv['res4b21'] x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res4b22_branch2a,self.res4b22_branch2b,self.res4b22_branch2c],x)) x_rpn_output = nn.functional.relu(self.rpn_conv_3x3(x)) x_rpn_cls_score = self.rpn_cls_score(x_rpn_output) x_rpn_bbox_pred = self.rpn_bbox_pred(x_rpn_output) n, c, h, w = x_rpn_cls_score.shape x_rpn_cls_score = x_rpn_cls_score.reshape(n,2,-1,w) x_rpn_cls_prob = nn.functional.softmax(x_rpn_cls_score, 1) x_rpn_cls_prob_reshape = x_rpn_cls_prob.reshape(n,24,-1,w) #im_size = np.array([600. , 600. , 2.6785715]) #im_size = np.array([5.6200000e+02, 1.0000000e+03, 8.9285713e-01]) rois = self.rpn.forward(x_rpn_cls_prob_reshape, x_rpn_bbox_pred, im_size) feats = torchvision.ops.roi_pool(x, rois, output_size=[14,14], spatial_scale=0.0625) x = nn.functional.relu(self.infer_resblock([self.res5a_branch1], [self.res5a_branch2a,self.res5a_branch2b,self.res5a_branch2c],feats)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res5b_branch2a,self.res5b_branch2b,self.res5b_branch2c],x)) x = nn.functional.relu(self.infer_resblock([lambda x: x], [self.res5c_branch2a,self.res5c_branch2b,self.res5c_branch2c],x)) x = torch.nn.functional.adaptive_avg_pool2d(x, (1,1)) pool5_flat = x.reshape((x.shape[0], -1)) x_cls_score = self.cls_score(pool5_flat) x_cls_prob = torch.nn.functional.softmax(x_cls_score, -1) x_cls_boxes = rois[:, 1:5] / im_size[2] max_conf, keep_boxes = self.post_process(rois, x_cls_boxes, x_cls_prob, 0.2) MIN_BOXES = 10 MAX_BOXES = 100 if len(keep_boxes) < MIN_BOXES: keep_boxes = torch.argsort(max_conf, 0, True)[:MIN_BOXES] elif len(keep_boxes) > MAX_BOXES: keep_boxes = torch.argsort(max_conf, 0, True)[:MAX_BOXES] boxes = x_cls_boxes[keep_boxes] features = pool5_flat[keep_boxes] confidence = max_conf[keep_boxes] return boxes, features, confidence def post_process(self, rois, cls_boxes, cls_prob, conf_thresh = 0.2): max_conf = torch.zeros((rois.shape[0]), device = rois.device ) for cls_ind in range(1, cls_prob.shape[1]): #cls_scores = scores[:, cls_ind] cls_scores = cls_prob[:, cls_ind] dets = torch.hstack( (cls_boxes, cls_scores[:, np.newaxis])) keep = np.array(torchvision.ops.nms(dets[:,:4],dets[:,4], 0.3 )) max_conf[keep] = torch.where(cls_scores[keep] > max_conf[keep], cls_scores[keep], max_conf[keep]) keep_boxes = torch.where(max_conf >= conf_thresh)[0] return max_conf, keep_boxes def load_weights_from_pkl(self, weights): with torch.no_grad(): load_convblock(self.conv1, 'conv1', 'bn_conv1', 'scale_conv1', weights_kv) load_convblock(self.res2a_branch1, 'res2a_branch1', 'bn2a_branch1', 'scale2a_branch1', weights_kv) load_convblock(self.res2a_branch2a, 'res2a_branch2a', 'bn2a_branch2a', 'scale2a_branch2a', weights_kv) load_convblock(self.res2a_branch2b, 'res2a_branch2b', 'bn2a_branch2b', 'scale2a_branch2b', weights_kv) load_convblock(self.res2a_branch2c, 'res2a_branch2c', 'bn2a_branch2c', 'scale2a_branch2c', weights_kv) load_convblock(self.res2b_branch2a, 'res2b_branch2a', 'bn2b_branch2a', 'scale2b_branch2a', weights_kv) load_convblock(self.res2b_branch2b, 'res2b_branch2b', 'bn2b_branch2b', 'scale2b_branch2b', weights_kv) load_convblock(self.res2b_branch2c, 'res2b_branch2c', 'bn2b_branch2c', 'scale2b_branch2c', weights_kv) load_convblock(self.res2c_branch2a, 'res2c_branch2a', 'bn2c_branch2a', 'scale2c_branch2a', weights_kv) load_convblock(self.res2c_branch2b, 'res2c_branch2b', 'bn2c_branch2b', 'scale2c_branch2b', weights_kv) load_convblock(self.res2c_branch2c, 'res2c_branch2c', 'bn2c_branch2c', 'scale2c_branch2c', weights_kv) load_convblock(self.res3a_branch1, 'res3a_branch1', 'bn3a_branch1', 'scale3a_branch1', weights_kv) load_convblock(self.res3a_branch2a, 'res3a_branch2a', 'bn3a_branch2a', 'scale3a_branch2a', weights_kv) load_convblock(self.res3a_branch2b, 'res3a_branch2b', 'bn3a_branch2b', 'scale3a_branch2b', weights_kv) load_convblock(self.res3a_branch2c, 'res3a_branch2c', 'bn3a_branch2c', 'scale3a_branch2c', weights_kv) load_convblock(self.res3b1_branch2a, 'res3b1_branch2a', 'bn3b1_branch2a', 'scale3b1_branch2a', weights_kv) load_convblock(self.res3b1_branch2b, 'res3b1_branch2b', 'bn3b1_branch2b', 'scale3b1_branch2b', weights_kv) load_convblock(self.res3b1_branch2c, 'res3b1_branch2c', 'bn3b1_branch2c', 'scale3b1_branch2c', weights_kv) load_convblock(self.res3b2_branch2a, 'res3b2_branch2a', 'bn3b2_branch2a', 'scale3b2_branch2a', weights_kv) load_convblock(self.res3b2_branch2b, 'res3b2_branch2b', 'bn3b2_branch2b', 'scale3b2_branch2b', weights_kv) load_convblock(self.res3b2_branch2c, 'res3b2_branch2c', 'bn3b2_branch2c', 'scale3b2_branch2c', weights_kv) load_convblock(self.res3b3_branch2a, 'res3b3_branch2a', 'bn3b3_branch2a', 'scale3b3_branch2a', weights_kv) load_convblock(self.res3b3_branch2b, 'res3b3_branch2b', 'bn3b3_branch2b', 'scale3b3_branch2b', weights_kv) load_convblock(self.res3b3_branch2c, 'res3b3_branch2c', 'bn3b3_branch2c', 'scale3b3_branch2c', weights_kv) load_convblock(self.res4a_branch1, 'res4a_branch1', 'bn4a_branch1', 'scale4a_branch1', weights_kv) load_convblock(self.res4a_branch2a, 'res4a_branch2a', 'bn4a_branch2a', 'scale4a_branch2a', weights_kv) load_convblock(self.res4a_branch2b, 'res4a_branch2b', 'bn4a_branch2b', 'scale4a_branch2b', weights_kv) load_convblock(self.res4a_branch2c, 'res4a_branch2c', 'bn4a_branch2c', 'scale4a_branch2c', weights_kv) load_convblock(self.res4b1_branch2a, 'res4b1_branch2a', 'bn4b1_branch2a', 'scale4b1_branch2a', weights_kv) load_convblock(self.res4b1_branch2b, 'res4b1_branch2b', 'bn4b1_branch2b', 'scale4b1_branch2b', weights_kv) load_convblock(self.res4b1_branch2c, 'res4b1_branch2c', 'bn4b1_branch2c', 'scale4b1_branch2c', weights_kv) load_convblock(self.res4b2_branch2a, 'res4b2_branch2a', 'bn4b2_branch2a', 'scale4b2_branch2a', weights_kv) load_convblock(self.res4b2_branch2b, 'res4b2_branch2b', 'bn4b2_branch2b', 'scale4b2_branch2b', weights_kv) load_convblock(self.res4b2_branch2c, 'res4b2_branch2c', 'bn4b2_branch2c', 'scale4b2_branch2c', weights_kv) load_convblock(self.res4b3_branch2a, 'res4b3_branch2a', 'bn4b3_branch2a', 'scale4b3_branch2a', weights_kv) load_convblock(self.res4b3_branch2b, 'res4b3_branch2b', 'bn4b3_branch2b', 'scale4b3_branch2b', weights_kv) load_convblock(self.res4b3_branch2c, 'res4b3_branch2c', 'bn4b3_branch2c', 'scale4b3_branch2c', weights_kv) load_convblock(self.res4b4_branch2a, 'res4b4_branch2a', 'bn4b4_branch2a', 'scale4b4_branch2a', weights_kv) load_convblock(self.res4b4_branch2b, 'res4b4_branch2b', 'bn4b4_branch2b', 'scale4b4_branch2b', weights_kv) load_convblock(self.res4b4_branch2c, 'res4b4_branch2c', 'bn4b4_branch2c', 'scale4b4_branch2c', weights_kv) load_convblock(self.res4b5_branch2a, 'res4b5_branch2a', 'bn4b5_branch2a', 'scale4b5_branch2a', weights_kv) load_convblock(self.res4b5_branch2b, 'res4b5_branch2b', 'bn4b5_branch2b', 'scale4b5_branch2b', weights_kv) load_convblock(self.res4b5_branch2c, 'res4b5_branch2c', 'bn4b5_branch2c', 'scale4b5_branch2c', weights_kv) load_convblock(self.res4b6_branch2a, 'res4b6_branch2a', 'bn4b6_branch2a', 'scale4b6_branch2a', weights_kv) load_convblock(self.res4b6_branch2b, 'res4b6_branch2b', 'bn4b6_branch2b', 'scale4b6_branch2b', weights_kv) load_convblock(self.res4b6_branch2c, 'res4b6_branch2c', 'bn4b6_branch2c', 'scale4b6_branch2c', weights_kv) load_convblock(self.res4b7_branch2a, 'res4b7_branch2a', 'bn4b7_branch2a', 'scale4b7_branch2a', weights_kv) load_convblock(self.res4b7_branch2b, 'res4b7_branch2b', 'bn4b7_branch2b', 'scale4b7_branch2b', weights_kv) load_convblock(self.res4b7_branch2c, 'res4b7_branch2c', 'bn4b7_branch2c', 'scale4b7_branch2c', weights_kv) load_convblock(self.res4b8_branch2a, 'res4b8_branch2a', 'bn4b8_branch2a', 'scale4b8_branch2a', weights_kv) load_convblock(self.res4b8_branch2b, 'res4b8_branch2b', 'bn4b8_branch2b', 'scale4b8_branch2b', weights_kv) load_convblock(self.res4b8_branch2c, 'res4b8_branch2c', 'bn4b8_branch2c', 'scale4b8_branch2c', weights_kv) load_convblock(self.res4b9_branch2a, 'res4b9_branch2a', 'bn4b9_branch2a', 'scale4b9_branch2a', weights_kv) load_convblock(self.res4b9_branch2b, 'res4b9_branch2b', 'bn4b9_branch2b', 'scale4b9_branch2b', weights_kv) load_convblock(self.res4b9_branch2c, 'res4b9_branch2c', 'bn4b9_branch2c', 'scale4b9_branch2c', weights_kv) load_convblock(self.res4b10_branch2a, 'res4b10_branch2a', 'bn4b10_branch2a', 'scale4b10_branch2a', weights_kv) load_convblock(self.res4b10_branch2b, 'res4b10_branch2b', 'bn4b10_branch2b', 'scale4b10_branch2b', weights_kv) load_convblock(self.res4b10_branch2c, 'res4b10_branch2c', 'bn4b10_branch2c', 'scale4b10_branch2c', weights_kv) load_convblock(self.res4b11_branch2a, 'res4b11_branch2a', 'bn4b11_branch2a', 'scale4b11_branch2a', weights_kv) load_convblock(self.res4b11_branch2b, 'res4b11_branch2b', 'bn4b11_branch2b', 'scale4b11_branch2b', weights_kv) load_convblock(self.res4b11_branch2c, 'res4b11_branch2c', 'bn4b11_branch2c', 'scale4b11_branch2c', weights_kv) load_convblock(self.res4b12_branch2a, 'res4b12_branch2a', 'bn4b12_branch2a', 'scale4b12_branch2a', weights_kv) load_convblock(self.res4b12_branch2b, 'res4b12_branch2b', 'bn4b12_branch2b', 'scale4b12_branch2b', weights_kv) load_convblock(self.res4b12_branch2c, 'res4b12_branch2c', 'bn4b12_branch2c', 'scale4b12_branch2c', weights_kv) load_convblock(self.res4b13_branch2a, 'res4b13_branch2a', 'bn4b13_branch2a', 'scale4b13_branch2a', weights_kv) load_convblock(self.res4b13_branch2b, 'res4b13_branch2b', 'bn4b13_branch2b', 'scale4b13_branch2b', weights_kv) load_convblock(self.res4b13_branch2c, 'res4b13_branch2c', 'bn4b13_branch2c', 'scale4b13_branch2c', weights_kv) load_convblock(self.res4b14_branch2a, 'res4b14_branch2a', 'bn4b14_branch2a', 'scale4b14_branch2a', weights_kv) load_convblock(self.res4b14_branch2b, 'res4b14_branch2b', 'bn4b14_branch2b', 'scale4b14_branch2b', weights_kv) load_convblock(self.res4b14_branch2c, 'res4b14_branch2c', 'bn4b14_branch2c', 'scale4b14_branch2c', weights_kv) load_convblock(self.res4b15_branch2a, 'res4b15_branch2a', 'bn4b15_branch2a', 'scale4b15_branch2a', weights_kv) load_convblock(self.res4b15_branch2b, 'res4b15_branch2b', 'bn4b15_branch2b', 'scale4b15_branch2b', weights_kv) load_convblock(self.res4b15_branch2c, 'res4b15_branch2c', 'bn4b15_branch2c', 'scale4b15_branch2c', weights_kv) load_convblock(self.res4b16_branch2a, 'res4b16_branch2a', 'bn4b16_branch2a', 'scale4b16_branch2a', weights_kv) load_convblock(self.res4b16_branch2b, 'res4b16_branch2b', 'bn4b16_branch2b', 'scale4b16_branch2b', weights_kv) load_convblock(self.res4b16_branch2c, 'res4b16_branch2c', 'bn4b16_branch2c', 'scale4b16_branch2c', weights_kv) load_convblock(self.res4b17_branch2a, 'res4b17_branch2a', 'bn4b17_branch2a', 'scale4b17_branch2a', weights_kv) load_convblock(self.res4b17_branch2b, 'res4b17_branch2b', 'bn4b17_branch2b', 'scale4b17_branch2b', weights_kv) load_convblock(self.res4b17_branch2c, 'res4b17_branch2c', 'bn4b17_branch2c', 'scale4b17_branch2c', weights_kv) load_convblock(self.res4b18_branch2a, 'res4b18_branch2a', 'bn4b18_branch2a', 'scale4b18_branch2a', weights_kv) load_convblock(self.res4b18_branch2b, 'res4b18_branch2b', 'bn4b18_branch2b', 'scale4b18_branch2b', weights_kv) load_convblock(self.res4b18_branch2c, 'res4b18_branch2c', 'bn4b18_branch2c', 'scale4b18_branch2c', weights_kv) load_convblock(self.res4b19_branch2a, 'res4b19_branch2a', 'bn4b19_branch2a', 'scale4b19_branch2a', weights_kv) load_convblock(self.res4b19_branch2b, 'res4b19_branch2b', 'bn4b19_branch2b', 'scale4b19_branch2b', weights_kv) load_convblock(self.res4b19_branch2c, 'res4b19_branch2c', 'bn4b19_branch2c', 'scale4b19_branch2c', weights_kv) load_convblock(self.res4b20_branch2a, 'res4b20_branch2a', 'bn4b20_branch2a', 'scale4b20_branch2a', weights_kv) load_convblock(self.res4b20_branch2b, 'res4b20_branch2b', 'bn4b20_branch2b', 'scale4b20_branch2b', weights_kv) load_convblock(self.res4b20_branch2c, 'res4b20_branch2c', 'bn4b20_branch2c', 'scale4b20_branch2c', weights_kv) load_convblock(self.res4b21_branch2a, 'res4b21_branch2a', 'bn4b21_branch2a', 'scale4b21_branch2a', weights_kv) load_convblock(self.res4b21_branch2b, 'res4b21_branch2b', 'bn4b21_branch2b', 'scale4b21_branch2b', weights_kv) load_convblock(self.res4b21_branch2c, 'res4b21_branch2c', 'bn4b21_branch2c', 'scale4b21_branch2c', weights_kv) load_convblock(self.res4b22_branch2a, 'res4b22_branch2a', 'bn4b22_branch2a', 'scale4b22_branch2a', weights_kv) load_convblock(self.res4b22_branch2b, 'res4b22_branch2b', 'bn4b22_branch2b', 'scale4b22_branch2b', weights_kv) load_convblock(self.res4b22_branch2c, 'res4b22_branch2c', 'bn4b22_branch2c', 'scale4b22_branch2c', weights_kv) load_convblock(self.res5a_branch1, 'res5a_branch1', 'bn5a_branch1', 'scale5a_branch1', weights_kv) load_convblock(self.res5a_branch2a, 'res5a_branch2a', 'bn5a_branch2a', 'scale5a_branch2a', weights_kv) load_convblock(self.res5a_branch2b, 'res5a_branch2b', 'bn5a_branch2b', 'scale5a_branch2b', weights_kv) load_convblock(self.res5a_branch2c, 'res5a_branch2c', 'bn5a_branch2c', 'scale5a_branch2c', weights_kv) load_convblock(self.res5b_branch2a, 'res5b_branch2a', 'bn5b_branch2a', 'scale5b_branch2a', weights_kv) load_convblock(self.res5b_branch2b, 'res5b_branch2b', 'bn5b_branch2b', 'scale5b_branch2b', weights_kv) load_convblock(self.res5b_branch2c, 'res5b_branch2c', 'bn5b_branch2c', 'scale5b_branch2c', weights_kv) load_convblock(self.res5c_branch2a, 'res5c_branch2a', 'bn5c_branch2a', 'scale5c_branch2a', weights_kv) load_convblock(self.res5c_branch2b, 'res5c_branch2b', 'bn5c_branch2b', 'scale5c_branch2b', weights_kv) load_convblock(self.res5c_branch2c, 'res5c_branch2c', 'bn5c_branch2c', 'scale5c_branch2c', weights_kv) self.rpn_conv_3x3.weight = nn.Parameter(torch.FloatTensor(weights_kv['rpn_conv/3x3'][0])) self.rpn_conv_3x3.bias = nn.Parameter(torch.FloatTensor(weights_kv['rpn_conv/3x3'][1])) self.rpn_cls_score.weight = nn.Parameter(torch.FloatTensor(weights_kv['rpn_cls_score'][0])) self.rpn_cls_score.bias = nn.Parameter(torch.FloatTensor(weights_kv['rpn_cls_score'][1])) self.rpn_bbox_pred.weight = nn.Parameter(torch.FloatTensor(weights_kv['rpn_bbox_pred'][0])) self.rpn_bbox_pred.bias = nn.Parameter(torch.FloatTensor(weights_kv['rpn_bbox_pred'][1])) self.cls_score.weight = nn.Parameter(torch.FloatTensor(weights_kv['cls_score'][0])) self.cls_score.bias = nn.Parameter(torch.FloatTensor(weights_kv['cls_score'][1])) # self.conv1.weight = nn.Parameter(torch.FloatTensor(weights[0]['weights'][0])) # self.conv1.bias = nn.Parameter(torch.zeros_like(self.conv1.bias)) # self.bn_conv1.running_mean = nn.Parameter(torch.FloatTensor(weights[1]['weights'][0] / weights[1]['weights'][2])) # self.bn_conv1.running_var = nn.Parameter(torch.FloatTensor(weights[1]['weights'][1] / weights[1]['weights'][2])) # self.bn_conv1.weight = nn.Parameter(torch.FloatTensor(weights[2]['weights'][0])) # self.bn_conv1.bias = nn.Parameter(torch.FloatTensor(weights[2]['weights'][1])) # def process_img(img): mean = np.array([[[102.9801, 115.9465, 122.7717]]]) img = img - mean im_shape = img.shape im_size_min = np.min(im_shape[0:2]) im_size_max = np.max(im_shape[0:2]) target_size = 600 max_size = 1000 im_scale = float(target_size) / float(im_size_min) if np.round(im_scale * im_size_max) > max_size: im_scale = float(max_size) / float(im_size_max) im = cv2.resize(img, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR) return im, np.array([im.shape[0], im.shape[1], im_scale])