Skip to content

Commit

Permalink
(1) add support for ScanNet. (2) Visualize depths during testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
yashbhalgat committed Jan 2, 2023
1 parent 041e36f commit fed4664
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 36 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ The code-base has additional support for:
* Total Variation Loss for smoother embeddings (use `--tv-loss-weight` to enable)
* Sparsity-inducing loss on the ray weights (use `--sparse-loss-weight` to enable)

## ScanNet dataset support
The repo now supports training a NeRF model on a scene from the ScanNet dataset. I personally found setting up the ScanNet dataset to be a bit tricky. Please find some instructions/notes in [ScanNet.md](ScanNet.md).


## TODO:
* Voxel pruning during training and/or inference
Expand Down
22 changes: 22 additions & 0 deletions ScanNet.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# ScanNet Instructions

I personally found it a bit tricky to setup the ScanNet dataset the first time I tried it. So, I am compiling some notes/instructions on how to do it in case someone finds it useful.

### 1. Dataset download

To download ScanNet data and its labels, follow the instructions [here](https://github.com/ScanNet/ScanNet). Basically, fill out the ScanNet Terms of Use agreement and email it to [scannet@googlegroups.com](mailto:scannet@googlegroups.com). You will receive a download link to the dataset. Download the dataset and unzip it.

### 2. Use [SensReader](https://github.com/ScanNet/ScanNet/tree/master/SensReader/python) to extract RGB-D and camera data
Use the `reader.py` script as follows for each scene you want to work with:
```
python reader.py --filename [.sens file to export data from] --output_path [output directory to export data to]
Options:
--export_depth_images: export all depth frames as 16-bit pngs (depth shift 1000)
--export_color_images: export all color frames as 8-bit rgb jpgs
--export_poses: export all camera poses (4x4 matrix, camera to world)
--export_intrinsics: export camera intrinsics (4x4 matrix)
```

### 3. Then, use this [script](https://github.com/zju3dv/object_nerf/blob/main/data_preparation/scannet_sens_reader/convert_to_nerf_style_data.py) to convert the data to NeRF-style format. For instructions, see Step 1 [here](https://github.com/zju3dv/object_nerf/tree/main/data_preparation).
1. The generated transforms_xxx.json comes with transformation matrix (from camera coordinate to world coordinate) in SLAM / OpenCV format (xyz -> right down forward). You need to change to NDC format (xyz -> right up back) in the dataloader for training with NeRF convention.
2. For example, see the conversion done [here](https://github.com/cvg/nice-slam/blob/7af15cc33729aa5a8ca052908d96f495e34ab34c/src/utils/datasets.py#L205).
14 changes: 14 additions & 0 deletions configs/scannet_scene0000.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
expname = scannet_scene0000_00
basedir = ./logs
datadir = /work/yashsb/datasets/ScanNet/
dataset_type = scannet

no_batching = False

use_viewdirs = True
white_bkgd = False
lrate_decay = 500

N_samples = 64
N_importance = 128
N_rand = 1024
5 changes: 3 additions & 2 deletions hash_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def forward(self, x):
x_embedded_all = []
for i in range(self.n_levels):
resolution = torch.floor(self.base_resolution * self.b**i)
voxel_min_vertex, voxel_max_vertex, hashed_voxel_indices = get_voxel_vertices(\
voxel_min_vertex, voxel_max_vertex, hashed_voxel_indices, keep_mask = get_voxel_vertices(\
x, self.bounding_box, \
resolution, self.log2_hashmap_size)

Expand All @@ -69,7 +69,8 @@ def forward(self, x):
x_embedded = self.trilinear_interp(x, voxel_min_vertex, voxel_max_vertex, voxel_embedds)
x_embedded_all.append(x_embedded)

return torch.cat(x_embedded_all, dim=-1)
keep_mask = keep_mask.sum(dim=-1)==keep_mask.shape[-1]
return torch.cat(x_embedded_all, dim=-1), keep_mask


class SHEncoder(nn.Module):
Expand Down
4 changes: 1 addition & 3 deletions load_blender.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,4 @@ def load_blender_data(basedir, half_res=False, testskip=1):

bounding_box = get_bbox3d_for_blenderobj(metas["train"], H, W, near=2.0, far=6.0)

return imgs, poses, render_poses, [H, W, focal], i_split, bounding_box


return imgs, poses, render_poses, [H, W, focal], i_split, bounding_box
107 changes: 107 additions & 0 deletions load_scannet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os
import torch
import numpy as np
import imageio
import json
import torch.nn.functional as F
import cv2
import pyvista as pv

trans_t = lambda t : torch.Tensor([
[1,0,0,0],
[0,1,0,0],
[0,0,1,t],
[0,0,0,1]]).float()

rot_phi = lambda phi : torch.Tensor([
[1,0,0,0],
[0,np.cos(phi),-np.sin(phi),0],
[0,np.sin(phi), np.cos(phi),0],
[0,0,0,1]]).float()

rot_theta = lambda th : torch.Tensor([
[np.cos(th),0,-np.sin(th),0],
[0,1,0,0],
[np.sin(th),0, np.cos(th),0],
[0,0,0,1]]).float()


def pose_spherical(theta, phi, radius):
c2w = trans_t(radius)
c2w = rot_phi(phi/180.*np.pi) @ c2w
c2w = rot_theta(theta/180.*np.pi) @ c2w
c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
return c2w


def load_scannet_data(basedir, sceneID, half_res=False, trainskip=10, testskip=1):
'''
basedir is something like: "/work/yashsb/datasets/ScanNet/"
'''
scansdir = os.path.join(basedir, "scans")
basedir = os.path.join(basedir, "nerfstyle_"+sceneID)

splits = ['train', 'val', 'test']
metas = {}
for s in splits:
with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
metas[s] = json.load(fp)

all_imgs = []
all_poses = []
counts = [0]
for s in splits:
meta = metas[s]
imgs = []
poses = []
if s=='train':
skip = trainskip
else:
skip = testskip

for frame in meta['frames'][::skip]:
fname = os.path.join(basedir, frame['file_path'] + '.png')
imgs.append(imageio.imread(fname))
pose = np.array(frame['transform_matrix'])

### NEED to do this because ScanNet uses OpenCV convention
pose[:3, 1] *= -1
pose[:3, 2] *= -1

poses.append(pose)

imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA)
poses = np.array(poses).astype(np.float32)
counts.append(counts[-1] + imgs.shape[0])
all_imgs.append(imgs)
all_poses.append(poses)

i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)]

imgs = np.concatenate(all_imgs, 0)
poses = np.concatenate(all_poses, 0)

H, W = imgs[0].shape[:2]
camera_angle_x = float(meta['camera_angle_x'])
focal = .5 * W / np.tan(.5 * camera_angle_x)

render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)

if half_res:
H = H//2
W = W//2
focal = focal/2.

imgs_half_res = np.zeros((imgs.shape[0], H, W, 3))
for i, img in enumerate(imgs):
imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
imgs = imgs_half_res
# imgs = tf.image.resize_area(imgs, [400, 400]).numpy()

## getting an approximate bounding box for the scene
# load scene mesh
mesh = pv.read(os.path.join(scansdir, sceneID, f"{sceneID}_vh_clean.ply"))
# get the bounding box
bounding_box = torch.tensor(mesh.bounds[::2]) - 1, torch.tensor(mesh.bounds[1::2]) + 1

return imgs, poses, render_poses, [H, W, focal], i_split, bounding_box
63 changes: 42 additions & 21 deletions run_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from load_llff import load_llff_data
from load_deepvoxels import load_dv_data
from load_blender import load_blender_data
from load_scannet import load_scannet_data
from load_LINEMOD import load_LINEMOD_data


Expand All @@ -45,7 +46,7 @@ def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
"""Prepares inputs and applies network 'fn'.
"""
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
embedded = embed_fn(inputs_flat)
embedded, keep_mask = embed_fn(inputs_flat)

if viewdirs is not None:
input_dirs = viewdirs[:,None].expand(inputs.shape)
Expand All @@ -54,6 +55,7 @@ def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
embedded = torch.cat([embedded, embedded_dirs], -1)

outputs_flat = batchify(fn, netchunk)(embedded)
outputs_flat[~keep_mask, -1] = 0 # set sigma to 0 for invalid points
outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
return outputs

Expand Down Expand Up @@ -135,7 +137,7 @@ def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
all_ret[k] = torch.reshape(all_ret[k], k_sh)

k_extract = ['rgb_map', 'disp_map', 'acc_map']
k_extract = ['rgb_map', 'depth_map', 'acc_map']
ret_list = [all_ret[k] for k in k_extract]
ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
return ret_list + [ret_dict]
Expand All @@ -144,6 +146,7 @@ def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):

H, W, focal = hwf
near, far = render_kwargs['near'], render_kwargs['far']

if render_factor!=0:
# Render downsampled for speed
Expand All @@ -152,18 +155,20 @@ def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedi
focal = focal/render_factor

rgbs = []
disps = []
depths = []
psnrs = []

t = time.time()
for i, c2w in enumerate(tqdm(render_poses)):
print(i, time.time() - t)
t = time.time()
rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
rgb, depth, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
rgbs.append(rgb.cpu().numpy())
disps.append(disp.cpu().numpy())
# normalize depth to [0,1]
depth = (depth - near) / (far - near)
depths.append(depth.cpu().numpy())
if i==0:
print(rgb.shape, disp.shape)
print(rgb.shape, depth.shape)

if gt_imgs is not None and render_factor==0:
try:
Expand All @@ -174,11 +179,21 @@ def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedi
print(p)
psnrs.append(p)


if savedir is not None:
# save rgb and depth as a figure
fig = plt.figure(figsize=(25,15))
ax = fig.add_subplot(1, 2, 1)
rgb8 = to8b(rgbs[-1])
ax.imshow(rgb8)
ax.axis('off')
ax = fig.add_subplot(1, 2, 2)
ax.imshow(depths[-1], cmap='plasma', vmin=0, vmax=1)
ax.axis('off')
filename = os.path.join(savedir, '{:03d}.png'.format(i))
imageio.imwrite(filename, rgb8)
# save as png
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close(fig)
# imageio.imwrite(filename, rgb8)


rgbs = np.stack(rgbs, 0)
Expand Down Expand Up @@ -224,9 +239,6 @@ def create_nerf(args):

model_fine = None

# if args.i_embed==1:
# args.N_importance = 0

if args.N_importance > 0:
if args.i_embed==1:
model_fine = NeRFSmall(num_layers=2,
Expand All @@ -248,9 +260,6 @@ def create_nerf(args):

# Create optimizer
if args.i_embed==1:
# sparse_opt = torch.optim.SparseAdam(embedding_params, lr=args.lrate, betas=(0.9, 0.99), eps=1e-15)
# dense_opt = torch.optim.Adam(grad_vars, lr=args.lrate, betas=(0.9, 0.99), weight_decay=1e-6)
# optimizer = MultiOptimizer(optimizers={"sparse_opt": sparse_opt, "dense_opt": dense_opt})
optimizer = RAdam([
{'params': grad_vars, 'weight_decay': 1e-6},
{'params': embedding_params, 'eps': 1e-15}
Expand Down Expand Up @@ -352,8 +361,8 @@ def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=F
weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3]

depth_map = torch.sum(weights * z_vals, -1)
disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
depth_map = torch.sum(weights * z_vals, -1) / torch.sum(weights, -1)
disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map)
acc_map = torch.sum(weights, -1)

if white_bkgd:
Expand Down Expand Up @@ -445,13 +454,12 @@ def render_rays(ray_batch,

pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]

# raw = run_network(pts)
raw = network_query_fn(pts, viewdirs, network_fn)
rgb_map, disp_map, acc_map, weights, depth_map, sparsity_loss = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

if N_importance > 0:

rgb_map_0, disp_map_0, acc_map_0, sparsity_loss_0 = rgb_map, disp_map, acc_map, sparsity_loss
rgb_map_0, depth_map_0, acc_map_0, sparsity_loss_0 = rgb_map, depth_map, acc_map, sparsity_loss

z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
Expand All @@ -466,12 +474,12 @@ def render_rays(ray_batch,

rgb_map, disp_map, acc_map, weights, depth_map, sparsity_loss = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map, 'sparsity_loss': sparsity_loss}
ret = {'rgb_map' : rgb_map, 'depth_map' : depth_map, 'acc_map' : acc_map, 'sparsity_loss': sparsity_loss}
if retraw:
ret['raw'] = raw
if N_importance > 0:
ret['rgb0'] = rgb_map_0
ret['disp0'] = disp_map_0
ret['depth0'] = depth_map_0
ret['acc0'] = acc_map_0
ret['sparsity_loss0'] = sparsity_loss_0
ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
Expand Down Expand Up @@ -571,6 +579,10 @@ def config_parser():
parser.add_argument("--half_res", action='store_true',
help='load blender synthetic data at 400x400 instead of 800x800')

## scannet flags
parser.add_argument("--scannet_sceneID", type=str, default='scene0000_00',
help='sceneID to load from scannet')

## llff flags
parser.add_argument("--factor", type=int, default=8,
help='downsample factor for LLFF images')
Expand Down Expand Up @@ -658,6 +670,15 @@ def train():
else:
images = images[...,:3]

elif args.dataset_type == 'scannet':
images, poses, render_poses, hwf, i_split, bounding_box = load_scannet_data(args.datadir, args.scannet_sceneID, args.half_res)
args.bounding_box = bounding_box
print('Loaded scannet', images.shape, render_poses.shape, hwf, args.datadir)
i_train, i_val, i_test = i_split

near = 0.1
far = 10.0

elif args.dataset_type == 'LINEMOD':
images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip)
print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
Expand Down Expand Up @@ -854,7 +875,7 @@ def train():
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)

##### Core optimization loop #####
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
rgb, depth, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
verbose=i < 10, retraw=True,
**render_kwargs_train)

Expand Down
Loading

0 comments on commit fed4664

Please sign in to comment.