Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Depth point cloud as predicted point cloud to be compared with GT point cloud (LIDAR) #53

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,23 @@ If you found this codebase useful in your research, please consider citing
### Preparation
Download nuscenes data from [https://www.nuscenes.org/](https://www.nuscenes.org/). Install dependencies.

Get miniconda in your dev environment

```
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
chmod +x Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
```
pip install nuscenes-devkit tensorboardX efficientnet_pytorch==0.7.0

Install dependencies

```
conda env create -f environment.yml
conda activate lss
```

Ensure consistent dataroot


### Pre-trained Model
Download a pre-trained BEV vehicle segmentation model from here: [https://drive.google.com/file/d/18fy-6beTFTZx5SrYLs9Xk7cY-fGSm7kw/view?usp=sharing](https://drive.google.com/file/d/18fy-6beTFTZx5SrYLs9Xk7cY-fGSm7kw/view?usp=sharing)
Expand Down
34 changes: 34 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: l3d
channels:
- pyg
- pytorch
- pytorch3d
- conda-forge
- fvcore
- iopath
- bottler
- defaults
dependencies:
# - cudatoolkit=11.0
# - python=3.9
- pip
- pytorch
- pytorch3d
- torchvision
- fvcore
- iopath
- nvidiacub
- pip:
- hydra-core
- Pillow
- plotly
- requests
- imageio
- matplotlib
- numpy
- PyMCubes
- tqdm
- visdom
- nuscenes-devkit
- tensorboardX
- efficientnet_pytorch==0.7.0
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from fire import Fire
from nuscenes import NuScenes

import src

Expand All @@ -13,7 +14,6 @@
Fire({
'lidar_check': src.explore.lidar_check,
'cumsum_check': src.explore.cumsum_check,

'train': src.train.train,
'eval_model_iou': src.explore.eval_model_iou,
'viz_model_preds': src.explore.viz_model_preds,
Expand Down
66 changes: 52 additions & 14 deletions src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from nuscenes.utils.data_classes import Box
from glob import glob

from pytorch3d.structures import Pointclouds
from torch.utils.data import default_collate

from .tools import get_lidar_data, img_transform, normalize_img, gen_dx_bx


Expand Down Expand Up @@ -69,7 +72,6 @@ def find_name(f):
if rec['channel'] == 'LIDAR_TOP' or (rec['is_key_frame'] and rec['channel'] in self.data_aug_conf['cams']):
rec['filename'] = info[rec['filename']]


def get_scenes(self):
# filter by scene split
split = {
Expand Down Expand Up @@ -219,22 +221,41 @@ def __getitem__(self, index):
imgs, rots, trans, intrins, post_rots, post_trans = self.get_image_data(rec, cams)
lidar_data = self.get_lidar_data(rec, nsweeps=3)
binimg = self.get_binimg(rec)



return imgs, rots, trans, intrins, post_rots, post_trans, lidar_data, binimg


class SegmentationData(NuscData):
def __init__(self, *args, **kwargs):
super(SegmentationData, self).__init__(*args, **kwargs)

self.nsweeps = 1

def __getitem__(self, index):
rec = self.ixes[index]

cams = self.choose_cams()
imgs, rots, trans, intrins, post_rots, post_trans = self.get_image_data(rec, cams)
lidar_pc = self.get_lidar_data(rec, self.nsweeps)
lidar_pc = lidar_pc.permute(1, 0)

binimg = self.get_binimg(rec)

return imgs, rots, trans, intrins, post_rots, post_trans, binimg

return imgs, rots, trans, intrins, post_rots, post_trans, binimg, lidar_pc

def collate_fn(self, batch):

imgs = default_collate(list(map(lambda x: x[0], batch)))
rots = default_collate(list(map(lambda x: x[1], batch)))
trans = default_collate(list(map(lambda x: x[2], batch)))
intrins = default_collate(list(map(lambda x: x[3], batch)))
post_rots = default_collate(list(map(lambda x: x[4], batch)))
post_trans = default_collate(list(map(lambda x: x[5], batch)))
bin_img = default_collate(list(map(lambda x: x[6], batch)))
lidar_pc = Pointclouds(list(map(lambda x: x[7], batch)))

return imgs, rots, trans, intrins, post_rots, post_trans, bin_img, lidar_pc



def worker_rnd_init(x):
Expand All @@ -243,8 +264,9 @@ def worker_rnd_init(x):

def compile_data(version, dataroot, data_aug_conf, grid_conf, bsz,
nworkers, parser_name):

nusc = NuScenes(version='v1.0-{}'.format(version),
dataroot=os.path.join(dataroot, version),
dataroot=dataroot,
verbose=False)
parser = {
'vizdata': VizData,
Expand All @@ -255,13 +277,29 @@ def compile_data(version, dataroot, data_aug_conf, grid_conf, bsz,
valdata = parser(nusc, is_train=False, data_aug_conf=data_aug_conf,
grid_conf=grid_conf)

trainloader = torch.utils.data.DataLoader(traindata, batch_size=bsz,
shuffle=True,
num_workers=nworkers,
drop_last=True,
worker_init_fn=worker_rnd_init)
valloader = torch.utils.data.DataLoader(valdata, batch_size=bsz,
shuffle=False,
num_workers=nworkers)
if parser_name == 'segmentationdata':
trainloader = torch.utils.data.DataLoader(traindata, batch_size=bsz,
shuffle=True,
num_workers=nworkers,
drop_last=True,
worker_init_fn=worker_rnd_init,
collate_fn=traindata.collate_fn)
valloader = torch.utils.data.DataLoader(valdata, batch_size=bsz,
shuffle=False,
num_workers=nworkers,
collate_fn=valdata.collate_fn)
elif parser_name == 'vizdata':
trainloader = torch.utils.data.DataLoader(traindata, batch_size=bsz,
shuffle=True,
num_workers=nworkers,
drop_last=True,
worker_init_fn=worker_rnd_init)
valloader = torch.utils.data.DataLoader(valdata, batch_size=bsz,
shuffle=False,
num_workers=nworkers)
else:
raise ValueError(parser_name)



return trainloader, valloader
69 changes: 51 additions & 18 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
Licensed under the NVIDIA Source Code License. See LICENSE at https://github.com/nv-tlabs/lift-splat-shoot.
Authors: Jonah Philion and Sanja Fidler
"""
import numpy as np
from typing import Optional, Any

import torch
import matplotlib
from matplotlib import pyplot as plt
from torch import nn
from efficientnet_pytorch import EfficientNet
from torchvision.models.resnet import resnet18
Expand Down Expand Up @@ -49,10 +53,14 @@ def get_depth_dist(self, x, eps=1e-20):
return x.softmax(dim=1)

def get_depth_feat(self, x):
x = self.get_eff_depth(x)
# Depth
x = self.depthnet(x)

# x [B * N, 3, H, W] -> [B * N, feat_dim, H_Down, W_Down]
efficient_net_feats = self.get_eff_depth(x)

# First D channels corresponds to your depth distributions, last C is your features.
x = self.depthnet(efficient_net_feats)

# Depth is basically your alpha which creates the probability distribution across dpeth
depth = self.get_depth_dist(x[:, :self.D])
new_x = depth.unsqueeze(1) * x[:, self.D:(self.D + self.C)].unsqueeze(2)

Expand Down Expand Up @@ -84,7 +92,7 @@ def get_eff_depth(self, x):
def forward(self, x):
depth, x = self.get_depth_feat(x)

return x
return depth, x


class BevEncode(nn.Module):
Expand Down Expand Up @@ -170,12 +178,13 @@ def get_geometry(self, rots, trans, intrins, post_rots, post_trans):
"""
B, N, _ = trans.shape

# undo post-transformation
# 1. De-apply the augmentation such that we have consistent geometric objects.
# N is the number of cameras.
# B x N x D x H x W x 3
points = self.frustum - post_trans.view(B, N, 1, 1, 1, 3)
points = torch.inverse(post_rots).view(B, N, 1, 1, 1, 3, 3).matmul(points.unsqueeze(-1))

# cam_to_ego
# 2. Perform transformation from camera to ego
points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3],
points[:, :, :, :, :, 2:3]
), 5)
Expand All @@ -190,12 +199,12 @@ def get_cam_feats(self, x):
"""
B, N, C, imH, imW = x.shape

x = x.view(B*N, C, imH, imW)
x = self.camencode(x)
x = x.view(B*N, C, imH, imW) # B * N, C, imH, imW
depth, x = self.camencode(x) # B * N, feat_dim, H, W
x = x.view(B, N, self.camC, self.D, imH//self.downsample, imW//self.downsample)
x = x.permute(0, 1, 3, 4, 5, 2)

return x
return depth, x

def voxel_pooling(self, geom_feats, x):
B, N, D, H, W, C = x.shape
Expand All @@ -213,8 +222,8 @@ def voxel_pooling(self, geom_feats, x):

# filter out points that are outside box
kept = (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < self.nx[0])\
& (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < self.nx[1])\
& (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < self.nx[2])
& (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < self.nx[1])\
& (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < self.nx[2])
x = x[kept]
geom_feats = geom_feats[kept]

Expand All @@ -241,18 +250,42 @@ def voxel_pooling(self, geom_feats, x):

return final

def get_voxels(self, x, rots, trans, intrins, post_rots, post_trans):
def get_voxel_and_depth_dist(self, x, rots, trans, intrins, post_rots, post_trans, lidar_pc: Optional[Any] = None):
# A batch (B) of frustums (N).
geom = self.get_geometry(rots, trans, intrins, post_rots, post_trans)
x = self.get_cam_feats(x)

x = self.voxel_pooling(geom, x)
B, N = geom.shape[0], geom.shape[1]
depth, x = self.get_cam_feats(x)

return x
# Divide the dim
depth = depth.view(B, N, *depth.shape[1:])

def forward(self, x, rots, trans, intrins, post_rots, post_trans):
x = self.get_voxels(x, rots, trans, intrins, post_rots, post_trans)
# Depth contains the post softmax depth logits
pred_pc = depth.unsqueeze(5) * geom
pred_pc = pred_pc.sum(dim=2)
x = self.voxel_pooling(geom, x)
return pred_pc, x

def forward(self, x, rots, trans, intrins, post_rots, post_trans, lidar_pc: Optional[Any] = None):
"""
Perform the forward pass for the whole LSS pipeline
:param x: <B, N, C, H, W> The images that we are trying to encode in the pipeline
:param rots: <B, N, 3, 3> The rotation matrices that represents the extrinsics for each cameras.
:param trans: <B, N, 3> The translation that represents the extrinsics for each cameras.
:param intrins: <B, N, 3, 3> The intrinsic matrices for each camera on the sensor rig.
:param post_rots: <B, N, 3, 3> Augmentation
:param post_trans: <B, N, 3> Augmentation
:param lidar_pc: <B, 3> The LIDAR point cloud that we are interested to learn depth from.
:return:
"""
# x is the batch of imgs
#
pred_pc, x = self.get_voxel_and_depth_dist(x, rots, trans, intrins, post_rots, post_trans, lidar_pc)
x = self.bevencode(x)
return x

B = pred_pc.shape[0]
pred_pc = pred_pc.view(B, -1, 3)
return pred_pc, x


def compile_model(grid_conf, data_aug_conf, outC):
Expand Down
5 changes: 3 additions & 2 deletions src/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from PIL import Image
from functools import reduce
import matplotlib as mpl

mpl.use('Agg')
import matplotlib.pyplot as plt
from nuscenes.utils.data_classes import LidarPointCloud
Expand Down Expand Up @@ -249,8 +250,8 @@ def get_val_info(model, valloader, loss_fn, device, use_tqdm=False):
loader = tqdm(valloader) if use_tqdm else valloader
with torch.no_grad():
for batch in loader:
allimgs, rots, trans, intrins, post_rots, post_trans, binimgs = batch
preds = model(allimgs.to(device), rots.to(device),
allimgs, rots, trans, intrins, post_rots, post_trans, binimgs, _ = batch
_, preds = model(allimgs.to(device), rots.to(device),
trans.to(device), intrins.to(device), post_rots.to(device),
post_trans.to(device))
binimgs = binimgs.to(device)
Expand Down
Loading