Shortcuts

Source code for mmdet3d.models.backbones.dgcnn

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import BaseModule, auto_fp16
from torch import nn as nn

from mmdet3d.ops import DGCNNFAModule, DGCNNGFModule
from ..builder import BACKBONES


[docs]@BACKBONES.register_module() class DGCNNBackbone(BaseModule): """Backbone network for DGCNN. Args: in_channels (int): Input channels of point cloud. num_samples (tuple[int], optional): The number of samples for knn or ball query in each graph feature (GF) module. Defaults to (20, 20, 20). knn_modes (tuple[str], optional): Mode of KNN of each knn module. Defaults to ('D-KNN', 'F-KNN', 'F-KNN'). radius (tuple[float], optional): Sampling radii of each GF module. Defaults to (None, None, None). gf_channels (tuple[tuple[int]], optional): Out channels of each mlp in GF module. Defaults to ((64, 64), (64, 64), (64, )). fa_channels (tuple[int], optional): Out channels of each mlp in FA module. Defaults to (1024, ). act_cfg (dict, optional): Config of activation layer. Defaults to dict(type='ReLU'). init_cfg (dict, optional): Initialization config. Defaults to None. """ def __init__(self, in_channels, num_samples=(20, 20, 20), knn_modes=('D-KNN', 'F-KNN', 'F-KNN'), radius=(None, None, None), gf_channels=((64, 64), (64, 64), (64, )), fa_channels=(1024, ), act_cfg=dict(type='ReLU'), init_cfg=None): super().__init__(init_cfg=init_cfg) self.num_gf = len(gf_channels) assert len(num_samples) == len(knn_modes) == len(radius) == len( gf_channels), 'Num_samples, knn_modes, radius and gf_channels \ should have the same length.' self.GF_modules = nn.ModuleList() gf_in_channel = in_channels * 2 skip_channel_list = [gf_in_channel] # input channel list for gf_index in range(self.num_gf): cur_gf_mlps = list(gf_channels[gf_index]) cur_gf_mlps = [gf_in_channel] + cur_gf_mlps gf_out_channel = cur_gf_mlps[-1] self.GF_modules.append( DGCNNGFModule( mlp_channels=cur_gf_mlps, num_sample=num_samples[gf_index], knn_mode=knn_modes[gf_index], radius=radius[gf_index], act_cfg=act_cfg)) skip_channel_list.append(gf_out_channel) gf_in_channel = gf_out_channel * 2 fa_in_channel = sum(skip_channel_list[1:]) cur_fa_mlps = list(fa_channels) cur_fa_mlps = [fa_in_channel] + cur_fa_mlps self.FA_module = DGCNNFAModule( mlp_channels=cur_fa_mlps, act_cfg=act_cfg)
[docs] @auto_fp16(apply_to=('points', )) def forward(self, points): """Forward pass. Args: points (torch.Tensor): point coordinates with features, with shape (B, N, in_channels). Returns: dict[str, list[torch.Tensor]]: Outputs after graph feature (GF) and feature aggregation (FA) modules. - gf_points (list[torch.Tensor]): Outputs after each GF module. - fa_points (torch.Tensor): Outputs after FA module. """ gf_points = [points] for i in range(self.num_gf): cur_points = self.GF_modules[i](gf_points[i]) gf_points.append(cur_points) fa_points = self.FA_module(gf_points) out = dict(gf_points=gf_points, fa_points=fa_points) return out
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.