Shortcuts

Source code for mmdet3d.models.middle_encoders.pillar_scatter

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.runner import auto_fp16
from torch import nn

from ..builder import MIDDLE_ENCODERS


[docs]@MIDDLE_ENCODERS.register_module() class PointPillarsScatter(nn.Module): """Point Pillar's Scatter. Converts learned features from dense tensor to sparse pseudo image. Args: in_channels (int): Channels of input features. output_shape (list[int]): Required output shape of features. """ def __init__(self, in_channels, output_shape): super().__init__() self.output_shape = output_shape self.ny = output_shape[0] self.nx = output_shape[1] self.in_channels = in_channels self.fp16_enabled = False
[docs] @auto_fp16(apply_to=('voxel_features', )) def forward(self, voxel_features, coors, batch_size=None): """Foraward function to scatter features.""" # TODO: rewrite the function in a batch manner # no need to deal with different batch cases if batch_size is not None: return self.forward_batch(voxel_features, coors, batch_size) else: return self.forward_single(voxel_features, coors)
[docs] def forward_single(self, voxel_features, coors): """Scatter features of single sample. Args: voxel_features (torch.Tensor): Voxel features in shape (N, C). coors (torch.Tensor): Coordinates of each voxel. The first column indicates the sample ID. """ # Create the canvas for this sample canvas = torch.zeros( self.in_channels, self.nx * self.ny, dtype=voxel_features.dtype, device=voxel_features.device) indices = coors[:, 2] * self.nx + coors[:, 3] indices = indices.long() voxels = voxel_features.t() # Now scatter the blob back to the canvas. canvas[:, indices] = voxels # Undo the column stacking to final 4-dim tensor canvas = canvas.view(1, self.in_channels, self.ny, self.nx) return canvas
[docs] def forward_batch(self, voxel_features, coors, batch_size): """Scatter features of single sample. Args: voxel_features (torch.Tensor): Voxel features in shape (N, C). coors (torch.Tensor): Coordinates of each voxel in shape (N, 4). The first column indicates the sample ID. batch_size (int): Number of samples in the current batch. """ # batch_canvas will be the final output. batch_canvas = [] for batch_itt in range(batch_size): # Create the canvas for this sample canvas = torch.zeros( self.in_channels, self.nx * self.ny, dtype=voxel_features.dtype, device=voxel_features.device) # Only include non-empty pillars batch_mask = coors[:, 0] == batch_itt this_coors = coors[batch_mask, :] indices = this_coors[:, 2] * self.nx + this_coors[:, 3] indices = indices.type(torch.long) voxels = voxel_features[batch_mask, :] voxels = voxels.t() # Now scatter the blob back to the canvas. canvas[:, indices] = voxels # Append to a list for later stacking. batch_canvas.append(canvas) # Stack to 3-dim tensor (batch-size, in_channels, nrows*ncols) batch_canvas = torch.stack(batch_canvas, 0) # Undo the column stacking to final 4-dim tensor batch_canvas = batch_canvas.view(batch_size, self.in_channels, self.ny, self.nx) return batch_canvas
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.