Shortcuts

Source code for mmdet3d.models.model_utils.transformer

# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING, MultiheadAttention
from torch import nn as nn


[docs]@ATTENTION.register_module() class GroupFree3DMHA(MultiheadAttention): """A warpper for torch.nn.MultiheadAttention for GroupFree3D. This module implements MultiheadAttention with identity connection, and positional encoding used in DETR is also passed as input. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. Same as `nn.MultiheadAttention`. attn_drop (float, optional): A Dropout layer on attn_output_weights. Defaults to 0.0. proj_drop (float, optional): A Dropout layer. Defaults to 0.0. dropout_layer (obj:`ConfigDict`, optional): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`, optional): The Config for initialization. Default: None. batch_first (bool, optional): Key, Query and Value are shape of (batch, n, embed_dim) or (n, batch, embed_dim). Defaults to False. """ def __init__(self, embed_dims, num_heads, attn_drop=0., proj_drop=0., dropout_layer=dict(type='DropOut', drop_prob=0.), init_cfg=None, batch_first=False, **kwargs): super().__init__(embed_dims, num_heads, attn_drop, proj_drop, dropout_layer, init_cfg, batch_first, **kwargs)
[docs] def forward(self, query, key, value, identity, query_pos=None, key_pos=None, attn_mask=None, key_padding_mask=None, **kwargs): """Forward function for `GroupFree3DMHA`. **kwargs allow passing a more general data flow when combining with other operations in `transformerlayer`. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims]. Same in `nn.MultiheadAttention.forward`. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims]. Same in `nn.MultiheadAttention.forward`. If None, the ``query`` will be used. value (Tensor): The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`. If None, the `key` will be used. identity (Tensor): This tensor, with the same shape as x, will be used for the identity link. If None, `x` will be used. query_pos (Tensor, optional): The positional encoding for query, with the same shape as `x`. Defaults to None. If not None, it will be added to `x` before forward function. key_pos (Tensor, optional): The positional encoding for `key`, with the same shape as `key`. Defaults to None. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. attn_mask (Tensor, optional): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor, optional): ByteTensor with shape [bs, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims]. """ if hasattr(self, 'operation_name'): if self.operation_name == 'self_attn': value = value + query_pos elif self.operation_name == 'cross_attn': value = value + key_pos else: raise NotImplementedError( f'{self.__class__.name} ' f"can't be used as {self.operation_name}") else: value = value + query_pos return super(GroupFree3DMHA, self).forward( query=query, key=key, value=value, identity=identity, query_pos=query_pos, key_pos=key_pos, attn_mask=attn_mask, key_padding_mask=key_padding_mask, **kwargs)
@POSITIONAL_ENCODING.register_module() class ConvBNPositionalEncoding(nn.Module): """Absolute position embedding with Conv learning. Args: input_channel (int): input features dim. num_pos_feats (int, optional): output position features dim. Defaults to 288 to be consistent with seed features dim. """ def __init__(self, input_channel, num_pos_feats=288): super().__init__() self.position_embedding_head = nn.Sequential( nn.Conv1d(input_channel, num_pos_feats, kernel_size=1), nn.BatchNorm1d(num_pos_feats), nn.ReLU(inplace=True), nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1)) def forward(self, xyz): """Forward pass. Args: xyz (Tensor): (B, N, 3) the coordinates to embed. Returns: Tensor: (B, num_pos_feats, N) the embedded position features. """ xyz = xyz.permute(0, 2, 1) position_embedding = self.position_embedding_head(xyz) return position_embedding
Read the Docs v: dev
Versions
latest
stable
v1.0.0rc1
v1.0.0rc0
v0.18.1
v0.18.0
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
dev
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.