Source code for mmdet3d.models.fusion_layers.vote_fusion

import torch
from torch import nn as nn

from mmdet3d.core.bbox import points_cam2img
from ..builder import FUSION_LAYERS
from . import apply_3d_transformation, bbox_2d_transform, coord_2d_transform

EPS = 1e-6


[docs]@FUSION_LAYERS.register_module() class VoteFusion(nn.Module): """Fuse 2d features from 3d seeds. Args: num_classes (int): number of classes. max_imvote_per_pixel (int): max number of imvotes. """ def __init__(self, num_classes=10, max_imvote_per_pixel=3): super(VoteFusion, self).__init__() self.num_classes = num_classes self.max_imvote_per_pixel = max_imvote_per_pixel
[docs] def forward(self, imgs, bboxes_2d_rescaled, seeds_3d_depth, img_metas): """Forward function. Args: imgs (list[torch.Tensor]): Image features. bboxes_2d_rescaled (list[torch.Tensor]): 2D bboxes. seeds_3d_depth (torch.Tensor): 3D seeds. img_metas (list[dict]): Meta information of images. Returns: torch.Tensor: Concatenated cues of each point. torch.Tensor: Validity mask of each feature. """ img_features = [] masks = [] for i, data in enumerate( zip(imgs, bboxes_2d_rescaled, seeds_3d_depth, img_metas)): img, bbox_2d_rescaled, seed_3d_depth, img_meta = data bbox_num = bbox_2d_rescaled.shape[0] seed_num = seed_3d_depth.shape[0] img_shape = img_meta['img_shape'] img_h, img_w, _ = img_shape # first reverse the data transformations xyz_depth = apply_3d_transformation( seed_3d_depth, 'DEPTH', img_meta, reverse=True) # project points from depth to image depth2img = xyz_depth.new_tensor(img_meta['depth2img']) uvz_origin = points_cam2img(xyz_depth, depth2img, True) z_cam = uvz_origin[..., 2] uv_origin = (uvz_origin[..., :2] - 1).round() # rescale 2d coordinates and bboxes uv_rescaled = coord_2d_transform(img_meta, uv_origin, True) bbox_2d_origin = bbox_2d_transform(img_meta, bbox_2d_rescaled, False) if bbox_num == 0: imvote_num = seed_num * self.max_imvote_per_pixel # use zero features two_cues = torch.zeros((15, imvote_num), device=seed_3d_depth.device) mask_zero = torch.zeros( imvote_num - seed_num, device=seed_3d_depth.device).bool() mask_one = torch.ones( seed_num, device=seed_3d_depth.device).bool() mask = torch.cat([mask_one, mask_zero], dim=0) else: # expand bboxes and seeds bbox_expanded = bbox_2d_origin.view(1, bbox_num, -1).expand( seed_num, -1, -1) seed_2d_expanded = uv_origin.view(seed_num, 1, -1).expand(-1, bbox_num, -1) seed_2d_expanded_x, seed_2d_expanded_y = \ seed_2d_expanded.split(1, dim=-1) bbox_expanded_l, bbox_expanded_t, bbox_expanded_r, \ bbox_expanded_b, bbox_expanded_conf, bbox_expanded_cls = \ bbox_expanded.split(1, dim=-1) bbox_expanded_midx = (bbox_expanded_l + bbox_expanded_r) / 2 bbox_expanded_midy = (bbox_expanded_t + bbox_expanded_b) / 2 seed_2d_in_bbox_x = (seed_2d_expanded_x > bbox_expanded_l) * \ (seed_2d_expanded_x < bbox_expanded_r) seed_2d_in_bbox_y = (seed_2d_expanded_y > bbox_expanded_t) * \ (seed_2d_expanded_y < bbox_expanded_b) seed_2d_in_bbox = seed_2d_in_bbox_x * seed_2d_in_bbox_y # semantic cues, dim=class_num sem_cue = torch.zeros_like(bbox_expanded_conf).expand( -1, -1, self.num_classes) sem_cue = sem_cue.scatter(-1, bbox_expanded_cls.long(), bbox_expanded_conf) # bbox center - uv delta_u = bbox_expanded_midx - seed_2d_expanded_x delta_v = bbox_expanded_midy - seed_2d_expanded_y seed_3d_expanded = seed_3d_depth.view(seed_num, 1, -1).expand( -1, bbox_num, -1) z_cam = z_cam.view(seed_num, 1, 1).expand(-1, bbox_num, -1) imvote = torch.cat( [delta_u, delta_v, torch.zeros_like(delta_v)], dim=-1).view(-1, 3) imvote = imvote * z_cam.reshape(-1, 1) imvote = imvote @ torch.inverse(depth2img.t()) # apply transformation to lifted imvotes imvote = apply_3d_transformation( imvote, 'DEPTH', img_meta, reverse=False) seed_3d_expanded = seed_3d_expanded.reshape(imvote.shape) # ray angle ray_angle = seed_3d_expanded + imvote ray_angle /= torch.sqrt(torch.sum(ray_angle**2, -1) + EPS).unsqueeze(-1) # imvote lifted to 3d xz = ray_angle[:, [0, 2]] / (ray_angle[:, [1]] + EPS) \ * seed_3d_expanded[:, [1]] - seed_3d_expanded[:, [0, 2]] # geometric cues, dim=5 geo_cue = torch.cat([xz, ray_angle], dim=-1).view(seed_num, -1, 5) two_cues = torch.cat([geo_cue, sem_cue], dim=-1) # mask to 0 if seed not in bbox two_cues = two_cues * seed_2d_in_bbox.float() feature_size = two_cues.shape[-1] # if bbox number is too small, append zeros if bbox_num < self.max_imvote_per_pixel: append_num = self.max_imvote_per_pixel - bbox_num append_zeros = torch.zeros( (seed_num, append_num, 1), device=seed_2d_in_bbox.device).bool() seed_2d_in_bbox = torch.cat( [seed_2d_in_bbox, append_zeros], dim=1) append_zeros = torch.zeros( (seed_num, append_num, feature_size), device=two_cues.device) two_cues = torch.cat([two_cues, append_zeros], dim=1) append_zeros = torch.zeros((seed_num, append_num, 1), device=two_cues.device) bbox_expanded_conf = torch.cat( [bbox_expanded_conf, append_zeros], dim=1) # sort the valid seed-bbox pair according to confidence pair_score = seed_2d_in_bbox.float() + bbox_expanded_conf # and find the largests mask, indices = pair_score.topk( self.max_imvote_per_pixel, dim=1, largest=True, sorted=True) indices_img = indices.expand(-1, -1, feature_size) two_cues = two_cues.gather(dim=1, index=indices_img) two_cues = two_cues.transpose(1, 0) two_cues = two_cues.reshape(-1, feature_size).transpose( 1, 0).contiguous() # since conf is ~ (0, 1), floor gives us validity mask = mask.floor().int() mask = mask.transpose(1, 0).reshape(-1).bool() # clear the padding img = img[:, :img_shape[0], :img_shape[1]] img_flatten = img.reshape(3, -1).float() img_flatten /= 255. # take the normalized pixel value as texture cue uv_rescaled[:, 0] = torch.clamp(uv_rescaled[:, 0].round(), 0, img_shape[1] - 1) uv_rescaled[:, 1] = torch.clamp(uv_rescaled[:, 1].round(), 0, img_shape[0] - 1) uv_flatten = uv_rescaled[:, 1].round() * \ img_shape[1] + uv_rescaled[:, 0].round() uv_expanded = uv_flatten.unsqueeze(0).expand(3, -1).long() txt_cue = torch.gather(img_flatten, dim=-1, index=uv_expanded) txt_cue = txt_cue.unsqueeze(1).expand(-1, self.max_imvote_per_pixel, -1).reshape(3, -1) # append texture cue img_feature = torch.cat([two_cues, txt_cue], dim=0) img_features.append(img_feature) masks.append(mask) return torch.stack(img_features, 0), torch.stack(masks, 0)