Shortcuts

Source code for mmdet3d.core.bbox.structures.utils

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from logging import warning


[docs]def limit_period(val, offset=0.5, period=np.pi): """Limit the value into a period for periodic function. Args: val (torch.Tensor): The value to be converted. offset (float, optional): Offset to set the value range. \ Defaults to 0.5. period ([type], optional): Period of the value. Defaults to np.pi. Returns: torch.Tensor: Value in the range of \ [-offset * period, (1-offset) * period] """ return val - torch.floor(val / period + offset) * period
def rotation_3d_in_axis(points, angles, axis=0): """Rotate points by angles according to axis. Args: points (torch.Tensor): Points of shape (N, M, 3). angles (torch.Tensor): Vector of angles in shape (N,) axis (int, optional): The axis to be rotated. Defaults to 0. Raises: ValueError: when the axis is not in range [0, 1, 2], it will \ raise value error. Returns: torch.Tensor: Rotated points in shape (N, M, 3) """ rot_sin = torch.sin(angles) rot_cos = torch.cos(angles) ones = torch.ones_like(rot_cos) zeros = torch.zeros_like(rot_cos) if axis == 1: rot_mat_T = torch.stack([ torch.stack([rot_cos, zeros, -rot_sin]), torch.stack([zeros, ones, zeros]), torch.stack([rot_sin, zeros, rot_cos]) ]) elif axis == 2 or axis == -1: rot_mat_T = torch.stack([ torch.stack([rot_cos, -rot_sin, zeros]), torch.stack([rot_sin, rot_cos, zeros]), torch.stack([zeros, zeros, ones]) ]) elif axis == 0: rot_mat_T = torch.stack([ torch.stack([zeros, rot_cos, -rot_sin]), torch.stack([zeros, rot_sin, rot_cos]), torch.stack([ones, zeros, zeros]) ]) else: raise ValueError(f'axis should in range [0, 1, 2], got {axis}') return torch.einsum('aij,jka->aik', (points, rot_mat_T))
[docs]def xywhr2xyxyr(boxes_xywhr): """Convert a rotated boxes in XYWHR format to XYXYR format. Args: boxes_xywhr (torch.Tensor): Rotated boxes in XYWHR format. Returns: torch.Tensor: Converted boxes in XYXYR format. """ boxes = torch.zeros_like(boxes_xywhr) half_w = boxes_xywhr[:, 2] / 2 half_h = boxes_xywhr[:, 3] / 2 boxes[:, 0] = boxes_xywhr[:, 0] - half_w boxes[:, 1] = boxes_xywhr[:, 1] - half_h boxes[:, 2] = boxes_xywhr[:, 0] + half_w boxes[:, 3] = boxes_xywhr[:, 1] + half_h boxes[:, 4] = boxes_xywhr[:, 4] return boxes
[docs]def get_box_type(box_type): """Get the type and mode of box structure. Args: box_type (str): The type of box structure. The valid value are "LiDAR", "Camera", or "Depth". Returns: tuple: Box type and box mode. """ from .box_3d_mode import (Box3DMode, CameraInstance3DBoxes, DepthInstance3DBoxes, LiDARInstance3DBoxes) box_type_lower = box_type.lower() if box_type_lower == 'lidar': box_type_3d = LiDARInstance3DBoxes box_mode_3d = Box3DMode.LIDAR elif box_type_lower == 'camera': box_type_3d = CameraInstance3DBoxes box_mode_3d = Box3DMode.CAM elif box_type_lower == 'depth': box_type_3d = DepthInstance3DBoxes box_mode_3d = Box3DMode.DEPTH else: raise ValueError('Only "box_type" of "camera", "lidar", "depth"' f' are supported, got {box_type}') return box_type_3d, box_mode_3d
[docs]def points_cam2img(points_3d, proj_mat, with_depth=False): """Project points from camera coordicates to image coordinates. Args: points_3d (torch.Tensor): Points in shape (N, 3). proj_mat (torch.Tensor): Transformation matrix between coordinates. with_depth (bool, optional): Whether to keep depth in the output. Defaults to False. Returns: torch.Tensor: Points in image coordinates with shape [N, 2]. """ points_num = list(points_3d.shape)[:-1] points_shape = np.concatenate([points_num, [1]], axis=0).tolist() assert len(proj_mat.shape) == 2, 'The dimension of the projection'\ f' matrix should be 2 instead of {len(proj_mat.shape)}.' d1, d2 = proj_mat.shape[:2] assert (d1 == 3 and d2 == 3) or (d1 == 3 and d2 == 4) or ( d1 == 4 and d2 == 4), 'The shape of the projection matrix'\ f' ({d1}*{d2}) is not supported.' if d1 == 3: proj_mat_expanded = torch.eye( 4, device=proj_mat.device, dtype=proj_mat.dtype) proj_mat_expanded[:d1, :d2] = proj_mat proj_mat = proj_mat_expanded # previous implementation use new_zeros, new_one yeilds better results points_4 = torch.cat( [points_3d, points_3d.new_ones(*points_shape)], dim=-1) point_2d = torch.matmul(points_4, proj_mat.t()) point_2d_res = point_2d[..., :2] / point_2d[..., 2:3] if with_depth: return torch.cat([point_2d_res, point_2d[..., 2:3]], dim=-1) return point_2d_res
[docs]def mono_cam_box2vis(cam_box): """This is a post-processing function on the bboxes from Mono-3D task. If we want to perform projection visualization, we need to: 1. rotate the box along x-axis for np.pi / 2 (roll) 2. change orientation from local yaw to global yaw 3. convert yaw by (np.pi / 2 - yaw) After applying this function, we can project and draw it on 2D images. Args: cam_box (:obj:`CameraInstance3DBoxes`): 3D bbox in camera coordinate \ system before conversion. Could be gt bbox loaded from dataset or \ network prediction output. Returns: :obj:`CameraInstance3DBoxes`: Box after conversion. """ warning.warn('DeprecationWarning: The hack of yaw and dimension in the ' 'monocular 3D detection on nuScenes has been removed. The ' 'function mono_cam_box2vis will be deprecated.') from . import CameraInstance3DBoxes assert isinstance(cam_box, CameraInstance3DBoxes), \ 'input bbox should be CameraInstance3DBoxes!' loc = cam_box.gravity_center dim = cam_box.dims yaw = cam_box.yaw feats = cam_box.tensor[:, 7:] # rotate along x-axis for np.pi / 2 # see also here: https://github.com/open-mmlab/mmdetection3d/blob/master/mmdet3d/datasets/nuscenes_mono_dataset.py#L557 # noqa dim[:, [1, 2]] = dim[:, [2, 1]] # change local yaw to global yaw for visualization # refer to https://github.com/open-mmlab/mmdetection3d/blob/master/mmdet3d/datasets/nuscenes_mono_dataset.py#L164-L166 # noqa yaw += torch.atan2(loc[:, 0], loc[:, 2]) # convert yaw by (-yaw - np.pi / 2) # this is because mono 3D box class such as `NuScenesBox` has different # definition of rotation with our `CameraInstance3DBoxes` yaw = -yaw - np.pi / 2 cam_box = torch.cat([loc, dim, yaw[:, None], feats], dim=1) cam_box = CameraInstance3DBoxes( cam_box, box_dim=cam_box.shape[-1], origin=(0.5, 0.5, 0.5)) return cam_box
def get_proj_mat_by_coord_type(img_meta, coord_type): """Obtain image features using points. Args: img_meta (dict): Meta info. coord_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR'. Can be case-insensitive. Returns: torch.Tensor: transformation matrix. """ coord_type = coord_type.upper() mapping = {'LIDAR': 'lidar2img', 'DEPTH': 'depth2img', 'CAMERA': 'cam2img'} assert coord_type in mapping.keys() return img_meta[mapping[coord_type]]
Read the Docs v: v0.17.3
Versions
latest
stable
v0.17.3
v0.17.2
v0.17.1
v0.17.0
v0.16.0
v0.15.0
v0.14.0
v0.13.0
v0.12.0
v0.11.0
v0.10.0
v0.9.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.