Shortcuts

Source code for mmdet3d.models.model_utils.vote_module

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv import is_tuple_of
from mmcv.cnn import ConvModule
from torch import nn as nn

from mmdet3d.models.builder import build_loss


[docs]class VoteModule(nn.Module): """Vote module. Generate votes from seed point features. Args: in_channels (int): Number of channels of seed point features. vote_per_seed (int, optional): Number of votes generated from each seed point. Default: 1. gt_per_seed (int, optional): Number of ground truth votes generated from each seed point. Default: 3. num_points (int, optional): Number of points to be used for voting. Default: 1. conv_channels (tuple[int], optional): Out channels of vote generating convolution. Default: (16, 16). conv_cfg (dict, optional): Config of convolution. Default: dict(type='Conv1d'). norm_cfg (dict, optional): Config of normalization. Default: dict(type='BN1d'). norm_feats (bool, optional): Whether to normalize features. Default: True. with_res_feat (bool, optional): Whether to predict residual features. Default: True. vote_xyz_range (list[float], optional): The range of points translation. Default: None. vote_loss (dict, optional): Config of vote loss. Default: None. """ def __init__(self, in_channels, vote_per_seed=1, gt_per_seed=3, num_points=-1, conv_channels=(16, 16), conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), act_cfg=dict(type='ReLU'), norm_feats=True, with_res_feat=True, vote_xyz_range=None, vote_loss=None): super().__init__() self.in_channels = in_channels self.vote_per_seed = vote_per_seed self.gt_per_seed = gt_per_seed self.num_points = num_points self.norm_feats = norm_feats self.with_res_feat = with_res_feat assert vote_xyz_range is None or is_tuple_of(vote_xyz_range, float) self.vote_xyz_range = vote_xyz_range if vote_loss is not None: self.vote_loss = build_loss(vote_loss) prev_channels = in_channels vote_conv_list = list() for k in range(len(conv_channels)): vote_conv_list.append( ConvModule( prev_channels, conv_channels[k], 1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, bias=True, inplace=True)) prev_channels = conv_channels[k] self.vote_conv = nn.Sequential(*vote_conv_list) # conv_out predicts coordinate and residual features if with_res_feat: out_channel = (3 + in_channels) * self.vote_per_seed else: out_channel = 3 * self.vote_per_seed self.conv_out = nn.Conv1d(prev_channels, out_channel, 1)
[docs] def forward(self, seed_points, seed_feats): """forward. Args: seed_points (torch.Tensor): Coordinate of the seed points in shape (B, N, 3). seed_feats (torch.Tensor): Features of the seed points in shape (B, C, N). Returns: tuple[torch.Tensor]: - vote_points: Voted xyz based on the seed points with shape (B, M, 3), ``M=num_seed*vote_per_seed``. - vote_features: Voted features based on the seed points with shape (B, C, M) where ``M=num_seed*vote_per_seed``, ``C=vote_feature_dim``. """ if self.num_points != -1: assert self.num_points < seed_points.shape[1], \ f'Number of vote points ({self.num_points}) should be '\ f'smaller than seed points size ({seed_points.shape[1]})' seed_points = seed_points[:, :self.num_points] seed_feats = seed_feats[..., :self.num_points] batch_size, feat_channels, num_seed = seed_feats.shape num_vote = num_seed * self.vote_per_seed x = self.vote_conv(seed_feats) # (batch_size, (3+out_dim)*vote_per_seed, num_seed) votes = self.conv_out(x) votes = votes.transpose(2, 1).view(batch_size, num_seed, self.vote_per_seed, -1) offset = votes[:, :, :, 0:3] if self.vote_xyz_range is not None: limited_offset_list = [] for axis in range(len(self.vote_xyz_range)): limited_offset_list.append(offset[..., axis].clamp( min=-self.vote_xyz_range[axis], max=self.vote_xyz_range[axis])) limited_offset = torch.stack(limited_offset_list, -1) vote_points = (seed_points.unsqueeze(2) + limited_offset).contiguous() else: vote_points = (seed_points.unsqueeze(2) + offset).contiguous() vote_points = vote_points.view(batch_size, num_vote, 3) offset = offset.reshape(batch_size, num_vote, 3).transpose(2, 1) if self.with_res_feat: res_feats = votes[:, :, :, 3:] vote_feats = (seed_feats.transpose(2, 1).unsqueeze(2) + res_feats).contiguous() vote_feats = vote_feats.view(batch_size, num_vote, feat_channels).transpose( 2, 1).contiguous() if self.norm_feats: features_norm = torch.norm(vote_feats, p=2, dim=1) vote_feats = vote_feats.div(features_norm.unsqueeze(1)) else: vote_feats = seed_feats return vote_points, vote_feats, offset
[docs] def get_loss(self, seed_points, vote_points, seed_indices, vote_targets_mask, vote_targets): """Calculate loss of voting module. Args: seed_points (torch.Tensor): Coordinate of the seed points. vote_points (torch.Tensor): Coordinate of the vote points. seed_indices (torch.Tensor): Indices of seed points in raw points. vote_targets_mask (torch.Tensor): Mask of valid vote targets. vote_targets (torch.Tensor): Targets of votes. Returns: torch.Tensor: Weighted vote loss. """ batch_size, num_seed = seed_points.shape[:2] seed_gt_votes_mask = torch.gather(vote_targets_mask, 1, seed_indices).float() seed_indices_expand = seed_indices.unsqueeze(-1).repeat( 1, 1, 3 * self.gt_per_seed) seed_gt_votes = torch.gather(vote_targets, 1, seed_indices_expand) seed_gt_votes += seed_points.repeat(1, 1, self.gt_per_seed) weight = seed_gt_votes_mask / (torch.sum(seed_gt_votes_mask) + 1e-6) distance = self.vote_loss( vote_points.view(batch_size * num_seed, -1, 3), seed_gt_votes.view(batch_size * num_seed, -1, 3), dst_weight=weight.view(batch_size * num_seed, 1))[1] vote_loss = torch.sum(torch.min(distance, dim=1)[0]) return vote_loss
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.