Shortcuts

Source code for mmdet3d.models.middle_encoders.sparse_encoder

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import auto_fp16
from torch import nn as nn

from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops import spconv as spconv
from ..builder import MIDDLE_ENCODERS


[docs]@MIDDLE_ENCODERS.register_module() class SparseEncoder(nn.Module): r"""Sparse encoder for SECOND and Part-A2. Args: in_channels (int): The number of input channels. sparse_shape (list[int]): The sparse shape of input tensor. order (list[str], optional): Order of conv module. Defaults to ('conv', 'norm', 'act'). norm_cfg (dict, optional): Config of normalization layer. Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01). base_channels (int, optional): Out channels for conv_input layer. Defaults to 16. output_channels (int, optional): Out channels for conv_out layer. Defaults to 128. encoder_channels (tuple[tuple[int]], optional): Convolutional channels of each encode block. encoder_paddings (tuple[tuple[int]], optional): Paddings of each encode block. Defaults to ((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)). block_type (str, optional): Type of the block to use. Defaults to 'conv_module'. """ def __init__(self, in_channels, sparse_shape, order=('conv', 'norm', 'act'), norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), base_channels=16, output_channels=128, encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)), encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1)), block_type='conv_module'): super().__init__() assert block_type in ['conv_module', 'basicblock'] self.sparse_shape = sparse_shape self.in_channels = in_channels self.order = order self.base_channels = base_channels self.output_channels = output_channels self.encoder_channels = encoder_channels self.encoder_paddings = encoder_paddings self.stage_num = len(self.encoder_channels) self.fp16_enabled = False # Spconv init all weight on its own assert isinstance(order, tuple) and len(order) == 3 assert set(order) == {'conv', 'norm', 'act'} if self.order[0] != 'conv': # pre activate self.conv_input = make_sparse_convmodule( in_channels, self.base_channels, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm1', conv_type='SubMConv3d', order=('conv', )) else: # post activate self.conv_input = make_sparse_convmodule( in_channels, self.base_channels, 3, norm_cfg=norm_cfg, padding=1, indice_key='subm1', conv_type='SubMConv3d') encoder_out_channels = self.make_encoder_layers( make_sparse_convmodule, norm_cfg, self.base_channels, block_type=block_type) self.conv_out = make_sparse_convmodule( encoder_out_channels, self.output_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), norm_cfg=norm_cfg, padding=0, indice_key='spconv_down2', conv_type='SparseConv3d')
[docs] @auto_fp16(apply_to=('voxel_features', )) def forward(self, voxel_features, coors, batch_size): """Forward of SparseEncoder. Args: voxel_features (torch.float32): Voxel features in shape (N, C). coors (torch.int32): Coordinates in shape (N, 4), the columns in the order of (batch_idx, z_idx, y_idx, x_idx). batch_size (int): Batch size. Returns: dict: Backbone features. """ coors = coors.int() input_sp_tensor = spconv.SparseConvTensor(voxel_features, coors, self.sparse_shape, batch_size) x = self.conv_input(input_sp_tensor) encode_features = [] for encoder_layer in self.encoder_layers: x = encoder_layer(x) encode_features.append(x) # for detection head # [200, 176, 5] -> [200, 176, 2] out = self.conv_out(encode_features[-1]) spatial_features = out.dense() N, C, D, H, W = spatial_features.shape spatial_features = spatial_features.view(N, C * D, H, W) return spatial_features
[docs] def make_encoder_layers(self, make_block, norm_cfg, in_channels, block_type='conv_module', conv_cfg=dict(type='SubMConv3d')): """make encoder layers using sparse convs. Args: make_block (method): A bounded function to build blocks. norm_cfg (dict[str]): Config of normalization layer. in_channels (int): The number of encoder input channels. block_type (str, optional): Type of the block to use. Defaults to 'conv_module'. conv_cfg (dict, optional): Config of conv layer. Defaults to dict(type='SubMConv3d'). Returns: int: The number of encoder output channels. """ assert block_type in ['conv_module', 'basicblock'] self.encoder_layers = spconv.SparseSequential() for i, blocks in enumerate(self.encoder_channels): blocks_list = [] for j, out_channels in enumerate(tuple(blocks)): padding = tuple(self.encoder_paddings[i])[j] # each stage started with a spconv layer # except the first stage if i != 0 and j == 0 and block_type == 'conv_module': blocks_list.append( make_block( in_channels, out_channels, 3, norm_cfg=norm_cfg, stride=2, padding=padding, indice_key=f'spconv{i + 1}', conv_type='SparseConv3d')) elif block_type == 'basicblock': if j == len(blocks) - 1 and i != len( self.encoder_channels) - 1: blocks_list.append( make_block( in_channels, out_channels, 3, norm_cfg=norm_cfg, stride=2, padding=padding, indice_key=f'spconv{i + 1}', conv_type='SparseConv3d')) else: blocks_list.append( SparseBasicBlock( out_channels, out_channels, norm_cfg=norm_cfg, conv_cfg=conv_cfg)) else: blocks_list.append( make_block( in_channels, out_channels, 3, norm_cfg=norm_cfg, padding=padding, indice_key=f'subm{i + 1}', conv_type='SubMConv3d')) in_channels = out_channels stage_name = f'encoder_layer{i + 1}' stage_layers = spconv.SparseSequential(*blocks_list) self.encoder_layers.add_module(stage_name, stage_layers) return out_channels
Read the Docs v: v1.0.0rc0
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.