关于build_rpn_targets函数解析
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | def build_rpn_targets(anchors, gt_class_ids, gt_boxes, config): """Given the anchors and GT boxes, compute overlaps and identify positive anchors and deltas to refine them to match their corresponding GT boxes. anchors: [num_anchors, (y1, x1, y2, x2)] gt_class_ids: [num_gt_boxes] Integer class IDs. gt_boxes: [num_gt_boxes, (y1, x1, y2, x2)] Returns: rpn_match: [N] (int32) matches between anchors and GT boxes. 1 = positive anchor, -1 = negative anchor, 0 = neutral rpn_bbox: [N, (dy, dx, log(dh), log(dw))] Anchor bbox deltas. """ # RPN Match: 1 = positive anchor, -1 = negative anchor, 0 = neutral rpn_match = np.zeros([anchors.shape[0]], dtype=np.int32) # RPN bounding boxes: [max anchors per image, (dy, dx, log(dh), log(dw))] rpn_bbox = np.zeros((config.RPN_TRAIN_ANCHORS_PER_IMAGE, 4)) # Handle COCO crowds # A crowd box in COCO is a bounding box around several instances. Exclude # them from training. A crowd box is given a negative class ID. crowd_ix = np.where(gt_class_ids < 0)[0] if crowd_ix.shape[0] > 0: # Filter out crowds from ground truth class IDs and boxes non_crowd_ix = np.where(gt_class_ids > 0)[0] crowd_boxes = gt_boxes[crowd_ix] gt_class_ids = gt_class_ids[non_crowd_ix] gt_boxes = gt_boxes[non_crowd_ix] # Compute overlaps with crowd boxes [anchors, crowds] crowd_overlaps = utils.compute_overlaps(anchors, crowd_boxes) crowd_iou_max = np.amax(crowd_overlaps, axis=1) no_crowd_bool = (crowd_iou_max < 0.001) else: # All anchors don't intersect a crowd no_crowd_bool = np.ones([anchors.shape[0]], dtype=bool) # Compute overlaps [num_anchors, num_gt_boxes] overlaps = utils.compute_overlaps(anchors, gt_boxes) # Match anchors to GT Boxes # If an anchor overlaps a GT box with IoU >= 0.7 then it's positive. # If an anchor overlaps a GT box with IoU < 0.3 then it's negative. # Neutral anchors are those that don't match the conditions above, # and they don't influence the loss function. # However, don't keep any GT box unmatched (rare, but happens). Instead, # match it to the closest anchor (even if its max IoU is < 0.3). # # 1. Set negative anchors first. They get overwritten below if a GT box is # matched to them. Skip boxes in crowd areas. anchor_iou_argmax = np.argmax(overlaps, axis=1) anchor_iou_max = overlaps[np.arange(overlaps.shape[0]), anchor_iou_argmax] rpn_match[(anchor_iou_max < 0.3) & (no_crowd_bool)] = -1 # 2. Set an anchor for each GT box (regardless of IoU value). # TODO: If multiple anchors have the same IoU match all of them gt_iou_argmax = np.argmax(overlaps, axis=0) rpn_match[gt_iou_argmax] = 1 # 3. Set anchors with high overlap as positive. rpn_match[anchor_iou_max >= 0.7] = 1 # Subsample to balance positive and negative anchors # Don't let positives be more than half the anchors pos_ids = np.where(rpn_match == 1)[0] extra = len(pos_ids) - (config.RPN_TRAIN_ANCHORS_PER_IMAGE // 2) if extra > 0: # Reset the extra ones to neutral pos_ids = np.random.choice(pos_ids, extra, replace=False) rpn_match[pos_ids] = 0 # Same for negative proposals neg_ids = np.where(rpn_match == -1)[0] extra = len(neg_ids) - (config.RPN_TRAIN_ANCHORS_PER_IMAGE - np.sum(rpn_match == 1)) if extra > 0: # Rest the extra ones to neutral neg_ids = np.random.choice(neg_ids, extra, replace=False) rpn_match[neg_ids] = 0 # For positive anchors, compute shift and scale needed to transform them # to match the corresponding GT boxes. ids = np.where(rpn_match == 1)[0] ix = 0 # TODO: use box_refinment() rather than duplicating the code here for idx, anchor in zip(ids, anchors[ids]): # Closest gt box (it might have IoU < 0.7) gt = gt_boxes[anchor_iou_argmax[idx]] # Convert coordinates to center plus width/height. # GT Box gt_h = gt[2] - gt[0] gt_w = gt[3] - gt[1] gt_center_y = gt[0] + 0.5 * gt_h gt_center_x = gt[1] + 0.5 * gt_w # Anchor a_h = anchor[2] - anchor[0] a_w = anchor[3] - anchor[1] a_center_y = anchor[0] + 0.5 * a_h a_center_x = anchor[1] + 0.5 * a_w # Compute the bbox refinement that the RPN should predict. rpn_bbox[ix] = [ (gt_center_y - a_center_y) / a_h, (gt_center_x - a_center_x) / a_w, np.log(gt_h / a_h), np.log(gt_w / a_w), ] ix += 1 # Normalize rpn_bbox[ix] /= config.BBOX_STD_DEV return rpn_match, rpn_bbox |
这里输入生成的所有anchors(大约6万个),gt_class_ids 标注grand truth的class ids和其所对应的bbox. 1, 如果没有遮挡crowd, 声明所有anchors遮挡状态no_crowd_bool为true.
2.
3. 3. 最后找到有效anchors (rpn_match里标注为1的anchor)所对应的ground truth box, 分别求他们中心点平移,和边框缩放值, 保存在rpn_box里
4. 返回rpn_match, rpn_bbox