Shortcuts

Source code for mmdet3d.models.model_utils.transformer

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING, MultiheadAttention
from torch import nn as nn


[docs]@ATTENTION.register_module() class GroupFree3DMHA(MultiheadAttention): """A warpper for torch.nn.MultiheadAttention for GroupFree3D. This module implements MultiheadAttention with identity connection, and positional encoding used in DETR is also passed as input. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. Same as `nn.MultiheadAttention`. attn_drop (float): A Dropout layer on attn_output_weights. Default 0.0. proj_drop (float): A Dropout layer. Default 0.0. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. batch_first (bool): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Default to False. """ def __init__(self, embed_dims, num_heads, attn_drop=0., proj_drop=0., dropout_layer=dict(type='DropOut', drop_prob=0.), init_cfg=None, batch_first=False, **kwargs): super().__init__(embed_dims, num_heads, attn_drop, proj_drop, dropout_layer, init_cfg, batch_first, **kwargs)
[docs] def forward(self, query, key, value, identity, query_pos=None, key_pos=None, attn_mask=None, key_padding_mask=None, **kwargs): """Forward function for `GroupFree3DMHA`. **kwargs allow passing a more general data flow when combining with other operations in `transformerlayer`. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims]. Same in `nn.MultiheadAttention.forward`. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims]. Same in `nn.MultiheadAttention.forward`. If None, the ``query`` will be used. Defaults to None. value (Tensor): The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`. Defaults to None. If None, the `key` will be used. identity (Tensor): This tensor, with the same shape as x, will be used for the identity link. If None, `x` will be used. Defaults to None. query_pos (Tensor): The positional encoding for query, with the same shape as `x`. If not None, it will be added to `x` before forward function. Defaults to None. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. Defaults to None. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. attn_mask (Tensor): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims]. """ if hasattr(self, 'operation_name'): if self.operation_name == 'self_attn': value = value + query_pos elif self.operation_name == 'cross_attn': value = value + key_pos else: raise NotImplementedError( f'{self.__class__.name} ' f"can't be used as {self.operation_name}") else: value = value + query_pos return super(GroupFree3DMHA, self).forward( query=query, key=key, value=value, identity=identity, query_pos=query_pos, key_pos=key_pos, attn_mask=attn_mask, key_padding_mask=key_padding_mask, **kwargs)
@POSITIONAL_ENCODING.register_module() class ConvBNPositionalEncoding(nn.Module): """Absolute position embedding with Conv learning. Args: input_channel (int): input features dim. num_pos_feats (int): output position features dim. Defaults to 288 to be consistent with seed features dim. """ def __init__(self, input_channel, num_pos_feats=288): super().__init__() self.position_embedding_head = nn.Sequential( nn.Conv1d(input_channel, num_pos_feats, kernel_size=1), nn.BatchNorm1d(num_pos_feats), nn.ReLU(inplace=True), nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1)) def forward(self, xyz): """Forward pass. Args: xyz (Tensor): (B, N, 3) the coordinates to embed. Returns: Tensor: (B, num_pos_feats, N) the embeded position features. """ xyz = xyz.permute(0, 2, 1) position_embedding = self.position_embedding_head(xyz) return position_embedding
Read the Docs v: v0.17.3
Versions
latest
stable
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.