Shortcuts

Source code for mmdet3d.models.dense_heads.base_conv_bbox_head

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import build_conv_layer
from mmcv.runner import BaseModule
from torch import nn as nn

from ..builder import HEADS


[docs]@HEADS.register_module() class BaseConvBboxHead(BaseModule): r"""More general bbox head, with shared conv layers and two optional separated branches. .. code-block:: none /-> cls convs -> cls_score shared convs \-> reg convs -> bbox_pred """ def __init__(self, in_channels=0, shared_conv_channels=(), cls_conv_channels=(), num_cls_out_channels=0, reg_conv_channels=(), num_reg_out_channels=0, conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), act_cfg=dict(type='ReLU'), bias='auto', init_cfg=None, *args, **kwargs): super(BaseConvBboxHead, self).__init__( init_cfg=init_cfg, *args, **kwargs) assert in_channels > 0 assert num_cls_out_channels > 0 assert num_reg_out_channels > 0 self.in_channels = in_channels self.shared_conv_channels = shared_conv_channels self.cls_conv_channels = cls_conv_channels self.num_cls_out_channels = num_cls_out_channels self.reg_conv_channels = reg_conv_channels self.num_reg_out_channels = num_reg_out_channels self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.bias = bias # add shared convs if len(self.shared_conv_channels) > 0: self.shared_convs = self._add_conv_branch( self.in_channels, self.shared_conv_channels) out_channels = self.shared_conv_channels[-1] else: out_channels = self.in_channels # add cls specific branch prev_channel = out_channels if len(self.cls_conv_channels) > 0: self.cls_convs = self._add_conv_branch(prev_channel, self.cls_conv_channels) prev_channel = self.cls_conv_channels[-1] self.conv_cls = build_conv_layer( conv_cfg, in_channels=prev_channel, out_channels=num_cls_out_channels, kernel_size=1) # add reg specific branch prev_channel = out_channels if len(self.reg_conv_channels) > 0: self.reg_convs = self._add_conv_branch(prev_channel, self.reg_conv_channels) prev_channel = self.reg_conv_channels[-1] self.conv_reg = build_conv_layer( conv_cfg, in_channels=prev_channel, out_channels=num_reg_out_channels, kernel_size=1) def _add_conv_branch(self, in_channels, conv_channels): """Add shared or separable branch.""" conv_spec = [in_channels] + list(conv_channels) # add branch specific conv layers conv_layers = nn.Sequential() for i in range(len(conv_spec) - 1): conv_layers.add_module( f'layer{i}', ConvModule( conv_spec[i], conv_spec[i + 1], kernel_size=1, padding=0, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg, bias=self.bias, inplace=True)) return conv_layers
[docs] def forward(self, feats): """Forward. Args: feats (Tensor): Input features Returns: Tensor: Class scores predictions Tensor: Regression predictions """ # shared part if len(self.shared_conv_channels) > 0: x = self.shared_convs(feats) # separate branches x_cls = x x_reg = x if len(self.cls_conv_channels) > 0: x_cls = self.cls_convs(x_cls) cls_score = self.conv_cls(x_cls) if len(self.reg_conv_channels) > 0: x_reg = self.reg_convs(x_reg) bbox_pred = self.conv_reg(x_reg) return cls_score, bbox_pred
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.