Source code for mmdet3d.models.model_utils.transformer

from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING, MultiheadAttention
from torch import nn as nn


[docs]@ATTENTION.register_module() class GroupFree3DMHA(MultiheadAttention): """A warpper for torch.nn.MultiheadAttention for GroupFree3D. This module implements MultiheadAttention with identity connection, and positional encoding used in DETR is also passed as input. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. Same as `nn.MultiheadAttention`. attn_drop (float): A Dropout layer on attn_output_weights. Default 0.0. proj_drop (float): A Dropout layer. Default 0.0. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default to False. """ def __init__(self, embed_dims, num_heads, attn_drop=0., proj_drop=0., dropout_layer=dict(type='DropOut', drop_prob=0.), init_cfg=None, batch_first=False, **kwargs): super().__init__(embed_dims, num_heads, attn_drop, proj_drop, dropout_layer, init_cfg, batch_first, **kwargs)
[docs] def forward(self, query, key, value, identity, query_pos=None, key_pos=None, attn_mask=None, key_padding_mask=None, **kwargs): """Forward function for `GroupFree3DMHA`. **kwargs allow passing a more general data flow when combining with other operations in `transformerlayer`. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims]. Same in `nn.MultiheadAttention.forward`. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims]. Same in `nn.MultiheadAttention.forward`. If None, the ``query`` will be used. Defaults to None. value (Tensor): The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`. Defaults to None. If None, the `key` will be used. identity (Tensor): This tensor, with the same shape as x, will be used for the identity link. If None, `x` will be used. Defaults to None. query_pos (Tensor): The positional encoding for query, with the same shape as `x`. If not None, it will be added to `x` before forward function. Defaults to None. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. Defaults to None. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. attn_mask (Tensor): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims]. """ if hasattr(self, 'operation_name'): if self.operation_name == 'self_attn': value = value + query_pos elif self.operation_name == 'cross_attn': value = value + key_pos else: raise NotImplementedError( f'{self.__class__.name} ' f"can't be used as {self.operation_name}") else: value = value + query_pos return super(GroupFree3DMHA, self).forward( query=query, key=key, value=value, identity=identity, query_pos=query_pos, key_pos=key_pos, attn_mask=attn_mask, key_padding_mask=key_padding_mask, **kwargs)
@POSITIONAL_ENCODING.register_module() class ConvBNPositionalEncoding(nn.Module): """Absolute position embedding with Conv learning. Args: input_channel (int): input features dim. num_pos_feats (int): output position features dim. Defaults to 288 to be consistent with seed features dim. """ def __init__(self, input_channel, num_pos_feats=288): super().__init__() self.position_embedding_head = nn.Sequential( nn.Conv1d(input_channel, num_pos_feats, kernel_size=1), nn.BatchNorm1d(num_pos_feats), nn.ReLU(inplace=True), nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1)) def forward(self, xyz): """Forward pass. Args: xyz (Tensor): (B, N, 3) the coordinates to embed. Returns: Tensor: (B, num_pos_feats, N) the embeded position features. """ xyz = xyz.permute(0, 2, 1) position_embedding = self.position_embedding_head(xyz) return position_embedding