diff --git a/README.md b/README.md index d727b449..0a8210b3 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,23 @@ If you found this codebase useful in your research, please consider citing ### Preparation Download nuscenes data from [https://www.nuscenes.org/](https://www.nuscenes.org/). Install dependencies. +Get miniconda in your dev environment + +``` +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +chmod +x Miniconda3-latest-Linux-x86_64.sh +bash Miniconda3-latest-Linux-x86_64.sh ``` -pip install nuscenes-devkit tensorboardX efficientnet_pytorch==0.7.0 + +Install dependencies + ``` +conda env create -f environment.yml +conda activate lss +``` + +Ensure consistent dataroot + ### Pre-trained Model Download a pre-trained BEV vehicle segmentation model from here: [https://drive.google.com/file/d/18fy-6beTFTZx5SrYLs9Xk7cY-fGSm7kw/view?usp=sharing](https://drive.google.com/file/d/18fy-6beTFTZx5SrYLs9Xk7cY-fGSm7kw/view?usp=sharing) diff --git a/environment.yml b/environment.yml new file mode 100644 index 00000000..acc85e08 --- /dev/null +++ b/environment.yml @@ -0,0 +1,34 @@ +name: l3d +channels: + - pyg + - pytorch + - pytorch3d + - conda-forge + - fvcore + - iopath + - bottler + - defaults +dependencies: +# - cudatoolkit=11.0 +# - python=3.9 + - pip + - pytorch + - pytorch3d + - torchvision + - fvcore + - iopath + - nvidiacub + - pip: + - hydra-core + - Pillow + - plotly + - requests + - imageio + - matplotlib + - numpy + - PyMCubes + - tqdm + - visdom + - nuscenes-devkit + - tensorboardX + - efficientnet_pytorch==0.7.0 diff --git a/main.py b/main.py index d8c2a957..01671ff4 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ """ from fire import Fire +from nuscenes import NuScenes import src @@ -13,7 +14,6 @@ Fire({ 'lidar_check': src.explore.lidar_check, 'cumsum_check': src.explore.cumsum_check, - 'train': src.train.train, 'eval_model_iou': src.explore.eval_model_iou, 'viz_model_preds': src.explore.viz_model_preds, diff --git a/src/data.py b/src/data.py index da465257..66993b47 100644 --- a/src/data.py +++ b/src/data.py @@ -15,6 +15,9 @@ from nuscenes.utils.data_classes import Box from glob import glob +from pytorch3d.structures import Pointclouds +from torch.utils.data import default_collate + from .tools import get_lidar_data, img_transform, normalize_img, gen_dx_bx @@ -69,7 +72,6 @@ def find_name(f): if rec['channel'] == 'LIDAR_TOP' or (rec['is_key_frame'] and rec['channel'] in self.data_aug_conf['cams']): rec['filename'] = info[rec['filename']] - def get_scenes(self): # filter by scene split split = { @@ -219,22 +221,41 @@ def __getitem__(self, index): imgs, rots, trans, intrins, post_rots, post_trans = self.get_image_data(rec, cams) lidar_data = self.get_lidar_data(rec, nsweeps=3) binimg = self.get_binimg(rec) - + + return imgs, rots, trans, intrins, post_rots, post_trans, lidar_data, binimg class SegmentationData(NuscData): def __init__(self, *args, **kwargs): super(SegmentationData, self).__init__(*args, **kwargs) - + self.nsweeps = 1 + def __getitem__(self, index): rec = self.ixes[index] cams = self.choose_cams() imgs, rots, trans, intrins, post_rots, post_trans = self.get_image_data(rec, cams) + lidar_pc = self.get_lidar_data(rec, self.nsweeps) + lidar_pc = lidar_pc.permute(1, 0) + binimg = self.get_binimg(rec) - - return imgs, rots, trans, intrins, post_rots, post_trans, binimg + + return imgs, rots, trans, intrins, post_rots, post_trans, binimg, lidar_pc + + def collate_fn(self, batch): + + imgs = default_collate(list(map(lambda x: x[0], batch))) + rots = default_collate(list(map(lambda x: x[1], batch))) + trans = default_collate(list(map(lambda x: x[2], batch))) + intrins = default_collate(list(map(lambda x: x[3], batch))) + post_rots = default_collate(list(map(lambda x: x[4], batch))) + post_trans = default_collate(list(map(lambda x: x[5], batch))) + bin_img = default_collate(list(map(lambda x: x[6], batch))) + lidar_pc = Pointclouds(list(map(lambda x: x[7], batch))) + + return imgs, rots, trans, intrins, post_rots, post_trans, bin_img, lidar_pc + def worker_rnd_init(x): @@ -243,8 +264,9 @@ def worker_rnd_init(x): def compile_data(version, dataroot, data_aug_conf, grid_conf, bsz, nworkers, parser_name): + nusc = NuScenes(version='v1.0-{}'.format(version), - dataroot=os.path.join(dataroot, version), + dataroot=dataroot, verbose=False) parser = { 'vizdata': VizData, @@ -255,13 +277,29 @@ def compile_data(version, dataroot, data_aug_conf, grid_conf, bsz, valdata = parser(nusc, is_train=False, data_aug_conf=data_aug_conf, grid_conf=grid_conf) - trainloader = torch.utils.data.DataLoader(traindata, batch_size=bsz, - shuffle=True, - num_workers=nworkers, - drop_last=True, - worker_init_fn=worker_rnd_init) - valloader = torch.utils.data.DataLoader(valdata, batch_size=bsz, - shuffle=False, - num_workers=nworkers) + if parser_name == 'segmentationdata': + trainloader = torch.utils.data.DataLoader(traindata, batch_size=bsz, + shuffle=True, + num_workers=nworkers, + drop_last=True, + worker_init_fn=worker_rnd_init, + collate_fn=traindata.collate_fn) + valloader = torch.utils.data.DataLoader(valdata, batch_size=bsz, + shuffle=False, + num_workers=nworkers, + collate_fn=valdata.collate_fn) + elif parser_name == 'vizdata': + trainloader = torch.utils.data.DataLoader(traindata, batch_size=bsz, + shuffle=True, + num_workers=nworkers, + drop_last=True, + worker_init_fn=worker_rnd_init) + valloader = torch.utils.data.DataLoader(valdata, batch_size=bsz, + shuffle=False, + num_workers=nworkers) + else: + raise ValueError(parser_name) + + return trainloader, valloader diff --git a/src/models.py b/src/models.py index 75f3dbd3..42316f11 100644 --- a/src/models.py +++ b/src/models.py @@ -3,8 +3,12 @@ Licensed under the NVIDIA Source Code License. See LICENSE at https://github.com/nv-tlabs/lift-splat-shoot. Authors: Jonah Philion and Sanja Fidler """ +import numpy as np +from typing import Optional, Any import torch +import matplotlib +from matplotlib import pyplot as plt from torch import nn from efficientnet_pytorch import EfficientNet from torchvision.models.resnet import resnet18 @@ -49,10 +53,14 @@ def get_depth_dist(self, x, eps=1e-20): return x.softmax(dim=1) def get_depth_feat(self, x): - x = self.get_eff_depth(x) - # Depth - x = self.depthnet(x) + # x [B * N, 3, H, W] -> [B * N, feat_dim, H_Down, W_Down] + efficient_net_feats = self.get_eff_depth(x) + + # First D channels corresponds to your depth distributions, last C is your features. + x = self.depthnet(efficient_net_feats) + + # Depth is basically your alpha which creates the probability distribution across dpeth depth = self.get_depth_dist(x[:, :self.D]) new_x = depth.unsqueeze(1) * x[:, self.D:(self.D + self.C)].unsqueeze(2) @@ -84,7 +92,7 @@ def get_eff_depth(self, x): def forward(self, x): depth, x = self.get_depth_feat(x) - return x + return depth, x class BevEncode(nn.Module): @@ -170,12 +178,13 @@ def get_geometry(self, rots, trans, intrins, post_rots, post_trans): """ B, N, _ = trans.shape - # undo post-transformation + # 1. De-apply the augmentation such that we have consistent geometric objects. + # N is the number of cameras. # B x N x D x H x W x 3 points = self.frustum - post_trans.view(B, N, 1, 1, 1, 3) points = torch.inverse(post_rots).view(B, N, 1, 1, 1, 3, 3).matmul(points.unsqueeze(-1)) - # cam_to_ego + # 2. Perform transformation from camera to ego points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3], points[:, :, :, :, :, 2:3] ), 5) @@ -190,12 +199,12 @@ def get_cam_feats(self, x): """ B, N, C, imH, imW = x.shape - x = x.view(B*N, C, imH, imW) - x = self.camencode(x) + x = x.view(B*N, C, imH, imW) # B * N, C, imH, imW + depth, x = self.camencode(x) # B * N, feat_dim, H, W x = x.view(B, N, self.camC, self.D, imH//self.downsample, imW//self.downsample) x = x.permute(0, 1, 3, 4, 5, 2) - return x + return depth, x def voxel_pooling(self, geom_feats, x): B, N, D, H, W, C = x.shape @@ -213,8 +222,8 @@ def voxel_pooling(self, geom_feats, x): # filter out points that are outside box kept = (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < self.nx[0])\ - & (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < self.nx[1])\ - & (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < self.nx[2]) + & (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < self.nx[1])\ + & (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < self.nx[2]) x = x[kept] geom_feats = geom_feats[kept] @@ -241,18 +250,42 @@ def voxel_pooling(self, geom_feats, x): return final - def get_voxels(self, x, rots, trans, intrins, post_rots, post_trans): + def get_voxel_and_depth_dist(self, x, rots, trans, intrins, post_rots, post_trans, lidar_pc: Optional[Any] = None): + # A batch (B) of frustums (N). geom = self.get_geometry(rots, trans, intrins, post_rots, post_trans) - x = self.get_cam_feats(x) - x = self.voxel_pooling(geom, x) + B, N = geom.shape[0], geom.shape[1] + depth, x = self.get_cam_feats(x) - return x + # Divide the dim + depth = depth.view(B, N, *depth.shape[1:]) - def forward(self, x, rots, trans, intrins, post_rots, post_trans): - x = self.get_voxels(x, rots, trans, intrins, post_rots, post_trans) + # Depth contains the post softmax depth logits + pred_pc = depth.unsqueeze(5) * geom + pred_pc = pred_pc.sum(dim=2) + x = self.voxel_pooling(geom, x) + return pred_pc, x + + def forward(self, x, rots, trans, intrins, post_rots, post_trans, lidar_pc: Optional[Any] = None): + """ + Perform the forward pass for the whole LSS pipeline + :param x: The images that we are trying to encode in the pipeline + :param rots: The rotation matrices that represents the extrinsics for each cameras. + :param trans: The translation that represents the extrinsics for each cameras. + :param intrins: The intrinsic matrices for each camera on the sensor rig. + :param post_rots: Augmentation + :param post_trans: Augmentation + :param lidar_pc: The LIDAR point cloud that we are interested to learn depth from. + :return: + """ + # x is the batch of imgs + # + pred_pc, x = self.get_voxel_and_depth_dist(x, rots, trans, intrins, post_rots, post_trans, lidar_pc) x = self.bevencode(x) - return x + + B = pred_pc.shape[0] + pred_pc = pred_pc.view(B, -1, 3) + return pred_pc, x def compile_model(grid_conf, data_aug_conf, outC): diff --git a/src/tools.py b/src/tools.py index b8870fe4..94ceabed 100644 --- a/src/tools.py +++ b/src/tools.py @@ -13,6 +13,7 @@ from PIL import Image from functools import reduce import matplotlib as mpl + mpl.use('Agg') import matplotlib.pyplot as plt from nuscenes.utils.data_classes import LidarPointCloud @@ -249,8 +250,8 @@ def get_val_info(model, valloader, loss_fn, device, use_tqdm=False): loader = tqdm(valloader) if use_tqdm else valloader with torch.no_grad(): for batch in loader: - allimgs, rots, trans, intrins, post_rots, post_trans, binimgs = batch - preds = model(allimgs.to(device), rots.to(device), + allimgs, rots, trans, intrins, post_rots, post_trans, binimgs, _ = batch + _, preds = model(allimgs.to(device), rots.to(device), trans.to(device), intrins.to(device), post_rots.to(device), post_trans.to(device)) binimgs = binimgs.to(device) diff --git a/src/train.py b/src/train.py index 4460c570..6328529e 100644 --- a/src/train.py +++ b/src/train.py @@ -3,9 +3,14 @@ Licensed under the NVIDIA Source Code License. See LICENSE at https://github.com/nv-tlabs/lift-splat-shoot. Authors: Jonah Philion and Sanja Fidler """ +from typing import Optional, Any import torch from time import time + +from matplotlib import pyplot as plt +from pytorch3d.loss import chamfer, chamfer_distance +from pytorch3d.structures import Pointclouds from tensorboardX import SummaryWriter import numpy as np import os @@ -15,32 +20,83 @@ from .tools import SimpleLoss, get_batch_iou, get_val_info +def visualize_gt_pred_pc(gt_pc, pred_pc, filepath: Optional[str] = None): + gt_pc_vis = gt_pc.view(-1, 3).detach().cpu().numpy() + pred_pc_vis = pred_pc.view(-1, 3).detach().cpu().numpy() + + # Assuming that gt_pc_vis and pred_pc_vis are 2D arrays with shape (n_points, 2) + xs_gt, ys_gt = gt_pc_vis[:, 0], gt_pc_vis[:, 1] + xs_pred, ys_pred = pred_pc_vis[:, 0], pred_pc_vis[:, 1] + + fig = plt.figure(figsize=(12, 7)) + ax = fig.add_subplot(111) # Adding 3D projection + + # Plotting the first set of points with the first color map + img_pred = ax.scatter(xs_pred, ys_pred, c=pred_pc_vis[:, 2], cmap='Reds') + + import ipdb; ipdb.set_trace() + + img_gt = ax.scatter(xs_gt, ys_gt, c=gt_pc_vis[:, 2], cmap='Blues') + + # Plotting the second set of points with the second color map + + # Creating color bars for each scatter plot + fig.colorbar(img_gt, ax=ax, shrink=0.5, aspect=5, label='Ground Truth') + fig.colorbar(img_pred, ax=ax, shrink=0.5, aspect=5, label='Prediction') + + # Setting the labels for the axes + ax.set_xlabel('X') + ax.set_ylabel('Y') + + fig.savefig(filepath) + + +def point_cloud_loss(gt_pc: Pointclouds, pred_pc: Pointclouds, mode: str = 'bidirectional'): + + # Get the length of individual point clouds within the batch. + gt_lens = [len(p) for p in gt_pc.points_list()] + pred_lens = [len(p) for p in pred_pc.points_list()] + + assert mode in ['bidirectional', 'gt_first', 'pred_first'] + if mode == 'bidirectional': + chamdist, _ = chamfer_distance(gt_pc, pred_pc, single_directional=False, x_lengths=gt_lens, y_lengths=pred_lens) + elif mode == 'gt_first': + chamdist, _ = chamfer_distance(gt_pc, pred_pc, single_directional=True, x_lengths=gt_lens, y_lengths=pred_lens) + elif mode == 'pred_first': + chamdist, _ = chamfer_distance(pred_pc, gt_pc, single_directional=True, x_lengths=pred_lens, y_lengths=gt_lens) + + return chamdist + + def train(version, - dataroot='/data/nuscenes', - nepochs=10000, - gpuid=1, - - H=900, W=1600, - resize_lim=(0.193, 0.225), - final_dim=(128, 352), - bot_pct_lim=(0.0, 0.22), - rot_lim=(-5.4, 5.4), - rand_flip=True, - ncams=5, - max_grad_norm=5.0, - pos_weight=2.13, - logdir='./runs', - - xbound=[-50.0, 50.0, 0.5], - ybound=[-50.0, 50.0, 0.5], - zbound=[-10.0, 10.0, 20.0], - dbound=[4.0, 45.0, 1.0], - - bsz=4, - nworkers=10, - lr=1e-3, - weight_decay=1e-7, - ): + dataroot='~/lss/data/', + nepochs=10000, + gpuid=1, + H=900, W=1600, + resize_lim=(0.193, 0.225), + final_dim=(128, 352), + bot_pct_lim=(0.0, 0.22), + rot_lim=(-5.4, 5.4), + rand_flip=True, + ncams=5, + max_grad_norm=5.0, + pos_weight=2.13, + logdir='./runs', + xbound=[-50.0, 50.0, 0.5], + ybound=[-50.0, 50.0, 0.5], + zbound=[-10.0, 10.0, 20.0], + dbound=[4.0, 45.0, 1.0], + bsz=4, + nworkers=10, + lr=1e-3, + weight_decay=1e-7, + pc_loss_weight=5e-2, + vis_dir='./visualize', + experiment_name='baseline' + ): + if not os.path.exists(vis_dir): + os.makedirs(vis_dir) + grid_conf = { 'xbound': xbound, 'ybound': ybound, @@ -48,16 +104,16 @@ def train(version, 'dbound': dbound, } data_aug_conf = { - 'resize_lim': resize_lim, - 'final_dim': final_dim, - 'rot_lim': rot_lim, - 'H': H, 'W': W, - 'rand_flip': rand_flip, - 'bot_pct_lim': bot_pct_lim, - 'cams': ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', - 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT'], - 'Ncams': ncams, - } + 'resize_lim': resize_lim, + 'final_dim': final_dim, + 'rot_lim': rot_lim, + 'H': H, 'W': W, + 'rand_flip': rand_flip, + 'bot_pct_lim': bot_pct_lim, + 'cams': ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', + 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT'], + 'Ncams': ncams, + } trainloader, valloader = compile_data(version, dataroot, data_aug_conf=data_aug_conf, grid_conf=grid_conf, bsz=bsz, nworkers=nworkers, parser_name='segmentationdata') @@ -71,34 +127,55 @@ def train(version, loss_fn = SimpleLoss(pos_weight).cuda(gpuid) - writer = SummaryWriter(logdir=logdir) + writer = SummaryWriter(logdir=f'{logdir}/{experiment_name}') val_step = 1000 if version == 'mini' else 10000 + if not os.path.exists(f'{vis_dir}/{experiment_name}'): + os.makedirs(f'{vis_dir}/{experiment_name}') + model.train() counter = 0 for epoch in range(nepochs): np.random.seed() - for batchi, (imgs, rots, trans, intrins, post_rots, post_trans, binimgs) in enumerate(trainloader): + for batchi, (imgs, rots, trans, intrins, post_rots, post_trans, binimgs, lidar_pc) in enumerate(trainloader): t0 = time() opt.zero_grad() - preds = model(imgs.to(device), - rots.to(device), - trans.to(device), - intrins.to(device), - post_rots.to(device), - post_trans.to(device), - ) + pred_pc, preds = model( + imgs.to(device), + rots.to(device), + trans.to(device), + intrins.to(device), + post_rots.to(device), + post_trans.to(device) + ) binimgs = binimgs.to(device) loss = loss_fn(preds, binimgs) - loss.backward() + + # Calculate the loss for the GT and the Pred PC. + # lidar_pc = lidar_pc.permute(0, 2, 1).to(device) + lidar_pc = lidar_pc.to(device) + pred_pc = Pointclouds(pred_pc) + pc_loss = point_cloud_loss(gt_pc=lidar_pc, pred_pc=pred_pc, mode='bidirectional') + + # Visualize the GT and Pred point cloud from the birds eye view with different color maps. + if counter % 100 == 4: + lidar_pc_vis = lidar_pc.points_list()[0] + pred_pc_vis = pred_pc.points_list()[0] + visualize_gt_pred_pc(gt_pc=lidar_pc_vis, pred_pc=pred_pc_vis, + filepath=f'{vis_dir}/{experiment_name}/gt_pred_pc_{counter}') + + total_loss = loss + pc_loss * pc_loss_weight + total_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) opt.step() - counter += 1 t1 = time() if counter % 10 == 0: print(counter, loss.item()) - writer.add_scalar('train/loss', loss, counter) + print(f"Epoch: {epoch}, Iter: {batchi}, Total Loss: {total_loss.item()}, Seg Loss: {loss.item()}, PC Loss {pc_loss.item()}") + writer.add_scalar('train/total_loss', total_loss, counter) + writer.add_scalar('train/loss_seg', loss, counter) + writer.add_scalar('train/loss_pc', pc_loss, counter) if counter % 50 == 0: _, _, iou = get_batch_iou(preds, binimgs) @@ -118,3 +195,5 @@ def train(version, print('saving', mname) torch.save(model.state_dict(), mname) model.train() + + counter += 1