diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7419e4b --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +__pycache__/ +build/ +*.egg-info/ +*.so +*.mp4 + +tmp* +trial*/ + +data +data_utils/face_tracking/3DMM/* +data_utils/face_parsing/79999_iter.pth + +pretrained +*.mp4 \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f565fb7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 hawkey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/assets/main.png b/assets/main.png new file mode 100644 index 0000000..00a6375 Binary files /dev/null and b/assets/main.png differ diff --git a/data_utils/deepspeech_features/README.md b/data_utils/deepspeech_features/README.md new file mode 100644 index 0000000..c9f6c6b --- /dev/null +++ b/data_utils/deepspeech_features/README.md @@ -0,0 +1,20 @@ +# Routines for DeepSpeech features processing +Several routines for [DeepSpeech](https://github.com/mozilla/DeepSpeech) features processing, like speech features generation for [VOCA](https://github.com/TimoBolkart/voca) model. + +## Installation + +``` +pip3 install -r requirements.txt +``` + +## Usage + +Generate wav files: +``` +python3 extract_wav.py --in-video= +``` + +Generate files with DeepSpeech features: +``` +python3 extract_ds_features.py --input= +``` diff --git a/data_utils/deepspeech_features/deepspeech_features.py b/data_utils/deepspeech_features/deepspeech_features.py new file mode 100644 index 0000000..2efc586 --- /dev/null +++ b/data_utils/deepspeech_features/deepspeech_features.py @@ -0,0 +1,275 @@ +""" + DeepSpeech features processing routines. + NB: Based on VOCA code. See the corresponding license restrictions. +""" + +__all__ = ['conv_audios_to_deepspeech'] + +import numpy as np +import warnings +import resampy +from scipy.io import wavfile +from python_speech_features import mfcc +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() + +def conv_audios_to_deepspeech(audios, + out_files, + num_frames_info, + deepspeech_pb_path, + audio_window_size=1, + audio_window_stride=1): + """ + Convert list of audio files into files with DeepSpeech features. + + Parameters + ---------- + audios : list of str or list of None + Paths to input audio files. + out_files : list of str + Paths to output files with DeepSpeech features. + num_frames_info : list of int + List of numbers of frames. + deepspeech_pb_path : str + Path to DeepSpeech 0.1.0 frozen model. + audio_window_size : int, default 16 + Audio window size. + audio_window_stride : int, default 1 + Audio window stride. + """ + # deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" + graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net( + deepspeech_pb_path) + + with tf.compat.v1.Session(graph=graph) as sess: + for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info): + print(audio_file_path) + print(out_file_path) + audio_sample_rate, audio = wavfile.read(audio_file_path) + if audio.ndim != 1: + warnings.warn( + "Audio has multiple channels, the first channel is used") + audio = audio[:, 0] + ds_features = pure_conv_audio_to_deepspeech( + audio=audio, + audio_sample_rate=audio_sample_rate, + audio_window_size=audio_window_size, + audio_window_stride=audio_window_stride, + num_frames=num_frames, + net_fn=lambda x: sess.run( + logits_ph, + feed_dict={ + input_node_ph: x[np.newaxis, ...], + input_lengths_ph: [x.shape[0]]})) + + net_output = ds_features.reshape(-1, 29) + win_size = 16 + zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) + net_output = np.concatenate( + (zero_pad, net_output, zero_pad), axis=0) + windows = [] + for window_index in range(0, net_output.shape[0] - win_size, 2): + windows.append( + net_output[window_index:window_index + win_size]) + print(np.array(windows).shape) + np.save(out_file_path, np.array(windows)) + + +def prepare_deepspeech_net(deepspeech_pb_path): + """ + Load and prepare DeepSpeech network. + + Parameters + ---------- + deepspeech_pb_path : str + Path to DeepSpeech 0.1.0 frozen model. + + Returns + ------- + graph : obj + ThensorFlow graph. + logits_ph : obj + ThensorFlow placeholder for `logits`. + input_node_ph : obj + ThensorFlow placeholder for `input_node`. + input_lengths_ph : obj + ThensorFlow placeholder for `input_lengths`. + """ + # Load graph and place_holders: + with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + + graph = tf.compat.v1.get_default_graph() + tf.import_graph_def(graph_def, name="deepspeech") + logits_ph = graph.get_tensor_by_name("deepspeech/logits:0") + input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0") + input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0") + + return graph, logits_ph, input_node_ph, input_lengths_ph + + +def pure_conv_audio_to_deepspeech(audio, + audio_sample_rate, + audio_window_size, + audio_window_stride, + num_frames, + net_fn): + """ + Core routine for converting audion into DeepSpeech features. + + Parameters + ---------- + audio : np.array + Audio data. + audio_sample_rate : int + Audio sample rate. + audio_window_size : int + Audio window size. + audio_window_stride : int + Audio window stride. + num_frames : int or None + Numbers of frames. + net_fn : func + Function for DeepSpeech model call. + + Returns + ------- + np.array + DeepSpeech features. + """ + target_sample_rate = 16000 + if audio_sample_rate != target_sample_rate: + resampled_audio = resampy.resample( + x=audio.astype(np.float), + sr_orig=audio_sample_rate, + sr_new=target_sample_rate) + else: + resampled_audio = audio.astype(np.float) + input_vector = conv_audio_to_deepspeech_input_vector( + audio=resampled_audio.astype(np.int16), + sample_rate=target_sample_rate, + num_cepstrum=26, + num_context=9) + + network_output = net_fn(input_vector) + # print(network_output.shape) + + deepspeech_fps = 50 + video_fps = 50 # Change this option if video fps is different + audio_len_s = float(audio.shape[0]) / audio_sample_rate + if num_frames is None: + num_frames = int(round(audio_len_s * video_fps)) + else: + video_fps = num_frames / audio_len_s + network_output = interpolate_features( + features=network_output[:, 0], + input_rate=deepspeech_fps, + output_rate=video_fps, + output_len=num_frames) + + # Make windows: + zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1])) + network_output = np.concatenate( + (zero_pad, network_output, zero_pad), axis=0) + windows = [] + for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride): + windows.append( + network_output[window_index:window_index + audio_window_size]) + + return np.array(windows) + + +def conv_audio_to_deepspeech_input_vector(audio, + sample_rate, + num_cepstrum, + num_context): + """ + Convert audio raw data into DeepSpeech input vector. + + Parameters + ---------- + audio : np.array + Audio data. + audio_sample_rate : int + Audio sample rate. + num_cepstrum : int + Number of cepstrum. + num_context : int + Number of context. + + Returns + ------- + np.array + DeepSpeech input vector. + """ + # Get mfcc coefficients: + features = mfcc( + signal=audio, + samplerate=sample_rate, + numcep=num_cepstrum) + + # We only keep every second feature (BiRNN stride = 2): + features = features[::2] + + # One stride per time step in the input: + num_strides = len(features) + + # Add empty initial and final contexts: + empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype) + features = np.concatenate((empty_context, features, empty_context)) + + # Create a view into the array with overlapping strides of size + # numcontext (past) + 1 (present) + numcontext (future): + window_size = 2 * num_context + 1 + train_inputs = np.lib.stride_tricks.as_strided( + features, + shape=(num_strides, window_size, num_cepstrum), + strides=(features.strides[0], + features.strides[0], features.strides[1]), + writeable=False) + + # Flatten the second and third dimensions: + train_inputs = np.reshape(train_inputs, [num_strides, -1]) + + train_inputs = np.copy(train_inputs) + train_inputs = (train_inputs - np.mean(train_inputs)) / \ + np.std(train_inputs) + + return train_inputs + + +def interpolate_features(features, + input_rate, + output_rate, + output_len): + """ + Interpolate DeepSpeech features. + + Parameters + ---------- + features : np.array + DeepSpeech features. + input_rate : int + input rate (FPS). + output_rate : int + Output rate (FPS). + output_len : int + Output data length. + + Returns + ------- + np.array + Interpolated data. + """ + input_len = features.shape[0] + num_features = features.shape[1] + input_timestamps = np.arange(input_len) / float(input_rate) + output_timestamps = np.arange(output_len) / float(output_rate) + output_features = np.zeros((output_len, num_features)) + for feature_idx in range(num_features): + output_features[:, feature_idx] = np.interp( + x=output_timestamps, + xp=input_timestamps, + fp=features[:, feature_idx]) + return output_features diff --git a/data_utils/deepspeech_features/deepspeech_store.py b/data_utils/deepspeech_features/deepspeech_store.py new file mode 100644 index 0000000..4c2f603 --- /dev/null +++ b/data_utils/deepspeech_features/deepspeech_store.py @@ -0,0 +1,172 @@ +""" + Routines for loading DeepSpeech model. +""" + +__all__ = ['get_deepspeech_model_file'] + +import os +import zipfile +import logging +import hashlib + + +deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features' + + +def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")): + """ + Return location for the pretrained on local file system. This function will download from online model zoo when + model cannot be found or has mismatch. The root directory will be created if it doesn't exist. + + Parameters + ---------- + local_model_store_dir_path : str, default $TENSORFLOW_HOME/models + Location for keeping the model parameters. + + Returns + ------- + file_path + Path to the requested pretrained model file. + """ + sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e" + file_name = "deepspeech-0_1_0-b90017e8.pb" + local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path) + file_path = os.path.join(local_model_store_dir_path, file_name) + if os.path.exists(file_path): + if _check_sha1(file_path, sha1_hash): + return file_path + else: + logging.warning("Mismatch in the content of model file detected. Downloading again.") + else: + logging.info("Model file not found. Downloading to {}.".format(file_path)) + + if not os.path.exists(local_model_store_dir_path): + os.makedirs(local_model_store_dir_path) + + zip_file_path = file_path + ".zip" + _download( + url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format( + repo_url=deepspeech_features_repo_url, + repo_release_tag="v0.0.1", + file_name=file_name), + path=zip_file_path, + overwrite=True) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(local_model_store_dir_path) + os.remove(zip_file_path) + + if _check_sha1(file_path, sha1_hash): + return file_path + else: + raise ValueError("Downloaded file has different hash. Please try again.") + + +def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): + """ + Download an given URL + + Parameters + ---------- + url : str + URL to download + path : str, optional + Destination path to store downloaded file. By default stores to the + current directory with same name as in url. + overwrite : bool, optional + Whether to overwrite destination file if already exists. + sha1_hash : str, optional + Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified + but doesn't match. + retries : integer, default 5 + The number of times to attempt the download in case of failure or non 200 return codes + verify_ssl : bool, default True + Verify SSL certificates. + + Returns + ------- + str + The file path of the downloaded file. + """ + import warnings + try: + import requests + except ImportError: + class requests_failed_to_import(object): + pass + requests = requests_failed_to_import + + if path is None: + fname = url.split("/")[-1] + # Empty filenames are invalid + assert fname, "Can't construct file-name from this URL. Please set the `path` option manually." + else: + path = os.path.expanduser(path) + if os.path.isdir(path): + fname = os.path.join(path, url.split("/")[-1]) + else: + fname = path + assert retries >= 0, "Number of retries should be at least 0" + + if not verify_ssl: + warnings.warn( + "Unverified HTTPS request is being made (verify_ssl=False). " + "Adding certificate verification is strongly advised.") + + if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)): + dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) + if not os.path.exists(dirname): + os.makedirs(dirname) + while retries + 1 > 0: + # Disable pyling too broad Exception + # pylint: disable=W0703 + try: + print("Downloading {} from {}...".format(fname, url)) + r = requests.get(url, stream=True, verify=verify_ssl) + if r.status_code != 200: + raise RuntimeError("Failed downloading url {}".format(url)) + with open(fname, "wb") as f: + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if sha1_hash and not _check_sha1(fname, sha1_hash): + raise UserWarning("File {} is downloaded but the content hash does not match." + " The repo may be outdated or download may be incomplete. " + "If the `repo_url` is overridden, consider switching to " + "the default repo.".format(fname)) + break + except Exception as e: + retries -= 1 + if retries <= 0: + raise e + else: + print("download failed, retrying, {} attempt{} left" + .format(retries, "s" if retries > 1 else "")) + + return fname + + +def _check_sha1(filename, sha1_hash): + """ + Check whether the sha1 hash of the file content matches the expected hash. + + Parameters + ---------- + filename : str + Path to the file. + sha1_hash : str + Expected sha1 hash in hexadecimal digits. + + Returns + ------- + bool + Whether the file content matches the expected hash. + """ + sha1 = hashlib.sha1() + with open(filename, "rb") as f: + while True: + data = f.read(1048576) + if not data: + break + sha1.update(data) + + return sha1.hexdigest() == sha1_hash diff --git a/data_utils/deepspeech_features/extract_ds_features.py b/data_utils/deepspeech_features/extract_ds_features.py new file mode 100644 index 0000000..db525d1 --- /dev/null +++ b/data_utils/deepspeech_features/extract_ds_features.py @@ -0,0 +1,132 @@ +""" + Script for extracting DeepSpeech features from audio file. +""" + +import os +import argparse +import numpy as np +import pandas as pd +from deepspeech_store import get_deepspeech_model_file +from deepspeech_features import conv_audios_to_deepspeech + + +def parse_args(): + """ + Create python script parameters. + Returns + ------- + ArgumentParser + Resulted args. + """ + parser = argparse.ArgumentParser( + description="Extract DeepSpeech features from audio file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--input", + type=str, + required=True, + help="path to input audio file or directory") + parser.add_argument( + "--output", + type=str, + help="path to output file with DeepSpeech features") + parser.add_argument( + "--deepspeech", + type=str, + help="path to DeepSpeech 0.1.0 frozen model") + parser.add_argument( + "--metainfo", + type=str, + help="path to file with meta-information") + + args = parser.parse_args() + return args + + +def extract_features(in_audios, + out_files, + deepspeech_pb_path, + metainfo_file_path=None): + """ + Real extract audio from video file. + Parameters + ---------- + in_audios : list of str + Paths to input audio files. + out_files : list of str + Paths to output files with DeepSpeech features. + deepspeech_pb_path : str + Path to DeepSpeech 0.1.0 frozen model. + metainfo_file_path : str, default None + Path to file with meta-information. + """ + #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" + if metainfo_file_path is None: + num_frames_info = [None] * len(in_audios) + else: + train_df = pd.read_csv( + metainfo_file_path, + sep="\t", + index_col=False, + dtype={"Id": np.int, "File": np.unicode, "Count": np.int}) + num_frames_info = train_df["Count"].values + assert (len(num_frames_info) == len(in_audios)) + + for i, in_audio in enumerate(in_audios): + if not out_files[i]: + file_stem, _ = os.path.splitext(in_audio) + out_files[i] = file_stem + ".npy" + #print(out_files[i]) + conv_audios_to_deepspeech( + audios=in_audios, + out_files=out_files, + num_frames_info=num_frames_info, + deepspeech_pb_path=deepspeech_pb_path) + + +def main(): + """ + Main body of script. + """ + args = parse_args() + in_audio = os.path.expanduser(args.input) + if not os.path.exists(in_audio): + raise Exception("Input file/directory doesn't exist: {}".format(in_audio)) + deepspeech_pb_path = args.deepspeech + #add + deepspeech_pb_path = True + args.deepspeech = '~/.tensorflow/models/deepspeech-0_1_0-b90017e8.pb' + #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" + if deepspeech_pb_path is None: + deepspeech_pb_path = "" + if deepspeech_pb_path: + deepspeech_pb_path = os.path.expanduser(args.deepspeech) + if not os.path.exists(deepspeech_pb_path): + deepspeech_pb_path = get_deepspeech_model_file() + if os.path.isfile(in_audio): + extract_features( + in_audios=[in_audio], + out_files=[args.output], + deepspeech_pb_path=deepspeech_pb_path, + metainfo_file_path=args.metainfo) + else: + audio_file_paths = [] + for file_name in os.listdir(in_audio): + if not os.path.isfile(os.path.join(in_audio, file_name)): + continue + _, file_ext = os.path.splitext(file_name) + if file_ext.lower() == ".wav": + audio_file_path = os.path.join(in_audio, file_name) + audio_file_paths.append(audio_file_path) + audio_file_paths = sorted(audio_file_paths) + out_file_paths = [""] * len(audio_file_paths) + extract_features( + in_audios=audio_file_paths, + out_files=out_file_paths, + deepspeech_pb_path=deepspeech_pb_path, + metainfo_file_path=args.metainfo) + + +if __name__ == "__main__": + main() + diff --git a/data_utils/deepspeech_features/extract_wav.py b/data_utils/deepspeech_features/extract_wav.py new file mode 100644 index 0000000..5f39e8b --- /dev/null +++ b/data_utils/deepspeech_features/extract_wav.py @@ -0,0 +1,87 @@ +""" + Script for extracting audio (16-bit, mono, 22000 Hz) from video file. +""" + +import os +import argparse +import subprocess + + +def parse_args(): + """ + Create python script parameters. + + Returns + ------- + ArgumentParser + Resulted args. + """ + parser = argparse.ArgumentParser( + description="Extract audio from video file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--in-video", + type=str, + required=True, + help="path to input video file or directory") + parser.add_argument( + "--out-audio", + type=str, + help="path to output audio file") + + args = parser.parse_args() + return args + + +def extract_audio(in_video, + out_audio): + """ + Real extract audio from video file. + + Parameters + ---------- + in_video : str + Path to input video file. + out_audio : str + Path to output audio file. + """ + if not out_audio: + file_stem, _ = os.path.splitext(in_video) + out_audio = file_stem + ".wav" + # command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}" + # command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" + # command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" + command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}" + subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True) + + +def main(): + """ + Main body of script. + """ + args = parse_args() + in_video = os.path.expanduser(args.in_video) + if not os.path.exists(in_video): + raise Exception("Input file/directory doesn't exist: {}".format(in_video)) + if os.path.isfile(in_video): + extract_audio( + in_video=in_video, + out_audio=args.out_audio) + else: + video_file_paths = [] + for file_name in os.listdir(in_video): + if not os.path.isfile(os.path.join(in_video, file_name)): + continue + _, file_ext = os.path.splitext(file_name) + if file_ext.lower() in (".mp4", ".mkv", ".avi"): + video_file_path = os.path.join(in_video, file_name) + video_file_paths.append(video_file_path) + video_file_paths = sorted(video_file_paths) + for video_file_path in video_file_paths: + extract_audio( + in_video=video_file_path, + out_audio="") + + +if __name__ == "__main__": + main() diff --git a/data_utils/deepspeech_features/fea_win.py b/data_utils/deepspeech_features/fea_win.py new file mode 100644 index 0000000..4f9c666 --- /dev/null +++ b/data_utils/deepspeech_features/fea_win.py @@ -0,0 +1,11 @@ +import numpy as np + +net_output = np.load('french.ds.npy').reshape(-1, 29) +win_size = 16 +zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) +net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0) +windows = [] +for window_index in range(0, net_output.shape[0] - win_size, 2): + windows.append(net_output[window_index:window_index + win_size]) +print(np.array(windows).shape) +np.save('aud_french.npy', np.array(windows)) diff --git a/data_utils/face_parsing/logger.py b/data_utils/face_parsing/logger.py new file mode 100644 index 0000000..ad8452b --- /dev/null +++ b/data_utils/face_parsing/logger.py @@ -0,0 +1,23 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import os.path as osp +import time +import sys +import logging + +import torch.distributed as dist + + +def setup_logger(logpth): + logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) + logfile = osp.join(logpth, logfile) + FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' + log_level = logging.INFO + if dist.is_initialized() and not dist.get_rank()==0: + log_level = logging.ERROR + logging.basicConfig(level=log_level, format=FORMAT, filename=logfile) + logging.root.addHandler(logging.StreamHandler()) + + diff --git a/data_utils/face_parsing/model.py b/data_utils/face_parsing/model.py new file mode 100644 index 0000000..43181f0 --- /dev/null +++ b/data_utils/face_parsing/model.py @@ -0,0 +1,285 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from resnet import Resnet18 +# from modules.bn import InPlaceABNSync as BatchNorm2d + + +class ConvBNReLU(nn.Module): + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, + out_chan, + kernel_size = ks, + stride = stride, + padding = padding, + bias = False) + self.bn = nn.BatchNorm2d(out_chan) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + +class BiSeNetOutput(nn.Module): + def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) + self.init_weight() + + def forward(self, x): + x = self.conv(x) + x = self.conv_out(x) + return x + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class AttentionRefinementModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + self.init_weight() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + +class ContextPath(nn.Module): + def __init__(self, *args, **kwargs): + super(ContextPath, self).__init__() + self.resnet = Resnet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + self.init_weight() + + def forward(self, x): + H0, W0 = x.size()[2:] + feat8, feat16, feat32 = self.resnet(x) + H8, W8 = feat8.size()[2:] + H16, W16 = feat16.size()[2:] + H32, W32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (H32, W32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +### This is not used, since I replace this with the resnet feature with the same size +class SpatialPath(nn.Module): + def __init__(self, *args, **kwargs): + super(SpatialPath, self).__init__() + self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) + self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) + self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) + self.init_weight() + + def forward(self, x): + feat = self.conv1(x) + feat = self.conv2(feat) + feat = self.conv3(feat) + feat = self.conv_out(feat) + return feat + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class FeatureFusionModule(nn.Module): + def __init__(self, in_chan, out_chan, *args, **kwargs): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, + out_chan//4, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.conv2 = nn.Conv2d(out_chan//4, + out_chan, + kernel_size = 1, + stride = 1, + padding = 0, + bias = False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + self.init_weight() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +class BiSeNet(nn.Module): + def __init__(self, n_classes, *args, **kwargs): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + ## here self.sp is deleted + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, n_classes) + self.conv_out16 = BiSeNetOutput(128, 64, n_classes) + self.conv_out32 = BiSeNetOutput(128, 64, n_classes) + self.init_weight() + + def forward(self, x): + H, W = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature + feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + feat_out = self.conv_out(feat_fuse) + feat_out16 = self.conv_out16(feat_cp8) + feat_out32 = self.conv_out32(feat_cp16) + + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) + feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) + feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) + + # return feat_out, feat_out16, feat_out32 + return feat_out + + def init_weight(self): + for ly in self.children(): + if isinstance(ly, nn.Conv2d): + nn.init.kaiming_normal_(ly.weight, a=1) + if not ly.bias is None: nn.init.constant_(ly.bias, 0) + + def get_params(self): + wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] + for name, child in self.named_children(): + child_wd_params, child_nowd_params = child.get_params() + if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): + lr_mul_wd_params += child_wd_params + lr_mul_nowd_params += child_nowd_params + else: + wd_params += child_wd_params + nowd_params += child_nowd_params + return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params + + +if __name__ == "__main__": + net = BiSeNet(19) + net.cuda() + net.eval() + in_ten = torch.randn(16, 3, 640, 480).cuda() + out, out16, out32 = net(in_ten) + print(out.shape) + + net.get_params() diff --git a/data_utils/face_parsing/resnet.py b/data_utils/face_parsing/resnet.py new file mode 100644 index 0000000..64969da --- /dev/null +++ b/data_utils/face_parsing/resnet.py @@ -0,0 +1,109 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as modelzoo + +# from modules.bn import InPlaceABNSync as BatchNorm2d + +resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum-1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class Resnet18(nn.Module): + def __init__(self): + super(Resnet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + self.init_weight() + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 + + def init_weight(self): + state_dict = modelzoo.load_url(resnet18_url) + self_state_dict = self.state_dict() + for k, v in state_dict.items(): + if 'fc' in k: continue + self_state_dict.update({k: v}) + self.load_state_dict(self_state_dict) + + def get_params(self): + wd_params, nowd_params = [], [] + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv2d)): + wd_params.append(module.weight) + if not module.bias is None: + nowd_params.append(module.bias) + elif isinstance(module, nn.BatchNorm2d): + nowd_params += list(module.parameters()) + return wd_params, nowd_params + + +if __name__ == "__main__": + net = Resnet18() + x = torch.randn(16, 3, 224, 224) + out = net(x) + print(out[0].size()) + print(out[1].size()) + print(out[2].size()) + net.get_params() diff --git a/data_utils/face_parsing/test.py b/data_utils/face_parsing/test.py new file mode 100644 index 0000000..ede8481 --- /dev/null +++ b/data_utils/face_parsing/test.py @@ -0,0 +1,98 @@ +#!/usr/bin/python +# -*- encoding: utf-8 -*- +import numpy as np +from model import BiSeNet + +import torch + +import os +import os.path as osp + +from PIL import Image +import torchvision.transforms as transforms +import cv2 +from pathlib import Path +import configargparse +import tqdm + +# import ttach as tta + +def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg', + img_size=(512, 512)): + im = np.array(im) + vis_im = im.copy().astype(np.uint8) + vis_parsing_anno = parsing_anno.copy().astype(np.uint8) + vis_parsing_anno = cv2.resize( + vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) + vis_parsing_anno_color = np.zeros( + (vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) # + 255 + + num_of_class = np.max(vis_parsing_anno) + # print(num_of_class) + for pi in range(1, 14): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) + + for pi in range(14, 16): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0]) + for pi in range(16, 17): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255]) + for pi in range(17, num_of_class+1): + index = np.where(vis_parsing_anno == pi) + vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) + + vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) + index = np.where(vis_parsing_anno == num_of_class-1) + vis_im = cv2.resize(vis_parsing_anno_color, img_size, + interpolation=cv2.INTER_NEAREST) + if save_im: + cv2.imwrite(save_path, vis_im) + + +def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): + + Path(respth).mkdir(parents=True, exist_ok=True) + + print(f'[INFO] loading model...') + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + net.cuda() + net.load_state_dict(torch.load(cp)) + net.eval() + + to_tensor = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + + image_paths = os.listdir(dspth) + + with torch.no_grad(): + for image_path in tqdm.tqdm(image_paths): + if image_path.endswith('.jpg') or image_path.endswith('.png'): + img = Image.open(osp.join(dspth, image_path)) + ori_size = img.size + image = img.resize((512, 512), Image.BILINEAR) + image = image.convert("RGB") + img = to_tensor(image) + + # test-time augmentation. + inputs = torch.unsqueeze(img, 0) # [1, 3, 512, 512] + outputs = net(inputs.cuda()) + parsing = outputs.mean(0).cpu().numpy().argmax(0) + + image_path = int(image_path[:-4]) + image_path = str(image_path) + '.png' + + vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size) + + +if __name__ == "__main__": + parser = configargparse.ArgumentParser() + parser.add_argument('--respath', type=str, default='./result/', help='result path for label') + parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images') + parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth') + args = parser.parse_args() + evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath) diff --git a/data_utils/face_tracking/__init__.py b/data_utils/face_tracking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_utils/face_tracking/convert_BFM.py b/data_utils/face_tracking/convert_BFM.py new file mode 100644 index 0000000..5c64af6 --- /dev/null +++ b/data_utils/face_tracking/convert_BFM.py @@ -0,0 +1,39 @@ +import numpy as np +from scipy.io import loadmat + +original_BFM = loadmat("3DMM/01_MorphableModel.mat") +sub_inds = np.load("3DMM/topology_info.npy", allow_pickle=True).item()["sub_inds"] + +shapePC = original_BFM["shapePC"] +shapeEV = original_BFM["shapeEV"] +shapeMU = original_BFM["shapeMU"] +texPC = original_BFM["texPC"] +texEV = original_BFM["texEV"] +texMU = original_BFM["texMU"] + +b_shape = shapePC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) +mu_shape = shapeMU.reshape(-1, 3) + +b_tex = texPC.reshape(-1, 199).transpose(1, 0).reshape(199, -1, 3) +mu_tex = texMU.reshape(-1, 3) + +b_shape = b_shape[:, sub_inds, :].reshape(199, -1) +mu_shape = mu_shape[sub_inds, :].reshape(-1) +b_tex = b_tex[:, sub_inds, :].reshape(199, -1) +mu_tex = mu_tex[sub_inds, :].reshape(-1) + +exp_info = np.load("3DMM/exp_info.npy", allow_pickle=True).item() +np.save( + "3DMM/3DMM_info.npy", + { + "mu_shape": mu_shape, + "b_shape": b_shape, + "sig_shape": shapeEV.reshape(-1), + "mu_exp": exp_info["mu_exp"], + "b_exp": exp_info["base_exp"], + "sig_exp": exp_info["sig_exp"], + "mu_tex": mu_tex, + "b_tex": b_tex, + "sig_tex": texEV.reshape(-1), + }, +) diff --git a/data_utils/face_tracking/data_loader.py b/data_utils/face_tracking/data_loader.py new file mode 100644 index 0000000..ba89904 --- /dev/null +++ b/data_utils/face_tracking/data_loader.py @@ -0,0 +1,16 @@ +import os +import torch +import numpy as np + + +def load_dir(path, start, end): + lmss = [] + imgs_paths = [] + for i in range(start, end): + if os.path.isfile(os.path.join(path, str(i) + ".lms")): + lms = np.loadtxt(os.path.join(path, str(i) + ".lms"), dtype=np.float32) + lmss.append(lms) + imgs_paths.append(os.path.join(path, str(i) + ".jpg")) + lmss = np.stack(lmss) + lmss = torch.as_tensor(lmss).cuda() + return lmss, imgs_paths diff --git a/data_utils/face_tracking/face_tracker.py b/data_utils/face_tracking/face_tracker.py new file mode 100644 index 0000000..438d112 --- /dev/null +++ b/data_utils/face_tracking/face_tracker.py @@ -0,0 +1,390 @@ +import os +import sys +import cv2 +import argparse +from pathlib import Path +import torch +import numpy as np +from data_loader import load_dir +from facemodel import Face_3DMM +from util import * +from render_3dmm import Render_3DMM + + +# torch.autograd.set_detect_anomaly(True) + +dir_path = os.path.dirname(os.path.realpath(__file__)) + + +def set_requires_grad(tensor_list): + for tensor in tensor_list: + tensor.requires_grad = True + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--path", type=str, default="obama/ori_imgs", help="idname of target person" +) +parser.add_argument("--img_h", type=int, default=512, help="image height") +parser.add_argument("--img_w", type=int, default=512, help="image width") +parser.add_argument("--frame_num", type=int, default=11000, help="image number") +args = parser.parse_args() + +start_id = 0 +end_id = args.frame_num + +lms, img_paths = load_dir(args.path, start_id, end_id) +num_frames = lms.shape[0] +h, w = args.img_h, args.img_w +cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).cuda() +id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650 +model_3dmm = Face_3DMM( + os.path.join(dir_path, "3DMM"), id_dim, exp_dim, tex_dim, point_num +) + +# only use one image per 40 to do fit the focal length +sel_ids = np.arange(0, num_frames, 40) +sel_num = sel_ids.shape[0] +arg_focal = 1600 +arg_landis = 1e5 + +print(f'[INFO] fitting focal length...') + +# fit the focal length +for focal in range(600, 1500, 100): + id_para = lms.new_zeros((1, id_dim), requires_grad=True) + exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True) + euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True) + trans = lms.new_zeros((sel_num, 3), requires_grad=True) + trans.data[:, 2] -= 7 + focal_length = lms.new_zeros(1, requires_grad=False) + focal_length.data += focal + set_requires_grad([id_para, exp_para, euler_angle, trans]) + + optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1) + optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=0.1) + + for iter in range(2000): + id_para_batch = id_para.expand(sel_num, -1) + geometry = model_3dmm.get_3dlandmarks( + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy + ) + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach()) + loss = loss_lan + optimizer_frame.zero_grad() + loss.backward() + optimizer_frame.step() + # if iter % 100 == 0: + # print(focal, 'pose', iter, loss.item()) + + for iter in range(2500): + id_para_batch = id_para.expand(sel_num, -1) + geometry = model_3dmm.get_3dlandmarks( + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy + ) + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms[sel_ids].detach()) + loss_regid = torch.mean(id_para * id_para) + loss_regexp = torch.mean(exp_para * exp_para) + loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4 + optimizer_idexp.zero_grad() + optimizer_frame.zero_grad() + loss.backward() + optimizer_idexp.step() + optimizer_frame.step() + # if iter % 100 == 0: + # print(focal, 'poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item()) + + if iter % 1500 == 0 and iter >= 1500: + for param_group in optimizer_idexp.param_groups: + param_group["lr"] *= 0.2 + for param_group in optimizer_frame.param_groups: + param_group["lr"] *= 0.2 + + print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item()) + + if loss_lan.item() < arg_landis: + arg_landis = loss_lan.item() + arg_focal = focal + +print("[INFO] find best focal:", arg_focal) + +print(f'[INFO] coarse fitting...') + +# for all frames, do a coarse fitting ??? +id_para = lms.new_zeros((1, id_dim), requires_grad=True) +exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True) +tex_para = lms.new_zeros( + (1, tex_dim), requires_grad=True +) # not optimized in this block ??? +euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True) +trans = lms.new_zeros((num_frames, 3), requires_grad=True) +light_para = lms.new_zeros((num_frames, 27), requires_grad=True) +trans.data[:, 2] -= 7 # ??? +focal_length = lms.new_zeros(1, requires_grad=True) +focal_length.data += arg_focal + +set_requires_grad([id_para, exp_para, tex_para, euler_angle, trans, light_para]) + +optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=0.1) +optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=1) + +for iter in range(1500): + id_para_batch = id_para.expand(num_frames, -1) + geometry = model_3dmm.get_3dlandmarks( + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy + ) + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach()) + loss = loss_lan + optimizer_frame.zero_grad() + loss.backward() + optimizer_frame.step() + if iter == 1000: + for param_group in optimizer_frame.param_groups: + param_group["lr"] = 0.1 + # if iter % 100 == 0: + # print('pose', iter, loss.item()) + +for param_group in optimizer_frame.param_groups: + param_group["lr"] = 0.1 + +for iter in range(2000): + id_para_batch = id_para.expand(num_frames, -1) + geometry = model_3dmm.get_3dlandmarks( + id_para_batch, exp_para, euler_angle, trans, focal_length, cxy + ) + proj_geo = forward_transform(geometry, euler_angle, trans, focal_length, cxy) + loss_lan = cal_lan_loss(proj_geo[:, :, :2], lms.detach()) + loss_regid = torch.mean(id_para * id_para) + loss_regexp = torch.mean(exp_para * exp_para) + loss = loss_lan + loss_regid * 0.5 + loss_regexp * 0.4 + optimizer_idexp.zero_grad() + optimizer_frame.zero_grad() + loss.backward() + optimizer_idexp.step() + optimizer_frame.step() + # if iter % 100 == 0: + # print('poseidexp', iter, loss_lan.item(), loss_regid.item(), loss_regexp.item()) + if iter % 1000 == 0 and iter >= 1000: + for param_group in optimizer_idexp.param_groups: + param_group["lr"] *= 0.2 + for param_group in optimizer_frame.param_groups: + param_group["lr"] *= 0.2 + +print(loss_lan.item(), torch.mean(trans[:, 2]).item()) + +print(f'[INFO] fitting light...') + +batch_size = 32 + +device_default = torch.device("cuda:0") +device_render = torch.device("cuda:0") +renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render) + +sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size] +imgs = [] +for sel_id in sel_ids: + imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1]) +imgs = np.stack(imgs) +sel_imgs = torch.as_tensor(imgs).cuda() +sel_lms = lms[sel_ids] +sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True) +set_requires_grad([sel_light]) + +optimizer_tl = torch.optim.Adam([tex_para, sel_light], lr=0.1) +optimizer_id_frame = torch.optim.Adam([euler_angle, trans, exp_para, id_para], lr=0.01) + +for iter in range(71): + sel_exp_para, sel_euler, sel_trans = ( + exp_para[sel_ids], + euler_angle[sel_ids], + trans[sel_ids], + ) + sel_id_para = id_para.expand(batch_size, -1) + geometry = model_3dmm.get_3dlandmarks( + sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy + ) + proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy) + + loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach()) + loss_regid = torch.mean(id_para * id_para) + loss_regexp = torch.mean(sel_exp_para * sel_exp_para) + + sel_tex_para = tex_para.expand(batch_size, -1) + sel_texture = model_3dmm.forward_tex(sel_tex_para) + geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) + rott_geo = forward_rott(geometry, sel_euler, sel_trans) + render_imgs = renderer( + rott_geo.to(device_render), + sel_texture.to(device_render), + sel_light.to(device_render), + ) + render_imgs = render_imgs.to(device_default) + + mask = (render_imgs[:, :, :, 3]).detach() > 0.0 + render_proj = sel_imgs.clone() + render_proj[mask] = render_imgs[mask][..., :3].byte() + loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask) + + if iter > 50: + loss = loss_col + loss_lan * 0.05 + loss_regid * 1.0 + loss_regexp * 0.8 + else: + loss = loss_col + loss_lan * 3 + loss_regid * 2.0 + loss_regexp * 1.0 + + optimizer_tl.zero_grad() + optimizer_id_frame.zero_grad() + loss.backward() + + optimizer_tl.step() + optimizer_id_frame.step() + + if iter % 50 == 0 and iter > 0: + for param_group in optimizer_id_frame.param_groups: + param_group["lr"] *= 0.2 + for param_group in optimizer_tl.param_groups: + param_group["lr"] *= 0.2 + # print(iter, loss_col.item(), loss_lan.item(), loss_regid.item(), loss_regexp.item()) + + +light_mean = torch.mean(sel_light, 0).unsqueeze(0).repeat(num_frames, 1) +light_para.data = light_mean + +exp_para = exp_para.detach() +euler_angle = euler_angle.detach() +trans = trans.detach() +light_para = light_para.detach() + +print(f'[INFO] fine frame-wise fitting...') + +for i in range(int((num_frames - 1) / batch_size + 1)): + + if (i + 1) * batch_size > num_frames: + start_n = num_frames - batch_size + sel_ids = np.arange(num_frames - batch_size, num_frames) + else: + start_n = i * batch_size + sel_ids = np.arange(i * batch_size, i * batch_size + batch_size) + + imgs = [] + for sel_id in sel_ids: + imgs.append(cv2.imread(img_paths[sel_id])[:, :, ::-1]) + imgs = np.stack(imgs) + sel_imgs = torch.as_tensor(imgs).cuda() + sel_lms = lms[sel_ids] + + sel_exp_para = exp_para.new_zeros((batch_size, exp_dim), requires_grad=True) + sel_exp_para.data = exp_para[sel_ids].clone() + sel_euler = euler_angle.new_zeros((batch_size, 3), requires_grad=True) + sel_euler.data = euler_angle[sel_ids].clone() + sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True) + sel_trans.data = trans[sel_ids].clone() + sel_light = light_para.new_zeros((batch_size, 27), requires_grad=True) + sel_light.data = light_para[sel_ids].clone() + + set_requires_grad([sel_exp_para, sel_euler, sel_trans, sel_light]) + + optimizer_cur_batch = torch.optim.Adam( + [sel_exp_para, sel_euler, sel_trans, sel_light], lr=0.005 + ) + + sel_id_para = id_para.expand(batch_size, -1).detach() + sel_tex_para = tex_para.expand(batch_size, -1).detach() + + pre_num = 5 + + if i > 0: + pre_ids = np.arange(start_n - pre_num, start_n) + + for iter in range(50): + + geometry = model_3dmm.get_3dlandmarks( + sel_id_para, sel_exp_para, sel_euler, sel_trans, focal_length, cxy + ) + proj_geo = forward_transform(geometry, sel_euler, sel_trans, focal_length, cxy) + loss_lan = cal_lan_loss(proj_geo[:, :, :2], sel_lms.detach()) + loss_regexp = torch.mean(sel_exp_para * sel_exp_para) + + sel_geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) + sel_texture = model_3dmm.forward_tex(sel_tex_para) + geometry = model_3dmm.forward_geo(sel_id_para, sel_exp_para) + rott_geo = forward_rott(geometry, sel_euler, sel_trans) + render_imgs = renderer( + rott_geo.to(device_render), + sel_texture.to(device_render), + sel_light.to(device_render), + ) + render_imgs = render_imgs.to(device_default) + + mask = (render_imgs[:, :, :, 3]).detach() > 0.0 + + loss_col = cal_col_loss(render_imgs[:, :, :, :3], sel_imgs.float(), mask) + + if i > 0: + geometry_lap = model_3dmm.forward_geo_sub( + id_para.expand(batch_size + pre_num, -1).detach(), + torch.cat((exp_para[pre_ids].detach(), sel_exp_para)), + model_3dmm.rigid_ids, + ) + rott_geo_lap = forward_rott( + geometry_lap, + torch.cat((euler_angle[pre_ids].detach(), sel_euler)), + torch.cat((trans[pre_ids].detach(), sel_trans)), + ) + loss_lap = cal_lap_loss( + [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0] + ) + else: + geometry_lap = model_3dmm.forward_geo_sub( + id_para.expand(batch_size, -1).detach(), + sel_exp_para, + model_3dmm.rigid_ids, + ) + rott_geo_lap = forward_rott(geometry_lap, sel_euler, sel_trans) + loss_lap = cal_lap_loss( + [rott_geo_lap.reshape(rott_geo_lap.shape[0], -1).permute(1, 0)], [1.0] + ) + + + if iter > 30: + loss = loss_col * 0.5 + loss_lan * 1.5 + loss_lap * 100000 + loss_regexp * 1.0 + else: + loss = loss_col * 0.5 + loss_lan * 8 + loss_lap * 100000 + loss_regexp * 1.0 + + optimizer_cur_batch.zero_grad() + loss.backward() + optimizer_cur_batch.step() + + # if iter % 10 == 0: + # print( + # i, + # iter, + # loss_col.item(), + # loss_lan.item(), + # loss_lap.item(), + # loss_regexp.item(), + # ) + + print(str(i) + " of " + str(int((num_frames - 1) / batch_size + 1)) + " done") + + render_proj = sel_imgs.clone() + render_proj[mask] = render_imgs[mask][..., :3].byte() + + exp_para[sel_ids] = sel_exp_para.clone() + euler_angle[sel_ids] = sel_euler.clone() + trans[sel_ids] = sel_trans.clone() + light_para[sel_ids] = sel_light.clone() + +torch.save( + { + "id": id_para.detach().cpu(), + "exp": exp_para.detach().cpu(), + "euler": euler_angle.detach().cpu(), + "trans": trans.detach().cpu(), + "focal": focal_length.detach().cpu(), + }, + os.path.join(os.path.dirname(args.path), "track_params.pt"), +) + +print("params saved") diff --git a/data_utils/face_tracking/facemodel.py b/data_utils/face_tracking/facemodel.py new file mode 100644 index 0000000..6d19c90 --- /dev/null +++ b/data_utils/face_tracking/facemodel.py @@ -0,0 +1,153 @@ +import torch +import torch.nn as nn +import numpy as np +import os +from util import * + + +class Face_3DMM(nn.Module): + def __init__(self, modelpath, id_dim, exp_dim, tex_dim, point_num): + super(Face_3DMM, self).__init__() + # id_dim = 100 + # exp_dim = 79 + # tex_dim = 100 + self.point_num = point_num + DMM_info = np.load( + os.path.join(modelpath, "3DMM_info.npy"), allow_pickle=True + ).item() + base_id = DMM_info["b_shape"][:id_dim, :] + mu_id = DMM_info["mu_shape"] + base_exp = DMM_info["b_exp"][:exp_dim, :] + mu_exp = DMM_info["mu_exp"] + mu = mu_id + mu_exp + mu = mu.reshape(-1, 3) + for i in range(3): + mu[:, i] -= np.mean(mu[:, i]) + mu = mu.reshape(-1) + self.base_id = torch.as_tensor(base_id).cuda() / 100000.0 + self.base_exp = torch.as_tensor(base_exp).cuda() / 100000.0 + self.mu = torch.as_tensor(mu).cuda() / 100000.0 + base_tex = DMM_info["b_tex"][:tex_dim, :] + mu_tex = DMM_info["mu_tex"] + self.base_tex = torch.as_tensor(base_tex).cuda() + self.mu_tex = torch.as_tensor(mu_tex).cuda() + sig_id = DMM_info["sig_shape"][:id_dim] + sig_tex = DMM_info["sig_tex"][:tex_dim] + sig_exp = DMM_info["sig_exp"][:exp_dim] + self.sig_id = torch.as_tensor(sig_id).cuda() + self.sig_tex = torch.as_tensor(sig_tex).cuda() + self.sig_exp = torch.as_tensor(sig_exp).cuda() + + keys_info = np.load( + os.path.join(modelpath, "keys_info.npy"), allow_pickle=True + ).item() + self.keyinds = torch.as_tensor(keys_info["keyinds"]).cuda() + self.left_contours = torch.as_tensor(keys_info["left_contour"]).cuda() + self.right_contours = torch.as_tensor(keys_info["right_contour"]).cuda() + self.rigid_ids = torch.as_tensor(keys_info["rigid_ids"]).cuda() + + def get_3dlandmarks(self, id_para, exp_para, euler_angle, trans, focal_length, cxy): + id_para = id_para * self.sig_id + exp_para = exp_para * self.sig_exp + batch_size = id_para.shape[0] + num_per_contour = self.left_contours.shape[1] + left_contours_flat = self.left_contours.reshape(-1) + right_contours_flat = self.right_contours.reshape(-1) + sel_index = torch.cat( + ( + 3 * left_contours_flat.unsqueeze(1), + 3 * left_contours_flat.unsqueeze(1) + 1, + 3 * left_contours_flat.unsqueeze(1) + 2, + ), + dim=1, + ).reshape(-1) + left_geometry = ( + torch.mm(id_para, self.base_id[:, sel_index]) + + torch.mm(exp_para, self.base_exp[:, sel_index]) + + self.mu[sel_index] + ) + left_geometry = left_geometry.view(batch_size, -1, 3) + proj_x = forward_transform( + left_geometry, euler_angle, trans, focal_length, cxy + )[:, :, 0] + proj_x = proj_x.reshape(batch_size, 8, num_per_contour) + arg_min = proj_x.argmin(dim=2) + left_geometry = left_geometry.view(batch_size * 8, num_per_contour, 3) + left_3dlands = left_geometry[ + torch.arange(batch_size * 8), arg_min.view(-1), : + ].view(batch_size, 8, 3) + + sel_index = torch.cat( + ( + 3 * right_contours_flat.unsqueeze(1), + 3 * right_contours_flat.unsqueeze(1) + 1, + 3 * right_contours_flat.unsqueeze(1) + 2, + ), + dim=1, + ).reshape(-1) + right_geometry = ( + torch.mm(id_para, self.base_id[:, sel_index]) + + torch.mm(exp_para, self.base_exp[:, sel_index]) + + self.mu[sel_index] + ) + right_geometry = right_geometry.view(batch_size, -1, 3) + proj_x = forward_transform( + right_geometry, euler_angle, trans, focal_length, cxy + )[:, :, 0] + proj_x = proj_x.reshape(batch_size, 8, num_per_contour) + arg_max = proj_x.argmax(dim=2) + right_geometry = right_geometry.view(batch_size * 8, num_per_contour, 3) + right_3dlands = right_geometry[ + torch.arange(batch_size * 8), arg_max.view(-1), : + ].view(batch_size, 8, 3) + + sel_index = torch.cat( + ( + 3 * self.keyinds.unsqueeze(1), + 3 * self.keyinds.unsqueeze(1) + 1, + 3 * self.keyinds.unsqueeze(1) + 2, + ), + dim=1, + ).reshape(-1) + geometry = ( + torch.mm(id_para, self.base_id[:, sel_index]) + + torch.mm(exp_para, self.base_exp[:, sel_index]) + + self.mu[sel_index] + ) + lands_3d = geometry.view(-1, self.keyinds.shape[0], 3) + lands_3d[:, :8, :] = left_3dlands + lands_3d[:, 9:17, :] = right_3dlands + return lands_3d + + def forward_geo_sub(self, id_para, exp_para, sub_index): + id_para = id_para * self.sig_id + exp_para = exp_para * self.sig_exp + sel_index = torch.cat( + ( + 3 * sub_index.unsqueeze(1), + 3 * sub_index.unsqueeze(1) + 1, + 3 * sub_index.unsqueeze(1) + 2, + ), + dim=1, + ).reshape(-1) + geometry = ( + torch.mm(id_para, self.base_id[:, sel_index]) + + torch.mm(exp_para, self.base_exp[:, sel_index]) + + self.mu[sel_index] + ) + return geometry.reshape(-1, sub_index.shape[0], 3) + + def forward_geo(self, id_para, exp_para): + id_para = id_para * self.sig_id + exp_para = exp_para * self.sig_exp + geometry = ( + torch.mm(id_para, self.base_id) + + torch.mm(exp_para, self.base_exp) + + self.mu + ) + return geometry.reshape(-1, self.point_num, 3) + + def forward_tex(self, tex_para): + tex_para = tex_para * self.sig_tex + texture = torch.mm(tex_para, self.base_tex) + self.mu_tex + return texture.reshape(-1, self.point_num, 3) diff --git a/data_utils/face_tracking/geo_transform.py b/data_utils/face_tracking/geo_transform.py new file mode 100644 index 0000000..c5f29b8 --- /dev/null +++ b/data_utils/face_tracking/geo_transform.py @@ -0,0 +1,69 @@ +"""This module contains functions for geometry transform and camera projection""" +import torch +import torch.nn as nn +import numpy as np + + +def euler2rot(euler_angle): + batch_size = euler_angle.shape[0] + theta = euler_angle[:, 0].reshape(-1, 1, 1) + phi = euler_angle[:, 1].reshape(-1, 1, 1) + psi = euler_angle[:, 2].reshape(-1, 1, 1) + one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) + zero = torch.zeros( + (batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device + ) + rot_x = torch.cat( + ( + torch.cat((one, zero, zero), 1), + torch.cat((zero, theta.cos(), theta.sin()), 1), + torch.cat((zero, -theta.sin(), theta.cos()), 1), + ), + 2, + ) + rot_y = torch.cat( + ( + torch.cat((phi.cos(), zero, -phi.sin()), 1), + torch.cat((zero, one, zero), 1), + torch.cat((phi.sin(), zero, phi.cos()), 1), + ), + 2, + ) + rot_z = torch.cat( + ( + torch.cat((psi.cos(), -psi.sin(), zero), 1), + torch.cat((psi.sin(), psi.cos(), zero), 1), + torch.cat((zero, zero, one), 1), + ), + 2, + ) + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) + + +def rot_trans_geo(geometry, rot, trans): + rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans.view(-1, 3, 1) + return rott_geo.permute(0, 2, 1) + + +def euler_trans_geo(geometry, euler, trans): + rot = euler2rot(euler) + return rot_trans_geo(geometry, rot, trans) + + +def proj_geo(rott_geo, camera_para): + fx = camera_para[:, 0] + fy = camera_para[:, 0] + cx = camera_para[:, 1] + cy = camera_para[:, 2] + + X = rott_geo[:, :, 0] + Y = rott_geo[:, :, 1] + Z = rott_geo[:, :, 2] + + fxX = fx[:, None] * X + fyY = fy[:, None] * Y + + proj_x = -fxX / Z + cx[:, None] + proj_y = fyY / Z + cy[:, None] + + return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) diff --git a/data_utils/face_tracking/render_3dmm.py b/data_utils/face_tracking/render_3dmm.py new file mode 100644 index 0000000..9e8c1cc --- /dev/null +++ b/data_utils/face_tracking/render_3dmm.py @@ -0,0 +1,202 @@ +import torch +import torch.nn as nn +import numpy as np +import os +from pytorch3d.structures import Meshes +from pytorch3d.renderer import ( + look_at_view_transform, + PerspectiveCameras, + FoVPerspectiveCameras, + PointLights, + DirectionalLights, + Materials, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, + TexturesUV, + TexturesVertex, + blending, +) + +from pytorch3d.ops import interpolate_face_attributes + +from pytorch3d.renderer.blending import ( + BlendParams, + hard_rgb_blend, + sigmoid_alpha_blend, + softmax_rgb_blend, +) + + +class SoftSimpleShader(nn.Module): + """ + Per pixel lighting - the lighting model is applied using the interpolated + coordinates and normals for each pixel. The blending function returns the + soft aggregated color using all the faces per pixel. + + To use the default values, simply initialize the shader with the desired + device e.g. + + """ + + def __init__( + self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None + ): + super().__init__() + self.lights = lights if lights is not None else PointLights(device=device) + self.materials = ( + materials if materials is not None else Materials(device=device) + ) + self.cameras = cameras + self.blend_params = blend_params if blend_params is not None else BlendParams() + + def to(self, device): + # Manually move to device modules which are not subclasses of nn.Module + self.cameras = self.cameras.to(device) + self.materials = self.materials.to(device) + self.lights = self.lights.to(device) + return self + + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: + + texels = meshes.sample_textures(fragments) + blend_params = kwargs.get("blend_params", self.blend_params) + + cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of SoftPhongShader" + raise ValueError(msg) + znear = kwargs.get("znear", getattr(cameras, "znear", 1.0)) + zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0)) + images = softmax_rgb_blend( + texels, fragments, blend_params, znear=znear, zfar=zfar + ) + return images + + +class Render_3DMM(nn.Module): + def __init__( + self, + focal=1015, + img_h=500, + img_w=500, + batch_size=1, + device=torch.device("cuda:0"), + ): + super(Render_3DMM, self).__init__() + + self.focal = focal + self.img_h = img_h + self.img_w = img_w + self.device = device + self.renderer = self.get_render(batch_size) + + dir_path = os.path.dirname(os.path.realpath(__file__)) + topo_info = np.load( + os.path.join(dir_path, "3DMM", "topology_info.npy"), allow_pickle=True + ).item() + self.tris = torch.as_tensor(topo_info["tris"]).to(self.device) + self.vert_tris = torch.as_tensor(topo_info["vert_tris"]).to(self.device) + + def compute_normal(self, geometry): + vert_1 = torch.index_select(geometry, 1, self.tris[:, 0]) + vert_2 = torch.index_select(geometry, 1, self.tris[:, 1]) + vert_3 = torch.index_select(geometry, 1, self.tris[:, 2]) + nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) + tri_normal = nn.functional.normalize(nnorm, dim=2) + v_norm = tri_normal[:, self.vert_tris, :].sum(2) + vert_normal = v_norm / v_norm.norm(dim=2).unsqueeze(2) + return vert_normal + + def get_render(self, batch_size=1): + half_s = self.img_w * 0.5 + R, T = look_at_view_transform(10, 0, 0) + R = R.repeat(batch_size, 1, 1) + T = torch.zeros((batch_size, 3), dtype=torch.float32).to(self.device) + + cameras = FoVPerspectiveCameras( + device=self.device, + R=R, + T=T, + znear=0.01, + zfar=20, + fov=2 * np.arctan(self.img_w // 2 / self.focal) * 180.0 / np.pi, + ) + lights = PointLights( + device=self.device, + location=[[0.0, 0.0, 1e5]], + ambient_color=[[1, 1, 1]], + specular_color=[[0.0, 0.0, 0.0]], + diffuse_color=[[0.0, 0.0, 0.0]], + ) + sigma = 1e-4 + raster_settings = RasterizationSettings( + image_size=(self.img_h, self.img_w), + blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma / 18.0, + faces_per_pixel=2, + perspective_correct=False, + ) + blend_params = blending.BlendParams(background_color=[0, 0, 0]) + renderer = MeshRenderer( + rasterizer=MeshRasterizer(raster_settings=raster_settings, cameras=cameras), + shader=SoftSimpleShader( + lights=lights, blend_params=blend_params, cameras=cameras + ), + ) + return renderer.to(self.device) + + @staticmethod + def Illumination_layer(face_texture, norm, gamma): + + n_b, num_vertex, _ = face_texture.size() + n_v_full = n_b * num_vertex + gamma = gamma.view(-1, 3, 9).clone() + gamma[:, :, 0] += 0.8 + + gamma = gamma.permute(0, 2, 1) + + a0 = np.pi + a1 = 2 * np.pi / np.sqrt(3.0) + a2 = 2 * np.pi / np.sqrt(8.0) + c0 = 1 / np.sqrt(4 * np.pi) + c1 = np.sqrt(3.0) / np.sqrt(4 * np.pi) + c2 = 3 * np.sqrt(5.0) / np.sqrt(12 * np.pi) + d0 = 0.5 / np.sqrt(3.0) + + Y0 = torch.ones(n_v_full).to(gamma.device).float() * a0 * c0 + norm = norm.view(-1, 3) + nx, ny, nz = norm[:, 0], norm[:, 1], norm[:, 2] + arrH = [] + + arrH.append(Y0) + arrH.append(-a1 * c1 * ny) + arrH.append(a1 * c1 * nz) + arrH.append(-a1 * c1 * nx) + arrH.append(a2 * c2 * nx * ny) + arrH.append(-a2 * c2 * ny * nz) + arrH.append(a2 * c2 * d0 * (3 * nz.pow(2) - 1)) + arrH.append(-a2 * c2 * nx * nz) + arrH.append(a2 * c2 * 0.5 * (nx.pow(2) - ny.pow(2))) + + H = torch.stack(arrH, 1) + Y = H.view(n_b, num_vertex, 9) + lighting = Y.bmm(gamma) + + face_color = face_texture * lighting + return face_color + + def forward(self, rott_geometry, texture, diffuse_sh): + face_normal = self.compute_normal(rott_geometry) + face_color = self.Illumination_layer(texture, face_normal, diffuse_sh) + face_color = TexturesVertex(face_color) + mesh = Meshes( + rott_geometry, + self.tris.float().repeat(rott_geometry.shape[0], 1, 1), + face_color, + ) + rendered_img = self.renderer(mesh) + rendered_img = torch.clamp(rendered_img, 0, 255) + + return rendered_img diff --git a/data_utils/face_tracking/render_land.py b/data_utils/face_tracking/render_land.py new file mode 100644 index 0000000..b4bd7fe --- /dev/null +++ b/data_utils/face_tracking/render_land.py @@ -0,0 +1,192 @@ +import torch +import torch.nn as nn +import render_util +import geo_transform +import numpy as np + + +def compute_tri_normal(geometry, tris): + geometry = geometry.permute(0, 2, 1) + tri_1 = tris[:, 0] + tri_2 = tris[:, 1] + tri_3 = tris[:, 2] + + vert_1 = torch.index_select(geometry, 2, tri_1) + vert_2 = torch.index_select(geometry, 2, tri_2) + vert_3 = torch.index_select(geometry, 2, tri_3) + + nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 1) + normal = nn.functional.normalize(nnorm).permute(0, 2, 1) + return normal + + +class Compute_normal_base(torch.autograd.Function): + @staticmethod + def forward(ctx, normal): + (normal_b,) = render_util.normal_base_forward(normal) + ctx.save_for_backward(normal) + return normal_b + + @staticmethod + def backward(ctx, grad_normal_b): + (normal,) = ctx.saved_tensors + (grad_normal,) = render_util.normal_base_backward(grad_normal_b, normal) + return grad_normal + + +class Normal_Base(torch.nn.Module): + def __init__(self): + super(Normal_Base, self).__init__() + + def forward(self, normal): + return Compute_normal_base.apply(normal) + + +def preprocess_render(geometry, euler, trans, cam, tris, vert_tris, ori_img): + point_num = geometry.shape[1] + rott_geo = geo_transform.euler_trans_geo(geometry, euler, trans) + proj_geo = geo_transform.proj_geo(rott_geo, cam) + rot_tri_normal = compute_tri_normal(rott_geo, tris) + rot_vert_normal = torch.index_select(rot_tri_normal, 1, vert_tris) + is_visible = -torch.bmm( + rot_vert_normal.reshape(-1, 1, 3), + nn.functional.normalize(rott_geo.reshape(-1, 3, 1)), + ).reshape(-1, point_num) + is_visible[is_visible < 0.01] = -1 + pixel_valid = torch.zeros( + (ori_img.shape[0], ori_img.shape[1] * ori_img.shape[2]), + dtype=torch.float32, + device=ori_img.device, + ) + return rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid + + +class Render_Face(torch.autograd.Function): + @staticmethod + def forward( + ctx, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid + ): + batch_size, h, w, _ = ori_img.shape + ori_img = ori_img.view(batch_size, -1, 3) + ori_size = torch.cat( + ( + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) + * h, + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) + * w, + ), + dim=1, + ).view(-1) + tri_index, tri_coord, render, real = render_util.render_face_forward( + proj_geo, ori_img, ori_size, texture, nbl, is_visible, tri_inds, pixel_valid + ) + ctx.save_for_backward( + ori_img, ori_size, proj_geo, texture, nbl, tri_inds, tri_index, tri_coord + ) + return render, real + + @staticmethod + def backward(ctx, grad_render, grad_real): + ( + ori_img, + ori_size, + proj_geo, + texture, + nbl, + tri_inds, + tri_index, + tri_coord, + ) = ctx.saved_tensors + grad_proj_geo, grad_texture, grad_nbl = render_util.render_face_backward( + grad_render, + grad_real, + ori_img, + ori_size, + proj_geo, + texture, + nbl, + tri_inds, + tri_index, + tri_coord, + ) + return grad_proj_geo, grad_texture, grad_nbl, None, None, None, None + + +class Render_RGB(nn.Module): + def __init__(self): + super(Render_RGB, self).__init__() + + def forward( + self, proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid + ): + return Render_Face.apply( + proj_geo, texture, nbl, ori_img, is_visible, tri_inds, pixel_valid + ) + + +def cal_land(proj_geo, is_visible, lands_info, land_num): + (land_index,) = render_util.update_contour(lands_info, is_visible, land_num) + proj_land = torch.index_select(proj_geo.reshape(-1, 3), 0, land_index)[ + :, :2 + ].reshape(-1, land_num, 2) + return proj_land + + +class Render_Land(nn.Module): + def __init__(self): + super(Render_Land, self).__init__() + lands_info = np.loadtxt("../data/3DMM/lands_info.txt", dtype=np.int32) + self.lands_info = torch.as_tensor(lands_info).cuda() + tris = np.loadtxt("../data/3DMM/tris.txt", dtype=np.int64) + self.tris = torch.as_tensor(tris).cuda() - 1 + vert_tris = np.loadtxt("../data/3DMM/vert_tris.txt", dtype=np.int64) + self.vert_tris = torch.as_tensor(vert_tris).cuda() + self.normal_baser = Normal_Base().cuda() + self.renderer = Render_RGB().cuda() + + def render_mesh(self, geometry, euler, trans, cam, ori_img, light): + batch_size, h, w, _ = ori_img.shape + ori_img = ori_img.view(batch_size, -1, 3) + ori_size = torch.cat( + ( + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) + * h, + torch.ones((batch_size, 1), dtype=torch.int32, device=ori_img.device) + * w, + ), + dim=1, + ).view(-1) + rott_geo, proj_geo, rot_tri_normal, _, _ = preprocess_render( + geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img + ) + tri_nb = self.normal_baser(rot_tri_normal.contiguous()) + nbl = torch.bmm( + tri_nb, (light.reshape(-1, 9, 3))[:, :, 0].unsqueeze(-1).repeat(1, 1, 3) + ) + texture = torch.ones_like(geometry) * 200 + (render,) = render_util.render_mesh( + proj_geo, ori_img, ori_size, texture, nbl, self.tris + ) + return render.view(batch_size, h, w, 3).byte() + + def cal_loss_rgb(self, geometry, euler, trans, cam, ori_img, light, texture, lands): + rott_geo, proj_geo, rot_tri_normal, is_visible, pixel_valid = preprocess_render( + geometry, euler, trans, cam, self.tris, self.vert_tris, ori_img + ) + tri_nb = self.normal_baser(rot_tri_normal.contiguous()) + nbl = torch.bmm(tri_nb, light.reshape(-1, 9, 3)) + render, real = self.renderer( + proj_geo, texture, nbl, ori_img, is_visible, self.tris, pixel_valid + ) + proj_land = cal_land(proj_geo, is_visible, self.lands_info, lands.shape[1]) + col_minus = torch.norm((render - real).reshape(-1, 3), dim=1).reshape( + ori_img.shape[0], -1 + ) + col_dis = torch.mean(col_minus * pixel_valid) / ( + torch.mean(pixel_valid) + 0.00001 + ) + land_dists = torch.norm((proj_land - lands).reshape(-1, 2), dim=1).reshape( + ori_img.shape[0], -1 + ) + lan_dis = torch.mean(land_dists) + return col_dis, lan_dis diff --git a/data_utils/face_tracking/util.py b/data_utils/face_tracking/util.py new file mode 100644 index 0000000..cc0f3d8 --- /dev/null +++ b/data_utils/face_tracking/util.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def compute_tri_normal(geometry, tris): + tri_1 = tris[:, 0] + tri_2 = tris[:, 1] + tri_3 = tris[:, 2] + vert_1 = torch.index_select(geometry, 1, tri_1) + vert_2 = torch.index_select(geometry, 1, tri_2) + vert_3 = torch.index_select(geometry, 1, tri_3) + nnorm = torch.cross(vert_2 - vert_1, vert_3 - vert_1, 2) + normal = nn.functional.normalize(nnorm) + return normal + + +def euler2rot(euler_angle): + batch_size = euler_angle.shape[0] + theta = euler_angle[:, 0].reshape(-1, 1, 1) + phi = euler_angle[:, 1].reshape(-1, 1, 1) + psi = euler_angle[:, 2].reshape(-1, 1, 1) + one = torch.ones(batch_size, 1, 1).to(euler_angle.device) + zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device) + rot_x = torch.cat( + ( + torch.cat((one, zero, zero), 1), + torch.cat((zero, theta.cos(), theta.sin()), 1), + torch.cat((zero, -theta.sin(), theta.cos()), 1), + ), + 2, + ) + rot_y = torch.cat( + ( + torch.cat((phi.cos(), zero, -phi.sin()), 1), + torch.cat((zero, one, zero), 1), + torch.cat((phi.sin(), zero, phi.cos()), 1), + ), + 2, + ) + rot_z = torch.cat( + ( + torch.cat((psi.cos(), -psi.sin(), zero), 1), + torch.cat((psi.sin(), psi.cos(), zero), 1), + torch.cat((zero, zero, one), 1), + ), + 2, + ) + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) + + +def rot_trans_pts(geometry, rot, trans): + rott_geo = torch.bmm(rot, geometry.permute(0, 2, 1)) + trans[:, :, None] + return rott_geo.permute(0, 2, 1) + + +def cal_lap_loss(tensor_list, weight_list): + lap_kernel = ( + torch.Tensor((-0.5, 1.0, -0.5)) + .unsqueeze(0) + .unsqueeze(0) + .float() + .to(tensor_list[0].device) + ) + loss_lap = 0 + for i in range(len(tensor_list)): + in_tensor = tensor_list[i] + in_tensor = in_tensor.view(-1, 1, in_tensor.shape[-1]) + out_tensor = F.conv1d(in_tensor, lap_kernel) + loss_lap += torch.mean(out_tensor ** 2) * weight_list[i] + return loss_lap + + +def proj_pts(rott_geo, focal_length, cxy): + cx, cy = cxy[0], cxy[1] + X = rott_geo[:, :, 0] + Y = rott_geo[:, :, 1] + Z = rott_geo[:, :, 2] + fxX = focal_length * X + fyY = focal_length * Y + proj_x = -fxX / Z + cx + proj_y = fyY / Z + cy + return torch.cat((proj_x[:, :, None], proj_y[:, :, None], Z[:, :, None]), 2) + + +def forward_rott(geometry, euler_angle, trans): + rot = euler2rot(euler_angle) + rott_geo = rot_trans_pts(geometry, rot, trans) + return rott_geo + + +def forward_transform(geometry, euler_angle, trans, focal_length, cxy): + rot = euler2rot(euler_angle) + rott_geo = rot_trans_pts(geometry, rot, trans) + proj_geo = proj_pts(rott_geo, focal_length, cxy) + return proj_geo + + +def cal_lan_loss(proj_lan, gt_lan): + return torch.mean((proj_lan - gt_lan) ** 2) + + +def cal_col_loss(pred_img, gt_img, img_mask): + pred_img = pred_img.float() + # loss = torch.sqrt(torch.sum(torch.square(pred_img - gt_img), 3))*img_mask/255 + loss = (torch.sum(torch.square(pred_img - gt_img), 3)) * img_mask / 255 + loss = torch.sum(loss, dim=(1, 2)) / torch.sum(img_mask, dim=(1, 2)) + loss = torch.mean(loss) + return loss diff --git a/data_utils/process.py b/data_utils/process.py new file mode 100644 index 0000000..65aaa2d --- /dev/null +++ b/data_utils/process.py @@ -0,0 +1,444 @@ +import os +import glob +import tqdm +import json +import argparse +import cv2 +import numpy as np + +def extract_audio(path, out_path, sample_rate=16000): + + print(f'[INFO] ===== extract audio from {path} to {out_path} =====') + cmd = f'ffmpeg -i {path} -f wav -ar {sample_rate} {out_path}' + os.system(cmd) + print(f'[INFO] ===== extracted audio =====') + + +def extract_audio_features(path, mode='wav2vec'): + + print(f'[INFO] ===== extract audio labels for {path} =====') + if mode == 'wav2vec': + cmd = f'python nerf/asr.py --wav {path} --save_feats' + else: # deepspeech + cmd = f'python data_utils/deepspeech_features/extract_ds_features.py --input {path}' + os.system(cmd) + print(f'[INFO] ===== extracted audio labels =====') + + + +def extract_images(path, out_path, fps=25): + + print(f'[INFO] ===== extract images from {path} to {out_path} =====') + cmd = f'ffmpeg -i {path} -vf fps={fps} -qmin 1 -q:v 1 -start_number 0 {os.path.join(out_path, "%d.jpg")}' + os.system(cmd) + print(f'[INFO] ===== extracted images =====') + + +def extract_semantics(ori_imgs_dir, parsing_dir): + + print(f'[INFO] ===== extract semantics from {ori_imgs_dir} to {parsing_dir} =====') + cmd = f'python data_utils/face_parsing/test.py --respath={parsing_dir} --imgpath={ori_imgs_dir}' + os.system(cmd) + print(f'[INFO] ===== extracted semantics =====') + + +def extract_landmarks(ori_imgs_dir): + + print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====') + + import face_alignment + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + for image_path in tqdm.tqdm(image_paths): + input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] + input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB) + preds = fa.get_landmarks(input) + if len(preds) > 0: + lands = preds[0].reshape(-1, 2)[:,:2] + np.savetxt(image_path.replace('jpg', 'lms'), lands, '%f') + del fa + print(f'[INFO] ===== extracted face landmarks =====') + + +def extract_background(base_dir, ori_imgs_dir): + + print(f'[INFO] ===== extract background image from {ori_imgs_dir} =====') + + from sklearn.neighbors import NearestNeighbors + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + # only use 1/20 image_paths + image_paths = image_paths[::20] + # read one image to get H/W + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] + h, w = tmp_image.shape[:2] + + # nearest neighbors + all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() + distss = [] + for image_path in tqdm.tqdm(image_paths): + parse_img = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png')) + bg = (parse_img[..., 0] == 255) & (parse_img[..., 1] == 255) & (parse_img[..., 2] == 255) + fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) + nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) + dists, _ = nbrs.kneighbors(all_xys) + distss.append(dists) + + distss = np.stack(distss) + max_dist = np.max(distss, 0) + max_id = np.argmax(distss, 0) + + bc_pixs = max_dist > 5 + bc_pixs_id = np.nonzero(bc_pixs) + bc_ids = max_id[bc_pixs] + + imgs = [] + num_pixs = distss.shape[1] + for image_path in image_paths: + img = cv2.imread(image_path) + imgs.append(img) + imgs = np.stack(imgs).reshape(-1, num_pixs, 3) + + bc_img = np.zeros((h*w, 3), dtype=np.uint8) + bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] + bc_img = bc_img.reshape(h, w, 3) + + max_dist = max_dist.reshape(h, w) + bc_pixs = max_dist > 5 + bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose() + fg_xys = np.stack(np.nonzero(bc_pixs)).transpose() + nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) + distances, indices = nbrs.kneighbors(bg_xys) + bg_fg_xys = fg_xys[indices[:, 0]] + bc_img[bg_xys[:, 0], bg_xys[:, 1], :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :] + + cv2.imwrite(os.path.join(base_dir, 'bc.jpg'), bc_img) + + print(f'[INFO] ===== extracted background image =====') + + +def extract_torso_and_gt(base_dir, ori_imgs_dir): + + print(f'[INFO] ===== extract torso and gt images for {base_dir} =====') + + from scipy.ndimage import binary_erosion, binary_dilation + + # load bg + bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED) + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + + for image_path in tqdm.tqdm(image_paths): + # read ori image + ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] + + # read semantics + seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png')) + head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0) + neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0) + torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255) + bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255) + + # get gt image + gt_image = ori_image.copy() + gt_image[bg_part] = bg_image[bg_part] + cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image) + + # get torso image + torso_image = gt_image.copy() # rgb + torso_image[head_part] = bg_image[head_part] + torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha + + # torso part "vertical" in-painting... + L = 8 + 1 + torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2] + # lexsort: sort 2D coords first by y then by x, + # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes + inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1])) + torso_coords = torso_coords[inds] + # choose the top pixel for each column + u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True) + top_torso_coords = torso_coords[uid] # [m, 2] + # only keep top-is-head pixels + top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) + mask = head_part[tuple(top_torso_coords_up.T)] + if mask.any(): + top_torso_coords = top_torso_coords[mask] + # get the color + top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3] + # construct inpaint coords (vertically up, or minus in x) + inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2] + inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2] + inpaint_torso_coords += inpaint_offsets + inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2] + inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3] + darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1] + inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3] + # set color + torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors + + inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool) + inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True + else: + inpaint_torso_mask = None + + + # neck part "vertical" in-painting... + push_down = 4 + L = 48 + push_down + 1 + + neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3) + + neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2] + # lexsort: sort 2D coords first by y then by x, + # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes + inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1])) + neck_coords = neck_coords[inds] + # choose the top pixel for each column + u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True) + top_neck_coords = neck_coords[uid] # [m, 2] + # only keep top-is-head pixels + top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0]) + mask = head_part[tuple(top_neck_coords_up.T)] + + top_neck_coords = top_neck_coords[mask] + # push these top down for 4 pixels to make the neck inpainting more natural... + offset_down = np.minimum(ucnt[mask] - 1, push_down) + top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1) + # get the color + top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3] + # construct inpaint coords (vertically up, or minus in x) + inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2] + inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2] + inpaint_neck_coords += inpaint_offsets + inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2] + inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3] + darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1] + inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3] + # set color + torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors + + # apply blurring to the inpaint area to avoid vertical-line artifects... + inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool) + inpaint_mask[tuple(inpaint_neck_coords.T)] = True + + blur_img = torso_image.copy() + blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT) + + torso_image[inpaint_mask] = blur_img[inpaint_mask] + + # set mask + mask = (neck_part | torso_part | inpaint_mask) + if inpaint_torso_mask is not None: + mask = mask | inpaint_torso_mask + torso_image[~mask] = 0 + torso_alpha[~mask] = 0 + + cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1)) + + print(f'[INFO] ===== extracted torso and gt images =====') + + +def face_tracking(ori_imgs_dir): + + print(f'[INFO] ===== perform face tracking =====') + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + + # read one image to get H/W + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] + h, w = tmp_image.shape[:2] + + cmd = f'python data_utils/face_tracking/face_tracker.py --path={ori_imgs_dir} --img_h={h} --img_w={w} --frame_num={len(image_paths)}' + + os.system(cmd) + + print(f'[INFO] ===== finished face tracking =====') + + +def save_transforms(base_dir, ori_imgs_dir): + print(f'[INFO] ===== save transforms =====') + + import torch + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + + # read one image to get H/W + tmp_image = cv2.imread(image_paths[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] + h, w = tmp_image.shape[:2] + + params_dict = torch.load(os.path.join(base_dir, 'track_params.pt')) + focal_len = params_dict['focal'] + euler_angle = params_dict['euler'] + trans = params_dict['trans'] / 10.0 + valid_num = euler_angle.shape[0] + + def euler2rot(euler_angle): + batch_size = euler_angle.shape[0] + theta = euler_angle[:, 0].reshape(-1, 1, 1) + phi = euler_angle[:, 1].reshape(-1, 1, 1) + psi = euler_angle[:, 2].reshape(-1, 1, 1) + one = torch.ones((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) + zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, device=euler_angle.device) + rot_x = torch.cat(( + torch.cat((one, zero, zero), 1), + torch.cat((zero, theta.cos(), theta.sin()), 1), + torch.cat((zero, -theta.sin(), theta.cos()), 1), + ), 2) + rot_y = torch.cat(( + torch.cat((phi.cos(), zero, -phi.sin()), 1), + torch.cat((zero, one, zero), 1), + torch.cat((phi.sin(), zero, phi.cos()), 1), + ), 2) + rot_z = torch.cat(( + torch.cat((psi.cos(), -psi.sin(), zero), 1), + torch.cat((psi.sin(), psi.cos(), zero), 1), + torch.cat((zero, zero, one), 1) + ), 2) + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) + + + # train_val_split = int(valid_num*0.5) + # train_val_split = valid_num - 25 * 20 # take the last 20s as valid set. + train_val_split = int(valid_num * 10 / 11) + + train_ids = torch.arange(0, train_val_split) + val_ids = torch.arange(train_val_split, valid_num) + + rot = euler2rot(euler_angle) + rot_inv = rot.permute(0, 2, 1) + trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2)) + + pose = torch.eye(4, dtype=torch.float32) + save_ids = ['train', 'val'] + train_val_ids = [train_ids, val_ids] + mean_z = -float(torch.mean(trans[:, 2]).item()) + + for split in range(2): + transform_dict = dict() + transform_dict['focal_len'] = float(focal_len[0]) + transform_dict['cx'] = float(w/2.0) + transform_dict['cy'] = float(h/2.0) + transform_dict['frames'] = [] + ids = train_val_ids[split] + save_id = save_ids[split] + + for i in ids: + i = i.item() + frame_dict = dict() + frame_dict['img_id'] = i + frame_dict['aud_id'] = i + + pose[:3, :3] = rot_inv[i] + pose[:3, 3] = trans_inv[i, :, 0] + + frame_dict['transform_matrix'] = pose.numpy().tolist() + + transform_dict['frames'].append(frame_dict) + + with open(os.path.join(base_dir, 'transforms_' + save_id + '.json'), 'w') as fp: + json.dump(transform_dict, fp, indent=2, separators=(',', ': ')) + + print(f'[INFO] ===== finished saving transforms =====') + + + +def extract_torso_train(base_dir, ori_imgs_dir): + + print(f'[INFO] ===== extract training torso gt images for {base_dir} =====') + + # load bg + bg_image = cv2.imread(os.path.join(base_dir, 'bc.jpg'), cv2.IMREAD_UNCHANGED) + + image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.jpg')) + + for image_path in tqdm.tqdm(image_paths): + # read ori image + ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] + + # read semantics + seg = cv2.imread(image_path.replace('ori_imgs', 'parsing').replace('.jpg', '.png')) + head_part = (seg[..., 0] == 255) & (seg[..., 1] == 0) & (seg[..., 2] == 0) + neck_part = (seg[..., 0] == 0) & (seg[..., 1] == 255) & (seg[..., 2] == 0) + torso_part = (seg[..., 0] == 0) & (seg[..., 1] == 0) & (seg[..., 2] == 255) + bg_part = (seg[..., 0] == 255) & (seg[..., 1] == 255) & (seg[..., 2] == 255) + + # get gt image + gt_image = ori_image.copy() + gt_image[bg_part] = bg_image[bg_part] + cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image) + + # get torso image + torso_image = gt_image.copy() # rgb + torso_image[head_part] = bg_image[head_part] + torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha + torso_alpha[head_part] = 0 + torso_alpha[bg_part] = 0 + cv2.imwrite(image_path.replace('ori_imgs', 'torso_imgs_train').replace('.jpg', '.png'), np.concatenate([torso_image, torso_alpha], axis=-1)) + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str, help="path to video file") + parser.add_argument('--task', type=int, default=-1, help="-1 means all") + parser.add_argument('--asr', type=str, default='wav2vec', help="wav2vec or deepspeech") + + opt = parser.parse_args() + + base_dir = os.path.dirname(opt.path) + + wav_path = os.path.join(base_dir, 'aud.wav') + ori_imgs_dir = os.path.join(base_dir, 'ori_imgs') + parsing_dir = os.path.join(base_dir, 'parsing') + gt_imgs_dir = os.path.join(base_dir, 'gt_imgs') + torso_imgs_dir = os.path.join(base_dir, 'torso_imgs') + torso_imgs_train_dir = os.path.join(base_dir, 'torso_imgs_train') + + os.makedirs(ori_imgs_dir, exist_ok=True) + os.makedirs(parsing_dir, exist_ok=True) + os.makedirs(gt_imgs_dir, exist_ok=True) + os.makedirs(torso_imgs_dir, exist_ok=True) + os.makedirs(torso_imgs_train_dir, exist_ok=True) + + + # extract audio + if opt.task == -1 or opt.task == 1: + extract_audio(opt.path, wav_path) + + # extract audio features + if opt.task == -1 or opt.task == 2: + extract_audio_features(wav_path, mode=opt.asr) + + # extract images + if opt.task == -1 or opt.task == 3: + extract_images(opt.path, ori_imgs_dir) + + # face parsing + if opt.task == -1 or opt.task == 4: + extract_semantics(ori_imgs_dir, parsing_dir) + + # extract bg + if opt.task == -1 or opt.task == 5: + extract_background(base_dir, ori_imgs_dir) + + # extract torso images and gt_images + if opt.task == -1 or opt.task == 6: + extract_torso_and_gt(base_dir, ori_imgs_dir) + + # extract face landmarks + if opt.task == -1 or opt.task == 7: + extract_landmarks(ori_imgs_dir) + + # face tracking + if opt.task == -1 or opt.task == 8: + face_tracking(ori_imgs_dir) + + # save transforms.json + if opt.task == -1 or opt.task == 9: + save_transforms(base_dir, ori_imgs_dir) + + if opt.task == -1 or opt.task == 10: + extract_torso_train(base_dir, ori_imgs_dir) + diff --git a/encoding.py b/encoding.py new file mode 100644 index 0000000..c700b47 --- /dev/null +++ b/encoding.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def get_encoder(encoding, input_dim=3, + multires=6, + degree=4, + num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, + **kwargs): + + if encoding == 'None': + return lambda x, **kwargs: x, input_dim + + elif encoding == 'frequency': + from freqencoder import FreqEncoder + encoder = FreqEncoder(input_dim=input_dim, degree=multires) + + elif encoding == 'spherical_harmonics': + from shencoder import SHEncoder + encoder = SHEncoder(input_dim=input_dim, degree=degree) + + elif encoding == 'hashgrid': + from gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) + + elif encoding == 'tiledgrid': + from gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) + + elif encoding == 'ash': + from ashencoder import AshEncoder + encoder = AshEncoder(input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution) + + else: + raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, spherical_harmonics, hashgrid, tiledgrid]') + + return encoder, encoder.output_dim \ No newline at end of file diff --git a/freqencoder/__init__.py b/freqencoder/__init__.py new file mode 100644 index 0000000..69ec49c --- /dev/null +++ b/freqencoder/__init__.py @@ -0,0 +1 @@ +from .freq import FreqEncoder \ No newline at end of file diff --git a/freqencoder/backend.py b/freqencoder/backend.py new file mode 100644 index 0000000..a89e351 --- /dev/null +++ b/freqencoder/backend.py @@ -0,0 +1,41 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-use_fast_math' +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_freqencoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'freqencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/freqencoder/freq.py b/freqencoder/freq.py new file mode 100644 index 0000000..05179f1 --- /dev/null +++ b/freqencoder/freq.py @@ -0,0 +1,77 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _freqencoder as _backend +except ImportError: + from .backend import _backend + + +class _freq_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, output_dim): + # inputs: [B, input_dim], float + # RETURN: [B, F], float + + if not inputs.is_cuda: inputs = inputs.cuda() + inputs = inputs.contiguous() + + B, input_dim = inputs.shape # batch size, coord dim + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) + + ctx.save_for_backward(inputs, outputs) + ctx.dims = [B, input_dim, degree, output_dim] + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + grad = grad.contiguous() + inputs, outputs = ctx.saved_tensors + B, input_dim, degree, output_dim = ctx.dims + + grad_inputs = torch.zeros_like(inputs) + _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) + + return grad_inputs, None, None + + +freq_encode = _freq_encoder.apply + + +class FreqEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim + self.degree = degree + self.output_dim = input_dim + input_dim * 2 * degree + + def __repr__(self): + return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" + + def forward(self, inputs, **kwargs): + # inputs: [..., input_dim] + # return: [..., ] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = freq_encode(inputs, self.degree, self.output_dim) + + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs \ No newline at end of file diff --git a/freqencoder/setup.py b/freqencoder/setup.py new file mode 100644 index 0000000..c9bb873 --- /dev/null +++ b/freqencoder/setup.py @@ -0,0 +1,51 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-use_fast_math' +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='freqencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_freqencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'freqencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/freqencoder/src/bindings.cpp b/freqencoder/src/bindings.cpp new file mode 100644 index 0000000..dc48bd0 --- /dev/null +++ b/freqencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "freqencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)"); + m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)"); +} \ No newline at end of file diff --git a/freqencoder/src/freqencoder.cu b/freqencoder/src/freqencoder.cu new file mode 100644 index 0000000..e1e0e89 --- /dev/null +++ b/freqencoder/src/freqencoder.cu @@ -0,0 +1,129 @@ +#include + +#include +#include +#include + +#include +#include + +#include +#include + +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + +inline constexpr __device__ float PI() { return 3.141592653589793f; } + +template +__host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +// inputs: [B, D] +// outputs: [B, C], C = D + D * deg * 2 +__global__ void kernel_freq( + const float * __restrict__ inputs, + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, + float * outputs +) { + // parallel on per-element + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * C) return; + + // get index + const uint32_t b = t / C; + const uint32_t c = t - b * C; // t % C; + + // locate + inputs += b * D; + outputs += t; + + // write self + if (c < D) { + outputs[0] = inputs[c]; + // write freq + } else { + const uint32_t col = c / D - 1; + const uint32_t d = c % D; + const uint32_t freq = col / 2; + const float phase_shift = (col % 2) * (PI() / 2); + outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); + } +} + +// grad: [B, C], C = D + D * deg * 2 +// outputs: [B, C] +// grad_inputs: [B, D] +__global__ void kernel_freq_backward( + const float * __restrict__ grad, + const float * __restrict__ outputs, + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, + float * grad_inputs +) { + // parallel on per-element + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; // t % D; + + // locate + grad += b * C; + outputs += b * C; + grad_inputs += t; + + // register + float result = grad[d]; + grad += D; + outputs += D; + + for (uint32_t f = 0; f < deg; f++) { + result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); + grad += 2 * D; + outputs += 2 * D; + } + + // write + grad_inputs[0] = result; +} + + +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { + CHECK_CUDA(inputs); + CHECK_CUDA(outputs); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(outputs); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(outputs); + + static constexpr uint32_t N_THREADS = 128; + + kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr()); +} + + +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { + CHECK_CUDA(grad); + CHECK_CUDA(outputs); + CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(outputs); + CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(outputs); + CHECK_IS_FLOATING(grad_inputs); + + static constexpr uint32_t N_THREADS = 128; + + kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr()); +} \ No newline at end of file diff --git a/freqencoder/src/freqencoder.h b/freqencoder/src/freqencoder.h new file mode 100644 index 0000000..cc420ee --- /dev/null +++ b/freqencoder/src/freqencoder.h @@ -0,0 +1,10 @@ +# pragma once + +#include +#include + +// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); + +// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); \ No newline at end of file diff --git a/gridencoder/__init__.py b/gridencoder/__init__.py new file mode 100644 index 0000000..f1476ce --- /dev/null +++ b/gridencoder/__init__.py @@ -0,0 +1 @@ +from .grid import GridEncoder \ No newline at end of file diff --git a/gridencoder/backend.py b/gridencoder/backend.py new file mode 100644 index 0000000..64a39ff --- /dev/null +++ b/gridencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14', '-finput-charset=UTF-8'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17', '/finput-charset=UTF-8'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_grid_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/gridencoder/grid.py b/gridencoder/grid.py new file mode 100644 index 0000000..8536992 --- /dev/null +++ b/gridencoder/grid.py @@ -0,0 +1,155 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _gridencoder as _backend +except ImportError: + from .backend import _backend + +_gridtype_to_id = { + 'hash': 0, + 'tiled': 1, +} + +class _grid_encode(Function): + @staticmethod + @custom_fwd + def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False): + # inputs: [B, D], float in [0, 1] + # embeddings: [sO, C], float + # offsets: [L + 1], int + # RETURN: [B, F], float + + inputs = inputs.float().contiguous() + + B, D = inputs.shape # batch size, coord dim + L = offsets.shape[0] - 1 # level + C = embeddings.shape[1] # embedding dim for each level + S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = base_resolution # base resolution + + # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) + # if C % 2 != 0, force float, since half for atomicAdd is very slow. + if torch.is_autocast_enabled() and C % 2 == 0: + embeddings = embeddings.to(torch.half) + + # L first, optimize cache for cuda kernel, but needs an extra permute later + outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) + + if calc_grad_inputs: + dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) + else: + dy_dx = None + + _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners) + + # permute back to [B, L * C] + outputs = outputs.permute(1, 0, 2).reshape(B, L * C) + + ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) + ctx.dims = [B, D, C, L, S, H, gridtype] + ctx.align_corners = align_corners + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + + inputs, embeddings, offsets, dy_dx = ctx.saved_tensors + B, D, C, L, S, H, gridtype = ctx.dims + align_corners = ctx.align_corners + + # grad: [B, L * C] --> [L, B, C] + grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() + + grad_embeddings = torch.zeros_like(embeddings) + + if dy_dx is not None: + grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) + else: + grad_inputs = None + + _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners) + + if dy_dx is not None: + grad_inputs = grad_inputs.to(inputs.dtype) + + return grad_inputs, grad_embeddings, None, None, None, None, None, None + + + +grid_encode = _grid_encode.apply + + +class GridEncoder(nn.Module): + def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False): + super().__init__() + + # the finest resolution desired at the last level, if provided, overridee per_level_scale + if desired_resolution is not None: + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) + + self.input_dim = input_dim # coord dims, 2 or 3 + self.num_levels = num_levels # num levels, each level multiply resolution by 2 + self.level_dim = level_dim # encode channels per level + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. + self.log2_hashmap_size = log2_hashmap_size + self.base_resolution = base_resolution + self.output_dim = num_levels * level_dim + self.gridtype = gridtype + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" + self.align_corners = align_corners + + # allocate parameters + offsets = [] + offset = 0 + self.max_params = 2 ** log2_hashmap_size + for i in range(num_levels): + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) + params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible + offsets.append(offset) + offset += params_in_level + # print(resolution, params_in_level) + offsets.append(offset) + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) + self.register_buffer('offsets', offsets) + + self.n_params = offsets[-1] * level_dim + + # parameters + self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) + + self.reset_parameters() + + def reset_parameters(self): + std = 1e-4 + self.embeddings.data.uniform_(-std, std) + + def __repr__(self): + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}" + + def forward(self, inputs, bound=1): + # inputs: [..., input_dim], normalized real world positions in [-bound, bound] + # return: [..., num_levels * level_dim] + + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.view(-1, self.input_dim) + + outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners) + outputs = outputs.view(prefix_shape + [self.output_dim]) + + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + + return outputs \ No newline at end of file diff --git a/gridencoder/setup.py b/gridencoder/setup.py new file mode 100644 index 0000000..bda10a1 --- /dev/null +++ b/gridencoder/setup.py @@ -0,0 +1,50 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='gridencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_gridencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/gridencoder/src/bindings.cpp b/gridencoder/src/bindings.cpp new file mode 100644 index 0000000..45f29b7 --- /dev/null +++ b/gridencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "gridencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); + m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); +} \ No newline at end of file diff --git a/gridencoder/src/gridencoder.cu b/gridencoder/src/gridencoder.cu new file mode 100644 index 0000000..34c1aba --- /dev/null +++ b/gridencoder/src/gridencoder.cu @@ -0,0 +1,479 @@ +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... +static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) { + // requires CUDA >= 10 and ARCH >= 70 + // this is very slow compared to float or __half2, and never used. + //return atomicAdd(reinterpret_cast<__half*>(address), val); +} + + +template +static inline __host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + + +template +__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { + static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions."); + + // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence + // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional + // coordinates. + constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 }; + + uint32_t result = 0; + #pragma unroll + for (uint32_t i = 0; i < D; ++i) { + result ^= pos_grid[i] * primes[i]; + } + + return result; +} + + +template +__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { + uint32_t stride = 1; + uint32_t index = 0; + + #pragma unroll + for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { + index += pos_grid[d] * stride; + stride *= align_corners ? resolution: (resolution + 1); + } + + // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. + // gridtype: 0 == hash, 1 == tiled + if (gridtype == 0 && stride > hashmap_size) { + index = fast_hash(pos_grid); + } + + return (index % hashmap_size) * C + ch; +} + + +template +__global__ void kernel_grid( + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ outputs, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + scalar_t * __restrict__ dy_dx, + const uint32_t gridtype, + const bool align_corners +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + grid += (uint32_t)offsets[level] * C; + inputs += b * D; + outputs += level * B * C + b * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + // if input out of bound, just set output to 0 + if (flag_oob) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = 0; + } + if (dy_dx) { + dy_dx += b * D * L * C + level * D * C; // B L D C + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[d * C + ch] = 0; + } + } + } + return; + } + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // interpolate + scalar_t results[C] = {0}; // temp results in register + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + // writing to register (fast) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results[ch] += w * grid[index + ch]; + } + + //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = results[ch]; + } + + // prepare dy_dx + // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 + if (dy_dx) { + + dy_dx += b * D * L * C + level * D * C; // B L D C + + #pragma unroll + for (uint32_t gd = 0; gd < D; gd++) { + + scalar_t results_grad[C] = {0}; + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { + float w = scale; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t nd = 0; nd < D - 1; nd++) { + const uint32_t d = (nd >= gd) ? (nd + 1) : nd; + + if ((idx & (1 << nd)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + pos_grid_local[gd] = pos_grid[gd]; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + pos_grid_local[gd] = pos_grid[gd] + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]); + } + } + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[gd * C + ch] = results_grad[ch]; + } + } + } +} + + +template +__global__ void kernel_grid_backward( + const scalar_t * __restrict__ grad, + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ grad_grid, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners +) { + const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; + if (b >= B) return; + + const uint32_t level = blockIdx.y; + const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; + + // locate + grad_grid += offsets[level] * C; + inputs += b * D; + grad += level * B * C + b * C + ch; // L, B, C + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // check input range (should be in [0, 1]) + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + return; // grad is init as 0, so we simply return. + } + } + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + } + + scalar_t grad_cur[N_C] = {0}; // fetch to register + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + grad_cur[c] = grad[c]; + } + + // interpolate + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); + + // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 + // TODO: use float which is better than __half, if N_C % 2 != 0 + if (std::is_same::value && N_C % 2 == 0) { + #pragma unroll + for (uint32_t c = 0; c < N_C; c += 2) { + // process two __half at once (by interpreting as a __half2) + __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; + atomicAdd((__half2*)&grad_grid[index + c], v); + } + // float, or __half when N_C % 2 != 0 (which means C == 1) + } else { + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + atomicAdd(&grad_grid[index + c], w * grad_cur[c]); + } + } + } +} + + +template +__global__ void kernel_input_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ dy_dx, + scalar_t * __restrict__ grad_inputs, + uint32_t B, uint32_t L +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; + + dy_dx += b * L * D * C; + + scalar_t result = 0; + + # pragma unroll + for (int l = 0; l < L; l++) { + # pragma unroll + for (int ch = 0; ch < C; ch++) { + result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; + } + } + + grad_inputs[t] = result; +} + + +template +void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break; + case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break; + case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break; + case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) +// H: base resolution +// dy_dx: [B, L * D * C] +template +void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) { + switch (D) { + case 1: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break; + case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break; + case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break; + case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break; + case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5"}; + } + +} + +template +void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) { + static constexpr uint32_t N_THREAD = 256; + const uint32_t N_C = std::min(2u, C); // n_features_per_thread + const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; + switch (C) { + case 1: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 2: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 4: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 8: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +// grad: [L, B, C], float +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// grad_embeddings: [sO, C] +// H: base resolution +template +void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) { + switch (D) { + case 1: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break; + case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break; + case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break; + case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break; + case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5"}; + } +} + + + +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners) { + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grid_encode_forward", ([&] { + grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners); + })); +} + +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(grad_embeddings); + // CHECK_CUDA(dy_dx); + // CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(grad_embeddings); + // CHECK_CONTIGUOUS(dy_dx); + // CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(grad_embeddings); + // CHECK_IS_FLOATING(dy_dx); + // CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "grid_encode_backward", ([&] { + grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners); + })); + +} diff --git a/gridencoder/src/gridencoder.h b/gridencoder/src/gridencoder.h new file mode 100644 index 0000000..89b6249 --- /dev/null +++ b/gridencoder/src/gridencoder.h @@ -0,0 +1,15 @@ +#ifndef _HASH_ENCODE_H +#define _HASH_ENCODE_H + +#include +#include + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [B, L * C], float +// H: base resolution +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners); +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners); + +#endif \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..d6235ad --- /dev/null +++ b/main.py @@ -0,0 +1,260 @@ +import torch +import argparse + +from nerf_triplane.provider import NeRFDataset +from nerf_triplane.gui import NeRFGUI +from nerf_triplane.utils import * +from nerf_triplane.network import NeRFNetwork + +# torch.autograd.set_detect_anomaly(True) +# Close tf32 features. Fix low numerical accuracy on rtx30xx gpu. +try: + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False +except AttributeError as e: + print('Info. This pytorch version is not support with tf32.') + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('path', type=str) + parser.add_argument('-O', action='store_true', help="equals --fp16 --cuda_ray --exp_eye") + parser.add_argument('--test', action='store_true', help="test mode (load model and test dataset)") + parser.add_argument('--test_train', action='store_true', help="test mode (load model and train dataset)") + parser.add_argument('--data_range', type=int, nargs='*', default=[0, -1], help="data range to use") + parser.add_argument('--workspace', type=str, default='workspace') + parser.add_argument('--seed', type=int, default=0) + + ### training options + parser.add_argument('--iters', type=int, default=200000, help="training iters") + parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate") + parser.add_argument('--lr_net', type=float, default=1e-3, help="initial learning rate") + parser.add_argument('--ckpt', type=str, default='latest') + parser.add_argument('--num_rays', type=int, default=4096 * 16, help="num rays sampled per image for each training step") + parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") + parser.add_argument('--max_steps', type=int, default=16, help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=16, help="num steps sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)") + + ### loss set + parser.add_argument('--warmup_step', type=int, default=10000, help="warm up steps") + parser.add_argument('--amb_aud_loss', type=int, default=1, help="use ambient aud loss") + parser.add_argument('--amb_eye_loss', type=int, default=1, help="use ambient eye loss") + parser.add_argument('--unc_loss', type=int, default=1, help="use uncertainty loss") + parser.add_argument('--lambda_amb', type=float, default=1e-4, help="lambda for ambient loss") + + ### network backbone options + parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") + + parser.add_argument('--bg_img', type=str, default='', help="background image") + parser.add_argument('--fbg', action='store_true', help="frame-wise bg") + parser.add_argument('--exp_eye', action='store_true', help="explicitly control the eyes") + parser.add_argument('--fix_eye', type=float, default=-1, help="fixed eye area, negative to disable, set to 0-0.3 for a reasonable eye") + parser.add_argument('--smooth_eye', action='store_true', help="smooth the eye area sequence") + + parser.add_argument('--torso_shrink', type=float, default=0.8, help="shrink bg coords to allow more flexibility in deform") + + ### dataset options + parser.add_argument('--color_space', type=str, default='srgb', help="Color space, supports (linear, srgb)") + parser.add_argument('--preload', type=int, default=0, help="0 means load data from disk on-the-fly, 1 means preload to CPU, 2 means GPU.") + # (the default value is for the fox dataset) + parser.add_argument('--bound', type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.") + parser.add_argument('--scale', type=float, default=4, help="scale camera location into box[-bound, bound]^3") + parser.add_argument('--offset', type=float, nargs='*', default=[0, 0, 0], help="offset of camera location") + parser.add_argument('--dt_gamma', type=float, default=1/256, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--min_near', type=float, default=0.05, help="minimum near distance for camera") + parser.add_argument('--density_thresh', type=float, default=10, help="threshold for density grid to be occupied (sigma)") + parser.add_argument('--density_thresh_torso', type=float, default=0.01, help="threshold for density grid to be occupied (alpha)") + parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable") + + parser.add_argument('--init_lips', action='store_true', help="init lips region") + parser.add_argument('--finetune_lips', action='store_true', help="use LPIPS and landmarks to fine tune lips region") + parser.add_argument('--smooth_lips', action='store_true', help="smooth the enc_a in a exponential decay way...") + + parser.add_argument('--torso', action='store_true', help="fix head and train torso") + parser.add_argument('--head_ckpt', type=str, default='', help="head model") + + ### GUI options + parser.add_argument('--gui', action='store_true', help="start a GUI") + parser.add_argument('--W', type=int, default=450, help="GUI width") + parser.add_argument('--H', type=int, default=450, help="GUI height") + parser.add_argument('--radius', type=float, default=3.35, help="default GUI camera radius from center") + parser.add_argument('--fovy', type=float, default=21.24, help="default GUI camera fovy") + parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel") + + ### else + parser.add_argument('--att', type=int, default=2, help="audio attention mode (0 = turn off, 1 = left-direction, 2 = bi-direction)") + parser.add_argument('--aud', type=str, default='', help="audio source (empty will load the default, else should be a path to a npy file)") + parser.add_argument('--emb', action='store_true', help="use audio class + embedding instead of logits") + + parser.add_argument('--ind_dim', type=int, default=4, help="individual code dim, 0 to turn off") + parser.add_argument('--ind_num', type=int, default=10000, help="number of individual codes, should be larger than training dataset size") + + parser.add_argument('--ind_dim_torso', type=int, default=8, help="individual code dim, 0 to turn off") + + parser.add_argument('--amb_dim', type=int, default=2, help="ambient dimension") + parser.add_argument('--part', action='store_true', help="use partial training data (1/10)") + parser.add_argument('--part2', action='store_true', help="use partial training data (first 15s)") + + parser.add_argument('--train_camera', action='store_true', help="optimize camera pose") + parser.add_argument('--smooth_path', action='store_true', help="brute-force smooth camera pose trajectory with a window size") + parser.add_argument('--smooth_path_window', type=int, default=7, help="smoothing window size") + + # asr + parser.add_argument('--asr', action='store_true', help="load asr for real-time app") + parser.add_argument('--asr_wav', type=str, default='', help="load the wav and use as input") + parser.add_argument('--asr_play', action='store_true', help="play out the audio") + + parser.add_argument('--asr_model', type=str, default='deepspeech') + # parser.add_argument('--asr_model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') + # parser.add_argument('--asr_model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') + + parser.add_argument('--asr_save_feats', action='store_true') + # audio FPS + parser.add_argument('--fps', type=int, default=50) + # sliding window left-middle-right length (unit: 20ms) + parser.add_argument('-l', type=int, default=10) + parser.add_argument('-m', type=int, default=50) + parser.add_argument('-r', type=int, default=10) + + opt = parser.parse_args() + + if opt.O: + opt.fp16 = True + opt.exp_eye = True + + if opt.test and False: + opt.smooth_path = True + opt.smooth_eye = True + opt.smooth_lips = True + + opt.cuda_ray = True + # assert opt.cuda_ray, "Only support CUDA ray mode." + + if opt.patch_size > 1: + # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." + assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays." + + # if opt.finetune_lips: + # # do not update density grid in finetune stage + # opt.update_extra_interval = 1e9 + + print(opt) + + seed_everything(opt.seed) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + model = NeRFNetwork(opt) + + # manually load state dict for head + if opt.torso and opt.head_ckpt != '': + + model_dict = torch.load(opt.head_ckpt, map_location='cpu')['model'] + + missing_keys, unexpected_keys = model.load_state_dict(model_dict, strict=False) + + if len(missing_keys) > 0: + print(f"[WARN] missing keys: {missing_keys}") + if len(unexpected_keys) > 0: + print(f"[WARN] unexpected keys: {unexpected_keys}") + + # freeze these keys + for k, v in model.named_parameters(): + if k in model_dict: + # print(f'[INFO] freeze {k}, {v.shape}') + v.requires_grad = False + + + # print(model) + + criterion = torch.nn.MSELoss(reduction='none') + + if opt.test: + + if opt.gui: + metrics = [] # use no metric in GUI for faster initialization... + else: + # metrics = [PSNRMeter(), LPIPSMeter(device=device)] + metrics = [PSNRMeter(), LPIPSMeter(device=device), LMDMeter(backend='fan')] + + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt) + + if opt.test_train: + test_set = NeRFDataset(opt, device=device, type='train') + # a manual fix to test on the training dataset + test_set.training = False + test_set.num_rays = -1 + test_loader = test_set.dataloader() + else: + test_loader = NeRFDataset(opt, device=device, type='test').dataloader() + + + # temp fix: for update_extra_states + model.aud_features = test_loader._data.auds + model.eye_areas = test_loader._data.eye_area + + if opt.gui: + # we still need test_loader to provide audio features for testing. + with NeRFGUI(opt, trainer, test_loader) as gui: + gui.render() + + else: + ### test and save video (fast) + trainer.test(test_loader) + + ### evaluate metrics (slow) + if test_loader.has_gt: + trainer.evaluate(test_loader) + + + + else: + + optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr, opt.lr_net), betas=(0, 0.99), eps=1e-8) + + train_loader = NeRFDataset(opt, device=device, type='train').dataloader() + + assert len(train_loader) < opt.ind_num, f"[ERROR] dataset too many frames: {len(train_loader)}, please increase --ind_num to this number!" + + # temp fix: for update_extra_states + model.aud_features = train_loader._data.auds + model.eye_area = train_loader._data.eye_area + model.poses = train_loader._data.poses + + # decay to 0.1 * init_lr at last iter step + if opt.finetune_lips: + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.05 ** (iter / opt.iters)) + else: + scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.5 ** (iter / opt.iters)) + + metrics = [PSNRMeter(), LPIPSMeter(device=device)] + + eval_interval = max(1, int(5000 / len(train_loader))) + trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=eval_interval) + with open(os.path.join(opt.workspace, 'opt.txt'), 'a') as f: + f.write(str(opt)) + if opt.gui: + with NeRFGUI(opt, trainer, train_loader) as gui: + gui.render() + + else: + valid_loader = NeRFDataset(opt, device=device, type='val', downscale=1).dataloader() + + max_epochs = np.ceil(opt.iters / len(train_loader)).astype(np.int32) + print(f'[INFO] max_epoch = {max_epochs}') + trainer.train(train_loader, valid_loader, max_epochs) + + # free some mem + del train_loader, valid_loader + torch.cuda.empty_cache() + + # also test + test_loader = NeRFDataset(opt, device=device, type='test').dataloader() + + if test_loader.has_gt: + trainer.evaluate(test_loader) # blender has gt, so evaluate it. + + trainer.test(test_loader) \ No newline at end of file diff --git a/nerf_triplane/asr.py b/nerf_triplane/asr.py new file mode 100644 index 0000000..dc8db9c --- /dev/null +++ b/nerf_triplane/asr.py @@ -0,0 +1,419 @@ +import time +import numpy as np +import torch +import torch.nn.functional as F +from transformers import AutoModelForCTC, AutoProcessor + +import pyaudio +import soundfile as sf +import resampy + +from queue import Queue +from threading import Thread, Event + + +def _read_frame(stream, exit_event, queue, chunk): + + while True: + if exit_event.is_set(): + print(f'[INFO] read frame thread ends') + break + frame = stream.read(chunk, exception_on_overflow=False) + frame = np.frombuffer(frame, dtype=np.int16).astype(np.float32) / 32767 # [chunk] + queue.put(frame) + +def _play_frame(stream, exit_event, queue, chunk): + + while True: + if exit_event.is_set(): + print(f'[INFO] play frame thread ends') + break + frame = queue.get() + frame = (frame * 32767).astype(np.int16).tobytes() + stream.write(frame, chunk) + +class ASR: + def __init__(self, opt): + + self.opt = opt + + self.play = opt.asr_play + + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.fps = opt.fps # 20 ms per frame + self.sample_rate = 16000 + self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000) + self.mode = 'live' if opt.asr_wav == '' else 'file' + + if 'esperanto' in self.opt.asr_model: + self.audio_dim = 44 + elif 'deepspeech' in self.opt.asr_model: + self.audio_dim = 29 + else: + self.audio_dim = 32 + + # prepare context cache + # each segment is (stride_left + ctx + stride_right) * 20ms, latency should be (ctx + stride_right) * 20ms + self.context_size = opt.m + self.stride_left_size = opt.l + self.stride_right_size = opt.r + self.text = '[START]\n' + self.terminated = False + self.frames = [] + + # pad left frames + if self.stride_left_size > 0: + self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) + + + self.exit_event = Event() + self.audio_instance = pyaudio.PyAudio() + + # create input stream + if self.mode == 'file': + self.file_stream = self.create_file_stream() + else: + # start a background process to read frames + self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk) + self.queue = Queue() + self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk)) + + # play out the audio too...? + if self.play: + self.output_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=False, output=True, frames_per_buffer=self.chunk) + self.output_queue = Queue() + self.process_play_frame = Thread(target=_play_frame, args=(self.output_stream, self.exit_event, self.output_queue, self.chunk)) + + # current location of audio + self.idx = 0 + + # create wav2vec model + print(f'[INFO] loading ASR model {self.opt.asr_model}...') + self.processor = AutoProcessor.from_pretrained(opt.asr_model) + self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) + + # prepare to save logits + if self.opt.asr_save_feats: + self.all_feats = [] + + # the extracted features + # use a loop queue to efficiently record endless features: [f--t---][-------][-------] + self.feat_buffer_size = 4 + self.feat_buffer_idx = 0 + self.feat_queue = torch.zeros(self.feat_buffer_size * self.context_size, self.audio_dim, dtype=torch.float32, device=self.device) + + # TODO: hard coded 16 and 8 window size... + self.front = self.feat_buffer_size * self.context_size - 8 # fake padding + self.tail = 8 + # attention window... + self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding... + + # warm up steps needed: mid + right + window_size + attention_size + self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3 + + self.listening = False + self.playing = False + + def listen(self): + # start + if self.mode == 'live' and not self.listening: + print(f'[INFO] starting read frame thread...') + self.process_read_frame.start() + self.listening = True + + if self.play and not self.playing: + print(f'[INFO] starting play frame thread...') + self.process_play_frame.start() + self.playing = True + + def stop(self): + + self.exit_event.set() + + if self.play: + self.output_stream.stop_stream() + self.output_stream.close() + if self.playing: + self.process_play_frame.join() + self.playing = False + + if self.mode == 'live': + self.input_stream.stop_stream() + self.input_stream.close() + if self.listening: + self.process_read_frame.join() + self.listening = False + + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + + self.stop() + + if self.mode == 'live': + # live mode: also print the result text. + self.text += '\n[END]' + print(self.text) + + def get_next_feat(self): + # return a [1/8, 16] window, for the next input to nerf side. + + while len(self.att_feats) < 8: + # [------f+++t-----] + if self.front < self.tail: + feat = self.feat_queue[self.front:self.tail] + # [++t-----------f+] + else: + feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0) + + self.front = (self.front + 2) % self.feat_queue.shape[0] + self.tail = (self.tail + 2) % self.feat_queue.shape[0] + + # print(self.front, self.tail, feat.shape) + + self.att_feats.append(feat.permute(1, 0)) + + att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16] + + # discard old + self.att_feats = self.att_feats[1:] + + return att_feat + + def run_step(self): + + if self.terminated: + return + + # get a frame of audio + frame = self.get_audio_frame() + + # the last frame + if frame is None: + # terminate, but always run the network for the left frames + self.terminated = True + else: + self.frames.append(frame) + # put to output + if self.play: + self.output_queue.put(frame) + # context not enough, do not run network. + if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: + return + + inputs = np.concatenate(self.frames) # [N * chunk] + + # discard the old part to save memory + if not self.terminated: + self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] + + logits, labels, text = self.frame_to_text(inputs) + feats = logits # better lips-sync than labels + + # save feats + if self.opt.asr_save_feats: + self.all_feats.append(feats) + + # record the feats efficiently.. (no concat, constant memory) + start = self.feat_buffer_idx * self.context_size + end = start + feats.shape[0] + self.feat_queue[start:end] = feats + self.feat_buffer_idx = (self.feat_buffer_idx + 1) % self.feat_buffer_size + + # very naive, just concat the text output. + if text != '': + self.text = self.text + ' ' + text + + # will only run once at ternimation + if self.terminated: + self.text += '\n[END]' + print(self.text) + if self.opt.asr_save_feats: + print(f'[INFO] save all feats for training purpose... ') + feats = torch.cat(self.all_feats, dim=0) # [N, C] + # print('[INFO] before unfold', feats.shape) + window_size = 16 + padding = window_size // 2 + feats = feats.view(-1, self.audio_dim).permute(1, 0).contiguous() # [C, M] + feats = feats.view(1, self.audio_dim, -1, 1) # [1, C, M, 1] + unfold_feats = F.unfold(feats, kernel_size=(window_size, 1), padding=(padding, 0), stride=(2, 1)) # [1, C * window_size, M / 2 + 1] + unfold_feats = unfold_feats.view(self.audio_dim, window_size, -1).permute(2, 1, 0).contiguous() # [C, window_size, M / 2 + 1] --> [M / 2 + 1, window_size, C] + # print('[INFO] after unfold', unfold_feats.shape) + # save to a npy file + if 'esperanto' in self.opt.asr_model: + output_path = self.opt.asr_wav.replace('.wav', '_eo.npy') + else: + output_path = self.opt.asr_wav.replace('.wav', '.npy') + np.save(output_path, unfold_feats.cpu().numpy()) + print(f"[INFO] saved logits to {output_path}") + + def create_file_stream(self): + + stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64 + stream = stream.astype(np.float32) + + if stream.ndim > 1: + print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + stream = stream[:, 0] + + if sample_rate != self.sample_rate: + print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) + + print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}') + + return stream + + + def create_pyaudio_stream(self): + + import pyaudio + + print(f'[INFO] creating live audio stream ...') + + audio = pyaudio.PyAudio() + + # get devices + info = audio.get_host_api_info_by_index(0) + n_devices = info.get('deviceCount') + + for i in range(0, n_devices): + if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: + name = audio.get_device_info_by_host_api_device_index(0, i).get('name') + print(f'[INFO] choose audio device {name}, id {i}') + break + + # get stream + stream = audio.open(input_device_index=i, + format=pyaudio.paInt16, + channels=1, + rate=self.sample_rate, + input=True, + frames_per_buffer=self.chunk) + + return audio, stream + + + def get_audio_frame(self): + + if self.mode == 'file': + + if self.idx < self.file_stream.shape[0]: + frame = self.file_stream[self.idx: self.idx + self.chunk] + self.idx = self.idx + self.chunk + return frame + else: + return None + + else: + + frame = self.queue.get() + # print(f'[INFO] get frame {frame.shape}') + + self.idx = self.idx + self.chunk + + return frame + + + def frame_to_text(self, frame): + # frame: [N * 320], N = (context_size + 2 * stride_size) + + inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True) + + with torch.no_grad(): + result = self.model(inputs.input_values.to(self.device)) + logits = result.logits # [1, N - 1, 32] + + # cut off stride + left = max(0, self.stride_left_size) + right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input. + + # do not cut right if terminated. + if self.terminated: + right = logits.shape[1] + + logits = logits[:, left:right] + + # print(frame.shape, inputs.input_values.shape, logits.shape) + + predicted_ids = torch.argmax(logits, dim=-1) + transcription = self.processor.batch_decode(predicted_ids)[0].lower() + + + # for esperanto + # labels = np.array(['ŭ', '»', 'c', 'ĵ', 'ñ', '”', '„', '“', 'ǔ', 'o', 'ĝ', 'm', 'k', 'd', 'a', 'ŝ', 'z', 'i', '«', '—', '‘', 'ĥ', 'f', 'y', 'h', 'j', '|', 'r', 'u', 'ĉ', 's', '–', 'fi', 'l', 'p', '’', 'g', 'v', 't', 'b', 'n', 'e', '[UNK]', '[PAD]']) + + # labels = np.array([' ', ' ', ' ', '-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z']) + # print(''.join(labels[predicted_ids[0].detach().cpu().long().numpy()])) + # print(predicted_ids[0]) + # print(transcription) + + return logits[0], predicted_ids[0], transcription # [N,] + + + def run(self): + + self.listen() + + while not self.terminated: + self.run_step() + + def clear_queue(self): + # clear the queue, to reduce potential latency... + print(f'[INFO] clear queue') + if self.mode == 'live': + self.queue.queue.clear() + if self.play: + self.output_queue.queue.clear() + + def warm_up(self): + + self.listen() + + print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') + t = time.time() + for _ in range(self.warm_up_steps): + self.run_step() + if torch.cuda.is_available(): + torch.cuda.synchronize() + t = time.time() - t + print(f'[INFO] warm-up done, actual latency = {t:.6f}s') + + self.clear_queue() + + + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--wav', type=str, default='') + parser.add_argument('--play', action='store_true', help="play out the audio") + + parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') + # parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') + + parser.add_argument('--save_feats', action='store_true') + # audio FPS + parser.add_argument('--fps', type=int, default=50) + # sliding window left-middle-right length. + parser.add_argument('-l', type=int, default=10) + parser.add_argument('-m', type=int, default=50) + parser.add_argument('-r', type=int, default=10) + + opt = parser.parse_args() + + # fix + opt.asr_wav = opt.wav + opt.asr_play = opt.play + opt.asr_model = opt.model + opt.asr_save_feats = opt.save_feats + + if 'deepspeech' in opt.asr_model: + raise ValueError("DeepSpeech features should not use this code to extract...") + + with ASR(opt) as asr: + asr.run() \ No newline at end of file diff --git a/nerf_triplane/gui.py b/nerf_triplane/gui.py new file mode 100644 index 0000000..7a6798d --- /dev/null +++ b/nerf_triplane/gui.py @@ -0,0 +1,565 @@ +import math +import torch +import numpy as np +import dearpygui.dearpygui as dpg +from scipy.spatial.transform import Rotation as R + +from .utils import * + +from .asr import ASR + + +class OrbitCamera: + def __init__(self, W, H, r=2, fovy=60): + self.W = W + self.H = H + self.radius = r # camera distance from center + self.fovy = fovy # in degree + self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point + self.rot = R.from_matrix([[0, -1, 0], [0, 0, -1], [1, 0, 0]]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention) + self.up = np.array([1, 0, 0], dtype=np.float32) # need to be normalized! + + # pose + @property + def pose(self): + # first move camera to radius + res = np.eye(4, dtype=np.float32) + res[2, 3] -= self.radius + # rotate + rot = np.eye(4, dtype=np.float32) + rot[:3, :3] = self.rot.as_matrix() + res = rot @ res + # translate + res[:3, 3] -= self.center + return res + + def update_pose(self, pose): + # pose: [4, 4] numpy array + # assert self.center is 0 + self.radius = np.linalg.norm(pose[:3, 3]) + T = np.eye(4) + T[2, 3] = -self.radius + rot = pose @ np.linalg.inv(T) + self.rot = R.from_matrix(rot[:3, :3]) + + def update_intrinsics(self, intrinsics): + fl_x, fl_y, cx, cy = intrinsics + self.W = int(cx * 2) + self.H = int(cy * 2) + self.fovy = np.rad2deg(2 * np.arctan2(self.H, 2 * fl_y)) + + # intrinsics + @property + def intrinsics(self): + focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2)) + return np.array([focal, focal, self.W // 2, self.H // 2]) + + def orbit(self, dx, dy): + # rotate along camera up/side axis! + side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized. + rotvec_x = self.up * np.radians(-0.01 * dx) + rotvec_y = side * np.radians(-0.01 * dy) + self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot + + def scale(self, delta): + self.radius *= 1.1 ** (-delta) + + def pan(self, dx, dy, dz=0): + # pan in camera coordinate system (careful on the sensitivity!) + self.center += 0.0001 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz]) + + +class NeRFGUI: + def __init__(self, opt, trainer, data_loader, debug=True): + self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. + self.W = opt.W + self.H = opt.H + self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) + self.debug = debug + self.training = False + self.step = 0 # training step + + self.trainer = trainer + self.data_loader = data_loader + + # override with dataloader's intrinsics + self.W = data_loader._data.W + self.H = data_loader._data.H + self.cam.update_intrinsics(data_loader._data.intrinsics) + + # use dataloader's pose + pose_init = data_loader._data.poses[0] + self.cam.update_pose(pose_init.detach().cpu().numpy()) + + # use dataloader's bg + bg_img = data_loader._data.bg_img #.view(1, -1, 3) + if self.H != bg_img.shape[0] or self.W != bg_img.shape[1]: + bg_img = F.interpolate(bg_img.permute(2, 0, 1).unsqueeze(0).contiguous(), (self.H, self.W), mode='bilinear').squeeze(0).permute(1, 2, 0).contiguous() + self.bg_color = bg_img.view(1, -1, 3) + + # audio features (from dataloader, only used in non-playing mode) + self.audio_features = data_loader._data.auds # [N, 29, 16] + self.audio_idx = 0 + + # control eye + self.eye_area = None if not self.opt.exp_eye else data_loader._data.eye_area.mean().item() + + # playing seq from dataloader, or pause. + self.playing = False + self.loader = iter(data_loader) + + self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) + self.need_update = True # camera moved, should reset accumulation + self.spp = 1 # sample per pixel + self.mode = 'image' # choose from ['image', 'depth'] + + self.dynamic_resolution = False # assert False! + self.downscale = 1 + self.train_steps = 16 + + self.ind_index = 0 + self.ind_num = trainer.model.individual_codes.shape[0] + + # build asr + if self.opt.asr: + self.asr = ASR(opt) + + dpg.create_context() + self.register_dpg() + self.test_step() + + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.opt.asr: + self.asr.stop() + dpg.destroy_context() + + def train_step(self): + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + + outputs = self.trainer.train_gui(self.data_loader, step=self.train_steps) + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + self.step += self.train_steps + self.need_update = True + + dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}') + + # dynamic train steps + # max allowed train time per-frame is 500 ms + full_t = t / self.train_steps * 16 + train_steps = min(16, max(4, int(16 * 500 / full_t))) + if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8: + self.train_steps = train_steps + + def prepare_buffer(self, outputs): + if self.mode == 'image': + return outputs['image'] + else: + return np.expand_dims(outputs['depth'], -1).repeat(3, -1) + + def test_step(self): + + if self.need_update or self.spp < self.opt.max_spp: + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + + if self.playing: + try: + data = next(self.loader) + except StopIteration: + self.loader = iter(self.data_loader) + data = next(self.loader) + + if self.opt.asr: + # use the live audio stream + data['auds'] = self.asr.get_next_feat() + + outputs = self.trainer.test_gui_with_data(data, self.W, self.H) + + # sync local camera pose + self.cam.update_pose(data['poses_matrix'][0].detach().cpu().numpy()) + + else: + if self.audio_features is not None: + auds = get_audio_features(self.audio_features, self.opt.att, self.audio_idx) + else: + auds = None + outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, auds, self.eye_area, self.ind_index, self.bg_color, self.spp, self.downscale) + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + # update dynamic resolution + if self.dynamic_resolution: + # max allowed infer time per-frame is 200 ms + full_t = t / (self.downscale ** 2) + downscale = min(1, max(1/4, math.sqrt(200 / full_t))) + if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8: + self.downscale = downscale + + if self.need_update: + self.render_buffer = self.prepare_buffer(outputs) + self.spp = 1 + self.need_update = False + else: + self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1) + self.spp += 1 + + if self.playing: + self.need_update = True + + dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}') + dpg.set_value("_log_spp", self.spp) + dpg.set_value("_texture", self.render_buffer) + + + def register_dpg(self): + + ### register texture + + with dpg.texture_registry(show=False): + dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") + + ### register window + + # the rendered image, as the primary window + with dpg.window(tag="_primary_window", width=self.W, height=self.H): + + # add the texture + dpg.add_image("_texture") + + # dpg.set_primary_window("_primary_window", True) + + dpg.show_tool(dpg.mvTool_Metrics) + + # control window + with dpg.window(label="Control", tag="_control_window", width=400, height=300): + + # button theme + with dpg.theme() as theme_button: + with dpg.theme_component(dpg.mvButton): + dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) + dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) + + # time + if not self.opt.test: + with dpg.group(horizontal=True): + dpg.add_text("Train time: ") + dpg.add_text("no data", tag="_log_train_time") + + with dpg.group(horizontal=True): + dpg.add_text("Infer time: ") + dpg.add_text("no data", tag="_log_infer_time") + + with dpg.group(horizontal=True): + dpg.add_text("SPP: ") + dpg.add_text("1", tag="_log_spp") + + # train button + if not self.opt.test: + with dpg.collapsing_header(label="Train", default_open=True): + + # train / stop + with dpg.group(horizontal=True): + dpg.add_text("Train: ") + + def callback_train(sender, app_data): + if self.training: + self.training = False + dpg.configure_item("_button_train", label="start") + else: + self.training = True + dpg.configure_item("_button_train", label="stop") + + dpg.add_button(label="start", tag="_button_train", callback=callback_train) + dpg.bind_item_theme("_button_train", theme_button) + + def callback_reset(sender, app_data): + @torch.no_grad() + def weight_reset(m: nn.Module): + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + self.trainer.model.apply(fn=weight_reset) + self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter + self.need_update = True + + dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset) + dpg.bind_item_theme("_button_reset", theme_button) + + # save ckpt + with dpg.group(horizontal=True): + dpg.add_text("Checkpoint: ") + + def callback_save(sender, app_data): + self.trainer.save_checkpoint(full=True, best=False) + dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1])) + self.trainer.epoch += 1 # use epoch to indicate different calls. + + dpg.add_button(label="save", tag="_button_save", callback=callback_save) + dpg.bind_item_theme("_button_save", theme_button) + + dpg.add_text("", tag="_log_ckpt") + + # save mesh + with dpg.group(horizontal=True): + dpg.add_text("Marching Cubes: ") + + def callback_mesh(sender, app_data): + self.trainer.save_mesh(resolution=256, threshold=10) + dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply') + self.trainer.epoch += 1 # use epoch to indicate different calls. + + dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh) + dpg.bind_item_theme("_button_mesh", theme_button) + + dpg.add_text("", tag="_log_mesh") + + with dpg.group(horizontal=True): + dpg.add_text("", tag="_log_train_log") + + + # rendering options + with dpg.collapsing_header(label="Options", default_open=True): + + # playing + with dpg.group(horizontal=True): + dpg.add_text("Play: ") + + def callback_play(sender, app_data): + + if self.playing: + self.playing = False + dpg.configure_item("_button_play", label="start") + else: + self.playing = True + dpg.configure_item("_button_play", label="stop") + if self.opt.asr: + self.asr.warm_up() + self.need_update = True + + dpg.add_button(label="start", tag="_button_play", callback=callback_play) + dpg.bind_item_theme("_button_play", theme_button) + + # set asr + if self.opt.asr: + + # clear queue button + def callback_clear_queue(sender, app_data): + + self.asr.clear_queue() + self.need_update = True + + dpg.add_button(label="clear", tag="_button_clear_queue", callback=callback_clear_queue) + dpg.bind_item_theme("_button_clear_queue", theme_button) + + # dynamic rendering resolution + with dpg.group(horizontal=True): + + def callback_set_dynamic_resolution(sender, app_data): + if self.dynamic_resolution: + self.dynamic_resolution = False + self.downscale = 1 + else: + self.dynamic_resolution = True + self.need_update = True + + # Disable dynamic resolution for face. + # dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution) + dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution") + + # mode combo + def callback_change_mode(sender, app_data): + self.mode = app_data + self.need_update = True + + dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode) + + + # bg_color picker + def callback_change_bg(sender, app_data): + self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1] + self.need_update = True + + dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg) + + # audio index slider + if not self.opt.asr: + def callback_set_audio_index(sender, app_data): + self.audio_idx = app_data + self.need_update = True + + dpg.add_slider_int(label="Audio", min_value=0, max_value=self.audio_features.shape[0] - 1, format="%d", default_value=self.audio_idx, callback=callback_set_audio_index) + + # ind code index slider + if self.opt.ind_dim > 0: + def callback_set_individual_code(sender, app_data): + self.ind_index = app_data + self.need_update = True + + dpg.add_slider_int(label="Individual", min_value=0, max_value=self.ind_num - 1, format="%d", default_value=self.ind_index, callback=callback_set_individual_code) + + # eye area slider + if self.opt.exp_eye: + def callback_set_eye(sender, app_data): + self.eye_area = app_data + self.need_update = True + + dpg.add_slider_float(label="eye area", min_value=0, max_value=0.5, format="%.2f percent", default_value=self.eye_area, callback=callback_set_eye) + + # fov slider + def callback_set_fovy(sender, app_data): + self.cam.fovy = app_data + self.need_update = True + + dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy) + + # dt_gamma slider + def callback_set_dt_gamma(sender, app_data): + self.opt.dt_gamma = app_data + self.need_update = True + + dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma) + + # max_steps slider + def callback_set_max_steps(sender, app_data): + self.opt.max_steps = app_data + self.need_update = True + + dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps) + + # aabb slider + def callback_set_aabb(sender, app_data, user_data): + # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax) + self.trainer.model.aabb_infer[user_data] = app_data + + # also change train aabb ? [better not...] + #self.trainer.model.aabb_train[user_data] = app_data + + self.need_update = True + + dpg.add_separator() + dpg.add_text("Axis-aligned bounding box:") + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5) + + + # debug info + if self.debug: + with dpg.collapsing_header(label="Debug"): + # pose + dpg.add_separator() + dpg.add_text("Camera Pose:") + dpg.add_text(str(self.cam.pose), tag="_log_pose") + + + ### register camera handler + + def callback_camera_drag_rotate(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.orbit(dx, dy) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + def callback_camera_wheel_scale(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + delta = app_data + + self.cam.scale(delta) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + def callback_camera_drag_pan(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.pan(dx, dy) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + with dpg.handler_registry(): + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate) + dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan) + + + dpg.create_viewport(title='RAD-NeRF', width=1080, height=720, resizable=True) + + ### global theme + with dpg.theme() as theme_no_padding: + with dpg.theme_component(dpg.mvAll): + # set all padding to 0 to avoid scroll bar + dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) + + dpg.bind_item_theme("_primary_window", theme_no_padding) + + dpg.setup_dearpygui() + + #dpg.show_metrics() + + dpg.show_viewport() + + + def render(self): + + while dpg.is_dearpygui_running(): + # update texture every frame + if self.training: + self.train_step() + # audio stream thread... + if self.opt.asr and self.playing: + # run 2 ASR steps (audio is at 50FPS, video is at 25FPS) + for _ in range(2): + self.asr.run_step() + self.test_step() + dpg.render_dearpygui_frame() \ No newline at end of file diff --git a/nerf_triplane/network.py b/nerf_triplane/network.py new file mode 100644 index 0000000..fc41359 --- /dev/null +++ b/nerf_triplane/network.py @@ -0,0 +1,352 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from encoding import get_encoder +from .renderer import NeRFRenderer + +# Audio feature extractor +class AudioAttNet(nn.Module): + def __init__(self, dim_aud=64, seq_len=8): + super(AudioAttNet, self).__init__() + self.seq_len = seq_len + self.dim_aud = dim_aud + self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len + nn.Conv1d(self.dim_aud, 16, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True) + ) + self.attentionNet = nn.Sequential( + nn.Linear(in_features=self.seq_len, out_features=self.seq_len, bias=True), + nn.Softmax(dim=1) + ) + + def forward(self, x): + # x: [1, seq_len, dim_aud] + y = x.permute(0, 2, 1) # [1, dim_aud, seq_len] + y = self.attentionConvNet(y) + y = self.attentionNet(y.view(1, self.seq_len)).view(1, self.seq_len, 1) + return torch.sum(y * x, dim=1) # [1, dim_aud] + + +# Audio feature extractor +class AudioNet(nn.Module): + def __init__(self, dim_in=29, dim_aud=64, win_size=16): + super(AudioNet, self).__init__() + self.win_size = win_size + self.dim_aud = dim_aud + self.encoder_conv = nn.Sequential( # n x 29 x 16 + nn.Conv1d(dim_in, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 8 + nn.LeakyReLU(0.02, True), + nn.Conv1d(32, 32, kernel_size=3, stride=2, padding=1, bias=True), # n x 32 x 4 + nn.LeakyReLU(0.02, True), + nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 2 + nn.LeakyReLU(0.02, True), + nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1, bias=True), # n x 64 x 1 + nn.LeakyReLU(0.02, True), + ) + self.encoder_fc1 = nn.Sequential( + nn.Linear(64, 64), + nn.LeakyReLU(0.02, True), + nn.Linear(64, dim_aud), + ) + + def forward(self, x): + half_w = int(self.win_size/2) + x = x[:, :, 8-half_w:8+half_w] + x = self.encoder_conv(x).squeeze(-1) + x = self.encoder_fc1(x) + return x + + +class MLP(nn.Module): + def __init__(self, dim_in, dim_out, dim_hidden, num_layers): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_hidden = dim_hidden + self.num_layers = num_layers + + net = [] + for l in range(num_layers): + net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=False)) + + self.net = nn.ModuleList(net) + + def forward(self, x): + for l in range(self.num_layers): + x = self.net[l](x) + if l != self.num_layers - 1: + x = F.relu(x, inplace=True) + # x = F.dropout(x, p=0.1, training=self.training) + + return x + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + opt, + # torso net (hard coded for now) + ): + super().__init__(opt) + + # audio embedding + self.emb = self.opt.emb + + if 'esperanto' in self.opt.asr_model: + self.audio_in_dim = 44 + elif 'deepspeech' in self.opt.asr_model: + self.audio_in_dim = 29 + else: + self.audio_in_dim = 32 + + if self.emb: + self.embedding = nn.Embedding(self.audio_in_dim, self.audio_in_dim) + + # audio network + audio_dim = 32 + self.audio_dim = audio_dim + self.audio_net = AudioNet(self.audio_in_dim, self.audio_dim) + + self.att = self.opt.att + if self.att > 0: + self.audio_att_net = AudioAttNet(self.audio_dim) + + # DYNAMIC PART + self.num_levels = 12 + self.level_dim = 1 + self.encoder_xy, self.in_dim_xy = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound) + self.encoder_yz, self.in_dim_yz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound) + self.encoder_xz, self.in_dim_xz = get_encoder('hashgrid', input_dim=2, num_levels=self.num_levels, level_dim=self.level_dim, base_resolution=64, log2_hashmap_size=14, desired_resolution=512 * self.bound) + + self.in_dim = self.in_dim_xy + self.in_dim_yz + self.in_dim_xz + + ## sigma network + self.num_layers = 3 + self.hidden_dim = 64 + self.geo_feat_dim = 64 + self.eye_att_net = MLP(self.in_dim, 1, 16, 2) + self.eye_dim = 1 if self.exp_eye else 0 + self.sigma_net = MLP(self.in_dim + self.audio_dim + self.eye_dim, 1 + self.geo_feat_dim, self.hidden_dim, self.num_layers) + ## color network + self.num_layers_color = 2 + self.hidden_dim_color = 64 + self.encoder_dir, self.in_dim_dir = get_encoder('spherical_harmonics') + self.color_net = MLP(self.in_dim_dir + self.geo_feat_dim + self.individual_dim, 3, self.hidden_dim_color, self.num_layers_color) + + self.unc_net = MLP(self.in_dim, 1, 32, 2) + + self.aud_ch_att_net = MLP(self.in_dim, self.audio_dim, 64, 2) + + self.testing = False + + if self.torso: + # torso deform network + self.register_parameter('anchor_points', + nn.Parameter(torch.tensor([[0.01, 0.01, 0.1, 1], [-0.1, -0.1, 0.1, 1], [0.1, -0.1, 0.1, 1]]))) + self.torso_deform_encoder, self.torso_deform_in_dim = get_encoder('frequency', input_dim=2, multires=8) + # self.torso_deform_encoder, self.torso_deform_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=1, base_resolution=16, log2_hashmap_size=16, desired_resolution=512) + self.anchor_encoder, self.anchor_in_dim = get_encoder('frequency', input_dim=6, multires=3) + self.torso_deform_net = MLP(self.torso_deform_in_dim + self.anchor_in_dim + self.individual_dim_torso, 2, 32, 3) + + # torso color network + self.torso_encoder, self.torso_in_dim = get_encoder('tiledgrid', input_dim=2, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=16, desired_resolution=2048) + self.torso_net = MLP(self.torso_in_dim + self.torso_deform_in_dim + self.anchor_in_dim + self.individual_dim_torso, 4, 32, 3) + + + def forward_torso(self, x, poses, c=None): + # x: [N, 2] in [-1, 1] + # head poses: [1, 4, 4] + # c: [1, ind_dim], individual code + + # test: shrink x + x = x * self.opt.torso_shrink + + # deformation-based + wrapped_anchor = self.anchor_points[None, ...] @ poses.permute(0, 2, 1).inverse() + wrapped_anchor = (wrapped_anchor[:, :, :2] / wrapped_anchor[:, :, 3, None] / wrapped_anchor[:, :, 2, None]).view(1, -1) + # print(wrapped_anchor) + # enc_pose = self.pose_encoder(poses) + enc_anchor = self.anchor_encoder(wrapped_anchor) + enc_x = self.torso_deform_encoder(x) + + if c is not None: + h = torch.cat([enc_x, enc_anchor.repeat(x.shape[0], 1), c.repeat(x.shape[0], 1)], dim=-1) + else: + h = torch.cat([enc_x, enc_anchor.repeat(x.shape[0], 1)], dim=-1) + + dx = self.torso_deform_net(h) + + x = (x + dx).clamp(-1, 1) + + x = self.torso_encoder(x, bound=1) + + # h = torch.cat([x, h, enc_a.repeat(x.shape[0], 1)], dim=-1) + h = torch.cat([x, h], dim=-1) + + h = self.torso_net(h) + + alpha = torch.sigmoid(h[..., :1])*(1 + 2*0.001) - 0.001 + color = torch.sigmoid(h[..., 1:])*(1 + 2*0.001) - 0.001 + + return alpha, color, dx + + + @staticmethod + @torch.jit.script + def split_xyz(x): + xy, yz, xz = x[:, :-1], x[:, 1:], torch.cat([x[:,:1], x[:,-1:]], dim=-1) + return xy, yz, xz + + + def encode_x(self, xyz, bound): + # x: [N, 3], in [-bound, bound] + N, M = xyz.shape + xy, yz, xz = self.split_xyz(xyz) + feat_xy = self.encoder_xy(xy, bound=bound) + feat_yz = self.encoder_yz(yz, bound=bound) + feat_xz = self.encoder_xz(xz, bound=bound) + + return torch.cat([feat_xy, feat_yz, feat_xz], dim=-1) + + + def encode_audio(self, a): + # a: [1, 29, 16] or [8, 29, 16], audio features from deepspeech + # if emb, a should be: [1, 16] or [8, 16] + + # fix audio traininig + if a is None: return None + + if self.emb: + a = self.embedding(a).transpose(-1, -2).contiguous() # [1/8, 29, 16] + + enc_a = self.audio_net(a) # [1/8, 64] + + if self.att > 0: + enc_a = self.audio_att_net(enc_a.unsqueeze(0)) # [1, 64] + + return enc_a + + + def predict_uncertainty(self, unc_inp): + if self.testing or not self.opt.unc_loss: + unc = torch.zeros_like(unc_inp) + else: + unc = self.unc_net(unc_inp.detach()) + + return unc + + + def forward(self, x, d, enc_a, c, e=None): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] + # enc_a: [1, aud_dim] + # c: [1, ind_dim], individual code + # e: [1, 1], eye feature + enc_x = self.encode_x(x, bound=self.bound) + + sigma_result = self.density(x, enc_a, e, enc_x) + sigma = sigma_result['sigma'] + geo_feat = sigma_result['geo_feat'] + aud_ch_att = sigma_result['ambient_aud'] + eye_att = sigma_result['ambient_eye'] + + # color + enc_d = self.encoder_dir(d) + + if c is not None: + h = torch.cat([enc_d, geo_feat, c.repeat(x.shape[0], 1)], dim=-1) + else: + h = torch.cat([enc_d, geo_feat], dim=-1) + + h_color = self.color_net(h) + color = torch.sigmoid(h_color)*(1 + 2*0.001) - 0.001 + + uncertainty = self.predict_uncertainty(enc_x) + uncertainty = torch.log(1 + torch.exp(uncertainty)) + + return sigma, color, aud_ch_att, eye_att, uncertainty[..., None] + + + def density(self, x, enc_a, e=None, enc_x=None): + # x: [N, 3], in [-bound, bound] + if enc_x is None: + enc_x = self.encode_x(x, bound=self.bound) + + enc_a = enc_a.repeat(enc_x.shape[0], 1) + aud_ch_att = self.aud_ch_att_net(enc_x) + enc_w = enc_a * aud_ch_att + + if e is not None: + # e = self.encoder_eye(e) + eye_att = torch.sigmoid(self.eye_att_net(enc_x)) + e = e * eye_att + # e = e.repeat(enc_x.shape[0], 1) + h = torch.cat([enc_x, enc_w, e], dim=-1) + else: + h = torch.cat([enc_x, enc_w], dim=-1) + + h = self.sigma_net(h) + + sigma = torch.exp(h[..., 0]) + geo_feat = h[..., 1:] + + return { + 'sigma': sigma, + 'geo_feat': geo_feat, + 'ambient_aud' : aud_ch_att.norm(dim=-1, keepdim=True), + 'ambient_eye' : eye_att, + } + + + # optimizer utils + def get_params(self, lr, lr_net, wd=0): + + # ONLY train torso + if self.torso: + params = [ + {'params': self.torso_encoder.parameters(), 'lr': lr}, + {'params': self.torso_deform_encoder.parameters(), 'lr': lr, 'weight_decay': wd}, + {'params': self.torso_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, + {'params': self.torso_deform_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, + {'params': self.anchor_points, 'lr': lr_net, 'weight_decay': wd} + ] + + if self.individual_dim_torso > 0: + params.append({'params': self.individual_codes_torso, 'lr': lr_net, 'weight_decay': wd}) + + return params + + params = [ + {'params': self.audio_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, + + {'params': self.encoder_xy.parameters(), 'lr': lr}, + {'params': self.encoder_yz.parameters(), 'lr': lr}, + {'params': self.encoder_xz.parameters(), 'lr': lr}, + # {'params': self.encoder_xyz.parameters(), 'lr': lr}, + + {'params': self.sigma_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, + {'params': self.color_net.parameters(), 'lr': lr_net, 'weight_decay': wd}, + ] + if self.att > 0: + params.append({'params': self.audio_att_net.parameters(), 'lr': lr_net * 5, 'weight_decay': 0.0001}) + if self.emb: + params.append({'params': self.embedding.parameters(), 'lr': lr}) + if self.individual_dim > 0: + params.append({'params': self.individual_codes, 'lr': lr_net, 'weight_decay': wd}) + if self.train_camera: + params.append({'params': self.camera_dT, 'lr': 1e-5, 'weight_decay': 0}) + params.append({'params': self.camera_dR, 'lr': 1e-5, 'weight_decay': 0}) + + params.append({'params': self.aud_ch_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) + params.append({'params': self.unc_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) + params.append({'params': self.eye_att_net.parameters(), 'lr': lr_net, 'weight_decay': wd}) + + return params \ No newline at end of file diff --git a/nerf_triplane/provider.py b/nerf_triplane/provider.py new file mode 100644 index 0000000..f61d92c --- /dev/null +++ b/nerf_triplane/provider.py @@ -0,0 +1,764 @@ +import os +import cv2 +import glob +import json +import tqdm +import numpy as np +from scipy.spatial.transform import Slerp, Rotation +import matplotlib.pyplot as plt + +import trimesh + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from .utils import get_audio_features, get_rays, get_bg_coords, convert_poses + +# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 +def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]): + new_pose = np.array([ + [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]], + [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]], + [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]], + [0, 0, 0, 1], + ], dtype=np.float32) + return new_pose + + +def smooth_camera_path(poses, kernel_size=5): + # smooth the camera trajectory... + # poses: [N, 4, 4], numpy array + + N = poses.shape[0] + K = kernel_size // 2 + + trans = poses[:, :3, 3].copy() # [N, 3] + rots = poses[:, :3, :3].copy() # [N, 3, 3] + + for i in range(N): + start = max(0, i - K) + end = min(N, i + K + 1) + poses[i, :3, 3] = trans[start:end].mean(0) + poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix() + + return poses + +def polygon_area(x, y): + x_ = x - x.mean() + y_ = y - y.mean() + correction = x_[-1] * y_[0] - y_[-1]* x_[0] + main_area = np.dot(x_[:-1], y_[1:]) - np.dot(y_[:-1], x_[1:]) + return 0.5 * np.abs(main_area + correction) + + +def visualize_poses(poses, size=0.1): + # poses: [B, 4, 4] + + print(f'[INFO] visualize poses: {poses.shape}') + + axes = trimesh.creation.axis(axis_length=4) + box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline() + box.colors = np.array([[128, 128, 128]] * len(box.entities)) + objects = [axes, box] + + for pose in poses: + # a camera is visualized with 8 line segments. + pos = pose[:3, 3] + a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] + b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] + c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] + d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] + + dir = (a + b + c + d) / 4 - pos + dir = dir / (np.linalg.norm(dir) + 1e-8) + o = pos + dir * 3 + + segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]]) + segs = trimesh.load_path(segs) + objects.append(segs) + + trimesh.Scene(objects).show() + + +class NeRFDataset_Test: + def __init__(self, opt, device, downscale=1): + super().__init__() + + self.opt = opt + self.device = device + self.downscale = downscale + self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box. + self.offset = opt.offset # camera offset + self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses. + self.fp16 = opt.fp16 + + self.start_index = opt.data_range[0] + self.end_index = opt.data_range[1] + + self.training = False + self.num_rays = -1 + + # load nerf-compatible format data. + + with open(opt.pose, 'r') as f: + transform = json.load(f) + + # load image size + self.H = int(transform['cy']) * 2 // downscale + self.W = int(transform['cx']) * 2 // downscale + + # read images + frames = transform["frames"] + + # use a slice of the dataset + if self.end_index == -1: # abuse... + self.end_index = len(frames) + + frames = frames[self.start_index:self.end_index] + + print(f'[INFO] load {len(frames)} frames.') + + # only load pre-calculated aud features when not live-streaming + if not self.opt.asr: + + aud_features = np.load(self.opt.aud) + + aud_features = torch.from_numpy(aud_features) + + # support both [N, 16] labels and [N, 16, K] logits + if len(aud_features.shape) == 3: + aud_features = aud_features.float().permute(0, 2, 1) # [N, 16, 29] --> [N, 29, 16] + + if self.opt.emb: + print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode') + aud_features = aud_features.argmax(1) # [N, 16] + + else: + assert self.opt.emb, "aud only provide labels, must use --emb" + aud_features = aud_features.long() + + print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}') + + self.poses = [] + self.auds = [] + self.eye_area = [] + + for f in tqdm.tqdm(frames, desc=f'Loading data'): + + pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4] + pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset) + self.poses.append(pose) + + # find the corresponding audio to the image frame + if not self.opt.asr and self.opt.aud == '': + aud = aud_features[min(f['aud_id'], aud_features.shape[0] - 1)] # careful for the last frame... + self.auds.append(aud) + + if self.opt.exp_eye: + + if 'eye_ratio' in f: + area = f['eye_ratio'] + else: + area = 0.25 # default value for opened eye + + self.eye_area.append(area) + + # load pre-extracted background image (should be the same size as training image...) + + if self.opt.bg_img == 'white': # special + bg_img = np.ones((self.H, self.W, 3), dtype=np.float32) + elif self.opt.bg_img == 'black': # special + bg_img = np.zeros((self.H, self.W, 3), dtype=np.float32) + else: # load from file + bg_img = cv2.imread(self.opt.bg_img, cv2.IMREAD_UNCHANGED) # [H, W, 3] + if bg_img.shape[0] != self.H or bg_img.shape[1] != self.W: + bg_img = cv2.resize(bg_img, (self.W, self.H), interpolation=cv2.INTER_AREA) + bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB) + bg_img = bg_img.astype(np.float32) / 255 # [H, W, 3/4] + + self.bg_img = bg_img + + self.poses = np.stack(self.poses, axis=0) + + # smooth camera path... + if self.opt.smooth_path: + self.poses = smooth_camera_path(self.poses, self.opt.smooth_path_window) + + self.poses = torch.from_numpy(self.poses) # [N, 4, 4] + + if self.opt.asr: + # live streaming, no pre-calculated auds + self.auds = None + else: + # auds corresponding to images + if self.opt.aud == '': + self.auds = torch.stack(self.auds, dim=0) # [N, 32, 16] + # auds is novel, may have a different length with images + else: + self.auds = aud_features + + self.bg_img = torch.from_numpy(self.bg_img) + + if self.opt.exp_eye: + self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N] + print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}') + + if self.opt.smooth_eye: + + # naive 5 window average + ori_eye = self.eye_area.copy() + for i in range(ori_eye.shape[0]): + start = max(0, i - 1) + end = min(ori_eye.shape[0], i + 2) + self.eye_area[i] = ori_eye[start:end].mean() + + self.eye_area = torch.from_numpy(self.eye_area).view(-1, 1) # [N, 1] + + # always preload + self.poses = self.poses.to(self.device) + + if self.auds is not None: + self.auds = self.auds.to(self.device) + + self.bg_img = self.bg_img.to(torch.half).to(self.device) + + if self.opt.exp_eye: + self.eye_area = self.eye_area.to(self.device) + + # load intrinsics + + fl_x = fl_y = transform['focal_len'] + + cx = (transform['cx'] / downscale) + cy = (transform['cy'] / downscale) + + self.intrinsics = np.array([fl_x, fl_y, cx, cy]) + + # directly build the coordinate meshgrid in [-1, 1]^2 + self.bg_coords = get_bg_coords(self.H, self.W, self.device) # [1, H*W, 2] in [-1, 1] + + def mirror_index(self, index): + size = self.poses.shape[0] + turn = index // size + res = index % size + if turn % 2 == 0: + return res + else: + return size - res - 1 + + def collate(self, index): + + B = len(index) # a list of length 1 + # assert B == 1 + + results = {} + + # audio use the original index + if self.auds is not None: + auds = get_audio_features(self.auds, self.opt.att, index[0]).to(self.device) + results['auds'] = auds + + # head pose and bg image may mirror (replay --> <-- --> <--). + index[0] = self.mirror_index(index[0]) + + poses = self.poses[index].to(self.device) # [B, 4, 4] + + rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, self.opt.patch_size) + + results['index'] = index # for ind. code + results['H'] = self.H + results['W'] = self.W + results['rays_o'] = rays['rays_o'] + results['rays_d'] = rays['rays_d'] + + if self.opt.exp_eye: + results['eye'] = self.eye_area[index].to(self.device) # [1] + else: + results['eye'] = None + + bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device) + + results['bg_color'] = bg_img + + bg_coords = self.bg_coords # [1, N, 2] + results['bg_coords'] = bg_coords + + # results['poses'] = convert_poses(poses) # [B, 6] + # results['poses_matrix'] = poses # [B, 4, 4] + results['poses'] = poses # [B, 4, 4] + + return results + + def dataloader(self): + + + # test with novel auds, then use its length + if self.auds is not None: + size = self.auds.shape[0] + # live stream test, use 2 * len(poses), so it naturally mirrors. + else: + size = 2 * self.poses.shape[0] + + loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=False, num_workers=0) + loader._data = self # an ugly fix... we need poses in trainer. + + # do evaluate if has gt images and use self-driven setting + loader.has_gt = False + + return loader + + +class NeRFDataset: + def __init__(self, opt, device, type='train', downscale=1): + super().__init__() + + self.opt = opt + self.device = device + self.type = type # train, val, test + self.downscale = downscale + self.root_path = opt.path + self.preload = opt.preload # 0 = disk, 1 = cpu, 2 = gpu + self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box. + self.offset = opt.offset # camera offset + self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses. + self.fp16 = opt.fp16 + + self.start_index = opt.data_range[0] + self.end_index = opt.data_range[1] + + self.training = self.type in ['train', 'all', 'trainval'] + self.num_rays = self.opt.num_rays if self.training else -1 + + # load nerf-compatible format data. + + # load all splits (train/valid/test) + if type == 'all': + transform_paths = glob.glob(os.path.join(self.root_path, '*.json')) + transform = None + for transform_path in transform_paths: + with open(transform_path, 'r') as f: + tmp_transform = json.load(f) + if transform is None: + transform = tmp_transform + else: + transform['frames'].extend(tmp_transform['frames']) + # load train and val split + elif type == 'trainval': + with open(os.path.join(self.root_path, f'transforms_train.json'), 'r') as f: + transform = json.load(f) + with open(os.path.join(self.root_path, f'transforms_val.json'), 'r') as f: + transform_val = json.load(f) + transform['frames'].extend(transform_val['frames']) + # only load one specified split + else: + # no test, use val as test + _split = 'val' if type == 'test' else type + with open(os.path.join(self.root_path, f'transforms_{_split}.json'), 'r') as f: + transform = json.load(f) + + # load image size + if 'h' in transform and 'w' in transform: + self.H = int(transform['h']) // downscale + self.W = int(transform['w']) // downscale + else: + self.H = int(transform['cy']) * 2 // downscale + self.W = int(transform['cx']) * 2 // downscale + + # read images + frames = transform["frames"] + + # use a slice of the dataset + if self.end_index == -1: # abuse... + self.end_index = len(frames) + + frames = frames[self.start_index:self.end_index] + + # use a subset of dataset. + if type == 'train': + if self.opt.part: + frames = frames[::10] # 1/10 frames + elif self.opt.part2: + frames = frames[:375] # first 15s + elif type == 'val': + frames = frames[:100] # first 100 frames for val + + print(f'[INFO] load {len(frames)} {type} frames.') + + # only load pre-calculated aud features when not live-streaming + if not self.opt.asr: + + # empty means the default self-driven extracted features. + if self.opt.aud == '': + if 'esperanto' in self.opt.asr_model: + aud_features = np.load(os.path.join(self.root_path, 'aud_eo.npy')) + elif 'deepspeech' in self.opt.asr_model: + aud_features = np.load(os.path.join(self.root_path, 'aud_ds.npy')) + else: + aud_features = np.load(os.path.join(self.root_path, 'aud.npy')) + # cross-driven extracted features. + else: + aud_features = np.load(self.opt.aud) + + aud_features = torch.from_numpy(aud_features) + + # support both [N, 16] labels and [N, 16, K] logits + if len(aud_features.shape) == 3: + aud_features = aud_features.float().permute(0, 2, 1) # [N, 16, 29] --> [N, 29, 16] + + if self.opt.emb: + print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode') + aud_features = aud_features.argmax(1) # [N, 16] + + else: + assert self.opt.emb, "aud only provide labels, must use --emb" + aud_features = aud_features.long() + + print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}') + + # load action units + import pandas as pd + au_blink_info=pd.read_csv(os.path.join(self.root_path, 'au.csv')) + au_blink = au_blink_info[' AU45_r'].values + + self.torso_img = [] + self.images = [] + + self.poses = [] + self.exps = [] + + self.auds = [] + self.face_rect = [] + self.lhalf_rect = [] + self.lips_rect = [] + self.eye_area = [] + self.eye_rect = [] + + for f in tqdm.tqdm(frames, desc=f'Loading {type} data'): + + f_path = os.path.join(self.root_path, 'gt_imgs', str(f['img_id']) + '.jpg') + + if not os.path.exists(f_path): + print('[WARN]', f_path, 'NOT FOUND!') + continue + + pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4] + pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset) + self.poses.append(pose) + + if self.preload > 0: + image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4] + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = image.astype(np.float32) / 255 # [H, W, 3/4] + + self.images.append(image) + else: + self.images.append(f_path) + + # load frame-wise bg + + torso_img_path = os.path.join(self.root_path, 'torso_imgs', str(f['img_id']) + '.png') + + if self.preload > 0: + torso_img = cv2.imread(torso_img_path, cv2.IMREAD_UNCHANGED) # [H, W, 4] + torso_img = cv2.cvtColor(torso_img, cv2.COLOR_BGRA2RGBA) + torso_img = torso_img.astype(np.float32) / 255 # [H, W, 3/4] + + self.torso_img.append(torso_img) + else: + self.torso_img.append(torso_img_path) + + # find the corresponding audio to the image frame + if not self.opt.asr and self.opt.aud == '': + aud = aud_features[min(f['aud_id'], aud_features.shape[0] - 1)] # careful for the last frame... + self.auds.append(aud) + + # load lms and extract face + lms = np.loadtxt(os.path.join(self.root_path, 'ori_imgs', str(f['img_id']) + '.lms')) # [68, 2] + + lh_xmin, lh_xmax = int(lms[31:36, 1].min()), int(lms[:, 1].max()) # actually lower half area + xmin, xmax = int(lms[:, 1].min()), int(lms[:, 1].max()) + ymin, ymax = int(lms[:, 0].min()), int(lms[:, 0].max()) + self.face_rect.append([xmin, xmax, ymin, ymax]) + self.lhalf_rect.append([lh_xmin, lh_xmax, ymin, ymax]) + + if self.opt.exp_eye: + # eyes_left = slice(36, 42) + # eyes_right = slice(42, 48) + + # area_left = polygon_area(lms[eyes_left, 0], lms[eyes_left, 1]) + # area_right = polygon_area(lms[eyes_right, 0], lms[eyes_right, 1]) + + # # area percentage of two eyes of the whole image... + # area = (area_left + area_right) / (self.H * self.W) * 100 + + # action units blink AU45 + area = au_blink[f['img_id']] + area = np.clip(area, 0, 2) / 2 + # area = area + np.random.rand() / 10 + self.eye_area.append(area) + + xmin, xmax = int(lms[36:48, 1].min()), int(lms[36:48, 1].max()) + ymin, ymax = int(lms[36:48, 0].min()), int(lms[36:48, 0].max()) + self.eye_rect.append([xmin, xmax, ymin, ymax]) + + if self.opt.finetune_lips: + lips = slice(48, 60) + xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max()) + ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max()) + + # padding to H == W + cx = (xmin + xmax) // 2 + cy = (ymin + ymax) // 2 + + l = max(xmax - xmin, ymax - ymin) // 2 + xmin = max(0, cx - l) + xmax = min(self.H, cx + l) + ymin = max(0, cy - l) + ymax = min(self.W, cy + l) + + self.lips_rect.append([xmin, xmax, ymin, ymax]) + + # load pre-extracted background image (should be the same size as training image...) + + if self.opt.bg_img == 'white': # special + bg_img = np.ones((self.H, self.W, 3), dtype=np.float32) + elif self.opt.bg_img == 'black': # special + bg_img = np.zeros((self.H, self.W, 3), dtype=np.float32) + else: # load from file + # default bg + if self.opt.bg_img == '': + self.opt.bg_img = os.path.join(self.root_path, 'bc.jpg') + bg_img = cv2.imread(self.opt.bg_img, cv2.IMREAD_UNCHANGED) # [H, W, 3] + if bg_img.shape[0] != self.H or bg_img.shape[1] != self.W: + bg_img = cv2.resize(bg_img, (self.W, self.H), interpolation=cv2.INTER_AREA) + bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB) + bg_img = bg_img.astype(np.float32) / 255 # [H, W, 3/4] + + self.bg_img = bg_img + + self.poses = np.stack(self.poses, axis=0) + + # smooth camera path... + if self.opt.smooth_path: + self.poses = smooth_camera_path(self.poses, self.opt.smooth_path_window) + + self.poses = torch.from_numpy(self.poses) # [N, 4, 4] + + if self.preload > 0: + self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, C] + self.torso_img = torch.from_numpy(np.stack(self.torso_img, axis=0)) # [N, H, W, C] + else: + self.images = np.array(self.images) + self.torso_img = np.array(self.torso_img) + + if self.opt.asr: + # live streaming, no pre-calculated auds + self.auds = None + else: + # auds corresponding to images + if self.opt.aud == '': + self.auds = torch.stack(self.auds, dim=0) # [N, 32, 16] + # auds is novel, may have a different length with images + else: + self.auds = aud_features + + self.bg_img = torch.from_numpy(self.bg_img) + + if self.opt.exp_eye: + self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N] + print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}') + + if self.opt.smooth_eye: + + # naive 5 window average + ori_eye = self.eye_area.copy() + for i in range(ori_eye.shape[0]): + start = max(0, i - 1) + end = min(ori_eye.shape[0], i + 2) + self.eye_area[i] = ori_eye[start:end].mean() + + self.eye_area = torch.from_numpy(self.eye_area).view(-1, 1) # [N, 1] + + + # calculate mean radius of all camera poses + self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item() + #print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}') + + + # [debug] uncomment to view all training poses. + # visualize_poses(self.poses.numpy()) + + # [debug] uncomment to view examples of randomly generated poses. + # visualize_poses(rand_poses(100, self.device, radius=self.radius).cpu().numpy()) + + if self.preload > 1: + self.poses = self.poses.to(self.device) + + if self.auds is not None: + self.auds = self.auds.to(self.device) + + self.bg_img = self.bg_img.to(torch.half).to(self.device) + + self.torso_img = self.torso_img.to(torch.half).to(self.device) + self.images = self.images.to(torch.half).to(self.device) + + if self.opt.exp_eye: + self.eye_area = self.eye_area.to(self.device) + + # load intrinsics + if 'focal_len' in transform: + fl_x = fl_y = transform['focal_len'] + elif 'fl_x' in transform or 'fl_y' in transform: + fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale + fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale + elif 'camera_angle_x' in transform or 'camera_angle_y' in transform: + # blender, assert in radians. already downscaled since we use H/W + fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None + fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None + if fl_x is None: fl_x = fl_y + if fl_y is None: fl_y = fl_x + else: + raise RuntimeError('Failed to load focal length, please check the transforms.json!') + + cx = (transform['cx'] / downscale) if 'cx' in transform else (self.W / 2) + cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2) + + self.intrinsics = np.array([fl_x, fl_y, cx, cy]) + + # directly build the coordinate meshgrid in [-1, 1]^2 + self.bg_coords = get_bg_coords(self.H, self.W, self.device) # [1, H*W, 2] in [-1, 1] + + + def mirror_index(self, index): + size = self.poses.shape[0] + turn = index // size + res = index % size + if turn % 2 == 0: + return res + else: + return size - res - 1 + + + def collate(self, index): + + B = len(index) # a list of length 1 + # assert B == 1 + + results = {} + + # audio use the original index + if self.auds is not None: + auds = get_audio_features(self.auds, self.opt.att, index[0]).to(self.device) + results['auds'] = auds + + # head pose and bg image may mirror (replay --> <-- --> <--). + index[0] = self.mirror_index(index[0]) + + poses = self.poses[index].to(self.device) # [B, 4, 4] + + if self.training and self.opt.finetune_lips: + rect = self.lips_rect[index[0]] + results['rect'] = rect + rays = get_rays(poses, self.intrinsics, self.H, self.W, -1, rect=rect) + else: + rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, self.opt.patch_size) + + results['index'] = index # for ind. code + results['H'] = self.H + results['W'] = self.W + results['rays_o'] = rays['rays_o'] + results['rays_d'] = rays['rays_d'] + + # get a mask for rays inside rect_face + if self.training: + xmin, xmax, ymin, ymax = self.face_rect[index[0]] + face_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N] + results['face_mask'] = face_mask + + xmin, xmax, ymin, ymax = self.lhalf_rect[index[0]] + lhalf_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N] + results['lhalf_mask'] = lhalf_mask + + if self.opt.exp_eye: + results['eye'] = self.eye_area[index].to(self.device) # [1] + if self.training: + results['eye'] += (np.random.rand()-0.5) / 10 + xmin, xmax, ymin, ymax = self.eye_rect[index[0]] + eye_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N] + results['eye_mask'] = eye_mask + + else: + results['eye'] = None + + # load bg + bg_torso_img = self.torso_img[index] + if self.preload == 0: # on the fly loading + bg_torso_img = cv2.imread(bg_torso_img[0], cv2.IMREAD_UNCHANGED) # [H, W, 4] + bg_torso_img = cv2.cvtColor(bg_torso_img, cv2.COLOR_BGRA2RGBA) + bg_torso_img = bg_torso_img.astype(np.float32) / 255 # [H, W, 3/4] + bg_torso_img = torch.from_numpy(bg_torso_img).unsqueeze(0) + bg_torso_img = bg_torso_img[..., :3] * bg_torso_img[..., 3:] + self.bg_img * (1 - bg_torso_img[..., 3:]) + bg_torso_img = bg_torso_img.view(B, -1, 3).to(self.device) + + if not self.opt.torso: + bg_img = bg_torso_img + else: + bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device) + + if self.training: + bg_img = torch.gather(bg_img, 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3] + + results['bg_color'] = bg_img + + if self.opt.torso and self.training: + bg_torso_img = torch.gather(bg_torso_img, 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3] + results['bg_torso_color'] = bg_torso_img + + images = self.images[index] # [B, H, W, 3/4] + if self.preload == 0: + images = cv2.imread(images[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] + images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB) + images = images.astype(np.float32) / 255 # [H, W, 3] + images = torch.from_numpy(images).unsqueeze(0) + images = images.to(self.device) + + if self.training: + C = images.shape[-1] + images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4] + + results['images'] = images + + if self.training: + bg_coords = torch.gather(self.bg_coords, 1, torch.stack(2 * [rays['inds']], -1)) # [1, N, 2] + else: + bg_coords = self.bg_coords # [1, N, 2] + + results['bg_coords'] = bg_coords + + # results['poses'] = convert_poses(poses) # [B, 6] + # results['poses_matrix'] = poses # [B, 4, 4] + results['poses'] = poses # [B, 4, 4] + + return results + + def dataloader(self): + + if self.training: + # training len(poses) == len(auds) + size = self.poses.shape[0] + else: + # test with novel auds, then use its length + if self.auds is not None: + size = self.auds.shape[0] + # live stream test, use 2 * len(poses), so it naturally mirrors. + else: + size = 2 * self.poses.shape[0] + + loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) + loader._data = self # an ugly fix... we need poses in trainer. + + # do evaluate if has gt images and use self-driven setting + loader.has_gt = (self.opt.aud == '') + + return loader \ No newline at end of file diff --git a/nerf_triplane/renderer.py b/nerf_triplane/renderer.py new file mode 100644 index 0000000..3dfe8f2 --- /dev/null +++ b/nerf_triplane/renderer.py @@ -0,0 +1,700 @@ +import math +import trimesh +import numpy as np +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import raymarching +from .utils import custom_meshgrid, get_audio_features, euler_angles_to_matrix, convert_poses + +def sample_pdf(bins, weights, n_samples, det=False): + # This implementation is from NeRF + # bins: [B, T], old_z_vals + # weights: [B, T - 1], bin weights. + # return: [B, n_samples], new_z_vals + + # Get pdf + weights = weights + 1e-5 # prevent nans + pdf = weights / torch.sum(weights, -1, keepdim=True) + cdf = torch.cumsum(pdf, -1) + cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) + # Take uniform samples + if det: + u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device) + u = u.expand(list(cdf.shape[:-1]) + [n_samples]) + else: + u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) + + # Invert CDF + u = u.contiguous() + inds = torch.searchsorted(cdf, u, right=True) + below = torch.max(torch.zeros_like(inds - 1), inds - 1) + above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) + inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) + + matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] + cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) + bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) + + denom = (cdf_g[..., 1] - cdf_g[..., 0]) + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = (u - cdf_g[..., 0]) / denom + samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) + + return samples + + +def plot_pointcloud(pc, color=None): + # pc: [N, 3] + # color: [N, 3/4] + print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0)) + pc = trimesh.PointCloud(pc, color) + # axis + axes = trimesh.creation.axis(axis_length=4) + # sphere + sphere = trimesh.creation.icosphere(radius=1) + trimesh.Scene([pc, axes, sphere]).show() + + +class NeRFRenderer(nn.Module): + def __init__(self, opt): + + super().__init__() + + self.opt = opt + self.bound = opt.bound + self.cascade = 1 + math.ceil(math.log2(opt.bound)) + self.grid_size = 128 + self.density_scale = 1 + + self.min_near = opt.min_near + self.density_thresh = opt.density_thresh + self.density_thresh_torso = opt.density_thresh_torso + + self.exp_eye = opt.exp_eye + self.test_train = opt.test_train + self.smooth_lips = opt.smooth_lips + + self.torso = opt.torso + self.cuda_ray = opt.cuda_ray + + # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) + # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. + aabb_train = torch.FloatTensor([-opt.bound, -opt.bound/2, -opt.bound, opt.bound, opt.bound/2, opt.bound]) + aabb_infer = aabb_train.clone() + self.register_buffer('aabb_train', aabb_train) + self.register_buffer('aabb_infer', aabb_infer) + + # individual codes + self.individual_num = opt.ind_num + + self.individual_dim = opt.ind_dim + if self.individual_dim > 0: + self.individual_codes = nn.Parameter(torch.randn(self.individual_num, self.individual_dim) * 0.1) + + if self.torso: + self.individual_dim_torso = opt.ind_dim_torso + if self.individual_dim_torso > 0: + self.individual_codes_torso = nn.Parameter(torch.randn(self.individual_num, self.individual_dim_torso) * 0.1) + + # optimize camera pose + self.train_camera = self.opt.train_camera + if self.train_camera: + self.camera_dR = nn.Parameter(torch.zeros(self.individual_num, 3)) # euler angle + self.camera_dT = nn.Parameter(torch.zeros(self.individual_num, 3)) # xyz offset + + # extra state for cuda raymarching + + # 3D head density grid + density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H] + density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8] + self.register_buffer('density_grid', density_grid) + self.register_buffer('density_bitfield', density_bitfield) + self.mean_density = 0 + self.iter_density = 0 + + # 2D torso density grid + if self.torso: + density_grid_torso = torch.zeros([self.grid_size ** 2]) # [H * H] + self.register_buffer('density_grid_torso', density_grid_torso) + self.mean_density_torso = 0 + + # step counter + step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging... + self.register_buffer('step_counter', step_counter) + self.mean_count = 0 + self.local_step = 0 + + # decay for enc_a + if self.smooth_lips: + self.enc_a = None + + def forward(self, x, d): + raise NotImplementedError() + + # separated density and color query (can accelerate non-cuda-ray mode.) + def density(self, x): + raise NotImplementedError() + + def color(self, x, d, mask=None, **kwargs): + raise NotImplementedError() + + def reset_extra_state(self): + if not self.cuda_ray: + return + # density grid + self.density_grid.zero_() + self.mean_density = 0 + self.iter_density = 0 + # step counter + self.step_counter.zero_() + self.mean_count = 0 + self.local_step = 0 + + + def run_cuda(self, rays_o, rays_d, auds, bg_coords, poses, eye=None, index=0, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # auds: [B, 16] + # index: [B] + # return: image: [B, N, 3], depth: [B, N] + + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + bg_coords = bg_coords.contiguous().view(-1, 2) + + # only add camera offset at training! + if self.train_camera and (self.training or self.test_train): + dT = self.camera_dT[index] # [1, 3] + dR = euler_angles_to_matrix(self.camera_dR[index] / 180 * np.pi + 1e-8).squeeze(0) # [1, 3] --> [3, 3] + + rays_o = rays_o + dT + rays_d = rays_d @ dR + + N = rays_o.shape[0] # N = B * N, in fact + device = rays_o.device + + results = {} + + # pre-calculate near far + nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near) + nears = nears.detach() + fars = fars.detach() + + # encode audio + enc_a = self.encode_audio(auds) # [1, 64] + + if enc_a is not None and self.smooth_lips: + if self.enc_a is not None: + _lambda = 0.35 + enc_a = _lambda * self.enc_a + (1 - _lambda) * enc_a + self.enc_a = enc_a + + + if self.individual_dim > 0: + if self.training: + ind_code = self.individual_codes[index] + # use a fixed ind code for the unknown test data. + else: + ind_code = self.individual_codes[0] + else: + ind_code = None + + if self.training: + # setup counter + counter = self.step_counter[self.local_step % 16] + counter.zero_() # set to 0 + self.local_step += 1 + + xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps) + sigmas, rgbs, amb_aud, amb_eye, uncertainty = self(xyzs, dirs, enc_a, ind_code, eye) + sigmas = self.density_scale * sigmas + + #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})') + + # weights_sum, ambient_sum, uncertainty_sum, depth, image = raymarching.composite_rays_train_uncertainty(sigmas, rgbs, ambient.abs().sum(-1), uncertainty, deltas, rays) + weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image = raymarching.composite_rays_train_triplane(sigmas, rgbs, amb_aud.abs().sum(-1), amb_eye.abs().sum(-1), uncertainty, deltas, rays) + + # for training only + results['weights_sum'] = weights_sum + results['ambient_aud'] = amb_aud_sum + results['ambient_eye'] = amb_eye_sum + results['uncertainty'] = uncertainty_sum + + results['rays'] = xyzs, dirs, enc_a, ind_code, eye + + else: + + dtype = torch.float32 + + weights_sum = torch.zeros(N, dtype=dtype, device=device) + depth = torch.zeros(N, dtype=dtype, device=device) + image = torch.zeros(N, 3, dtype=dtype, device=device) + amb_aud_sum = torch.zeros(N, dtype=dtype, device=device) + amb_eye_sum = torch.zeros(N, dtype=dtype, device=device) + uncertainty_sum = torch.zeros(N, dtype=dtype, device=device) + + n_alive = N + rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] + rays_t = nears.clone() # [N] + + step = 0 + + while step < max_steps: + + # count alive rays + n_alive = rays_alive.shape[0] + + # exit loop + if n_alive <= 0: + break + + # decide compact_steps + n_step = max(min(N // n_alive, 8), 1) + + xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps) + + sigmas, rgbs, ambients_aud, ambients_eye, uncertainties = self(xyzs, dirs, enc_a, ind_code, eye) + sigmas = self.density_scale * sigmas + + # raymarching.composite_rays_uncertainty(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum, T_thresh) + raymarching.composite_rays_triplane(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients_aud, ambients_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum, T_thresh) + + rays_alive = rays_alive[rays_alive >= 0] + + # print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') + + step += n_step + + torso_results = self.run_torso(rays_o, bg_coords, poses, index, bg_color) + bg_color = torso_results['bg_color'] + + image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + image = image.view(*prefix, 3) + image = image.clamp(0, 1) + + depth = torch.clamp(depth - nears, min=0) / (fars - nears) + depth = depth.view(*prefix) + + amb_aud_sum = amb_aud_sum.view(*prefix) + amb_eye_sum = amb_eye_sum.view(*prefix) + + results['depth'] = depth + results['image'] = image # head_image if train, else com_image + results['ambient_aud'] = amb_aud_sum + results['ambient_eye'] = amb_eye_sum + results['uncertainty'] = uncertainty_sum + + return results + + + def run_torso(self, rays_o, bg_coords, poses, index=0, bg_color=None, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # auds: [B, 16] + # index: [B] + # return: image: [B, N, 3], depth: [B, N] + + rays_o = rays_o.contiguous().view(-1, 3) + bg_coords = bg_coords.contiguous().view(-1, 2) + + N = rays_o.shape[0] # N = B * N, in fact + device = rays_o.device + + results = {} + + # background + if bg_color is None: + bg_color = 1 + + # first mix torso with background + if self.torso: + # torso ind code + if self.individual_dim_torso > 0: + if self.training: + ind_code_torso = self.individual_codes_torso[index] + # use a fixed ind code for the unknown test data. + else: + ind_code_torso = self.individual_codes_torso[0] + else: + ind_code_torso = None + + # 2D density grid for acceleration... + density_thresh_torso = min(self.density_thresh_torso, self.mean_density_torso) + occupancy = F.grid_sample(self.density_grid_torso.view(1, 1, self.grid_size, self.grid_size), bg_coords.view(1, -1, 1, 2), align_corners=True).view(-1) + mask = occupancy > density_thresh_torso + + # masked query of torso + torso_alpha = torch.zeros([N, 1], device=device) + torso_color = torch.zeros([N, 3], device=device) + + if mask.any(): + torso_alpha_mask, torso_color_mask, deform = self.forward_torso(bg_coords[mask], poses, ind_code_torso) + + torso_alpha[mask] = torso_alpha_mask.float() + torso_color[mask] = torso_color_mask.float() + + results['deform'] = deform + + # first mix torso with background + + bg_color = torso_color * torso_alpha + bg_color * (1 - torso_alpha) + + results['torso_alpha'] = torso_alpha + results['torso_color'] = bg_color + + # print(torso_alpha.shape, torso_alpha.max().item(), torso_alpha.min().item()) + + results['bg_color'] = bg_color + + return results + + + @torch.no_grad() + def mark_untrained_grid(self, poses, intrinsic, S=64): + # poses: [B, 4, 4] + # intrinsic: [3, 3] + + if not self.cuda_ray: + return + + if isinstance(poses, np.ndarray): + poses = torch.from_numpy(poses) + + B = poses.shape[0] + + fx, fy, cx, cy = intrinsic + + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + + count = torch.zeros_like(self.density_grid) + poses = poses.to(count.device) + + # 5-level loop, forgive me... + + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + world_xyzs = (2 * coords.float() / (self.grid_size - 1) - 1).unsqueeze(0) # [1, N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_world_xyzs = world_xyzs * (bound - half_grid_size) + + # split batch to avoid OOM + head = 0 + while head < B: + tail = min(head + S, B) + + # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.) + cam_xyzs = cas_world_xyzs - poses[head:tail, :3, 3].unsqueeze(1) + cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3] + + # query if point is covered by any camera + mask_z = cam_xyzs[:, :, 2] > 0 # [S, N] + mask_x = torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2 + mask_y = torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2 + mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N] + + # update count + count[cas, indices] += mask + head += S + + # mark untrained grid as -1 + self.density_grid[count == 0] = -1 + + #print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}') + + @torch.no_grad() + def update_extra_state(self, decay=0.95, S=128): + # call before each epoch to update extra states. + + if not self.cuda_ray: + return + + # use random auds (different expressions should have similar density grid...) + rand_idx = random.randint(0, self.aud_features.shape[0] - 1) + auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device) + + # encode audio + enc_a = self.encode_audio(auds) + + ### update density grid + if not self.torso: # forbid updating head if is training torso... + + tmp_grid = torch.zeros_like(self.density_grid) + + # use a random eye area based on training dataset's statistics... + if self.exp_eye: + eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1] + else: + eye = None + + # full update + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_size) + # add noise in [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size + # query density + sigmas = self.density(cas_xyzs, enc_a, eye)['sigma'].reshape(-1).detach().to(tmp_grid.dtype) + sigmas *= self.density_scale + # assign + tmp_grid[cas, indices] = sigmas + + # dilate the density_grid (less aggressive culling) + tmp_grid = raymarching.morton3D_dilation(tmp_grid) + + # ema update + valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0) + self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) + self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 non-training regions are viewed as 0 density. + self.iter_density += 1 + + # convert to bitfield + density_thresh = min(self.mean_density, self.density_thresh) + self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield) + + ### update torso density grid + if self.torso: + tmp_grid_torso = torch.zeros_like(self.density_grid_torso) + + # random pose, random ind_code + rand_idx = random.randint(0, self.poses.shape[0] - 1) + # pose = convert_poses(self.poses[[rand_idx]]).to(self.density_bitfield.device) + pose = self.poses[[rand_idx]].to(self.density_bitfield.device) + + if self.opt.ind_dim_torso > 0: + ind_code = self.individual_codes_torso[[rand_idx]] + else: + ind_code = None + + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + + half_grid_size = 1 / self.grid_size + + for xs in X: + for ys in Y: + xx, yy = custom_meshgrid(xs, ys) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], dim=-1) # [N, 2], in [0, 128) + indices = (coords[:, 1] * self.grid_size + coords[:, 0]).long() # NOTE: xy transposed! + xys = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 2] in [-1, 1] + xys = xys * (1 - half_grid_size) + # add noise in [-hgs, hgs] + xys += (torch.rand_like(xys) * 2 - 1) * half_grid_size + # query density + alphas, _, _ = self.forward_torso(xys, pose, ind_code) # [N, 1] + + # assign + tmp_grid_torso[indices] = alphas.squeeze(1).float() + + # dilate + tmp_grid_torso = tmp_grid_torso.view(1, 1, self.grid_size, self.grid_size) + # tmp_grid_torso = F.max_pool2d(tmp_grid_torso, kernel_size=3, stride=1, padding=1) + tmp_grid_torso = F.max_pool2d(tmp_grid_torso, kernel_size=5, stride=1, padding=2) + tmp_grid_torso = tmp_grid_torso.view(-1) + + self.density_grid_torso = torch.maximum(self.density_grid_torso * decay, tmp_grid_torso) + self.mean_density_torso = torch.mean(self.density_grid_torso).item() + + # density_thresh_torso = min(self.density_thresh_torso, self.mean_density_torso) + # print(f'[density grid torso] min={self.density_grid_torso.min().item():.4f}, max={self.density_grid_torso.max().item():.4f}, mean={self.mean_density_torso:.4f}, occ_rate={(self.density_grid_torso > density_thresh_torso).sum() / (128**2):.3f}') + + ### update step counter + total_step = min(16, self.local_step) + if total_step > 0: + self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) + self.local_step = 0 + + #print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}') + + + @torch.no_grad() + def get_audio_grid(self, S=128): + # call before each epoch to update extra states. + + if not self.cuda_ray: + return + + # use random auds (different expressions should have similar density grid...) + rand_idx = random.randint(0, self.aud_features.shape[0] - 1) + auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device) + + # encode audio + enc_a = self.encode_audio(auds) + tmp_grid = torch.zeros_like(self.density_grid) + + # use a random eye area based on training dataset's statistics... + if self.exp_eye: + eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1] + else: + eye = None + + # full update + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_size) + # add noise in [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size + # query density + aud_norms = self.density(cas_xyzs.to(tmp_grid.dtype), enc_a, eye)['ambient_aud'].reshape(-1).detach().to(tmp_grid.dtype) + # assign + tmp_grid[cas, indices] = aud_norms + + # dilate the density_grid (less aggressive culling) + tmp_grid = raymarching.morton3D_dilation(tmp_grid) + return tmp_grid + # # ema update + # valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0) + # self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) + + + @torch.no_grad() + def get_eye_grid(self, S=128): + # call before each epoch to update extra states. + + if not self.cuda_ray: + return + + # use random auds (different expressions should have similar density grid...) + rand_idx = random.randint(0, self.aud_features.shape[0] - 1) + auds = get_audio_features(self.aud_features, self.att, rand_idx).to(self.density_bitfield.device) + + # encode audio + enc_a = self.encode_audio(auds) + tmp_grid = torch.zeros_like(self.density_grid) + + # use a random eye area based on training dataset's statistics... + if self.exp_eye: + eye = self.eye_area[[rand_idx]].to(self.density_bitfield.device) # [1, 1] + else: + eye = None + + # full update + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.density_bitfield.device).split(S) + + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_size) + # add noise in [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size + # query density + eye_norms = self.density(cas_xyzs.to(tmp_grid.dtype), enc_a, eye)['ambient_eye'].reshape(-1).detach().to(tmp_grid.dtype) + # assign + tmp_grid[cas, indices] = eye_norms + + # dilate the density_grid (less aggressive culling) + tmp_grid = raymarching.morton3D_dilation(tmp_grid) + return tmp_grid + # # ema update + # valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0) + # self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) + + + + def render(self, rays_o, rays_d, auds, bg_coords, poses, staged=False, max_ray_batch=4096, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # auds: [B, 29, 16] + # eye: [B, 1] + # bg_coords: [1, N, 2] + # return: pred_rgb: [B, N, 3] + + _run = self.run_cuda + + B, N = rays_o.shape[:2] + device = rays_o.device + + # never stage when cuda_ray + if staged and not self.cuda_ray: + # not used + raise NotImplementedError + + else: + results = _run(rays_o, rays_d, auds, bg_coords, poses, **kwargs) + + return results + + + def render_torso(self, rays_o, rays_d, auds, bg_coords, poses, staged=False, max_ray_batch=4096, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # auds: [B, 29, 16] + # eye: [B, 1] + # bg_coords: [1, N, 2] + # return: pred_rgb: [B, N, 3] + + _run = self.run_torso + + B, N = rays_o.shape[:2] + device = rays_o.device + + # never stage when cuda_ray + if staged and not self.cuda_ray: + # not used + raise NotImplementedError + + else: + results = _run(rays_o, bg_coords, poses, **kwargs) + + return results \ No newline at end of file diff --git a/nerf_triplane/utils.py b/nerf_triplane/utils.py new file mode 100644 index 0000000..d9304d3 --- /dev/null +++ b/nerf_triplane/utils.py @@ -0,0 +1,1514 @@ +import os +import glob +import tqdm +import math +import random +import warnings +import tensorboardX + +import numpy as np +import pandas as pd + +import time +from datetime import datetime + +import cv2 +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader + +import trimesh +import mcubes +from rich.console import Console +from torch_ema import ExponentialMovingAverage + +from packaging import version as pver +import imageio +import lpips + +def custom_meshgrid(*args): + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if pver.parse(torch.__version__) < pver.parse('1.10'): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing='ij') + + +def get_audio_features(features, att_mode, index): + if att_mode == 0: + return features[[index]] + elif att_mode == 1: + left = index - 8 + pad_left = 0 + if left < 0: + pad_left = -left + left = 0 + auds = features[left:index] + if pad_left > 0: + # pad may be longer than auds, so do not use zeros_like + auds = torch.cat([torch.zeros(pad_left, *auds.shape[1:], device=auds.device, dtype=auds.dtype), auds], dim=0) + return auds + elif att_mode == 2: + left = index - 4 + right = index + 4 + pad_left = 0 + pad_right = 0 + if left < 0: + pad_left = -left + left = 0 + if right > features.shape[0]: + pad_right = right - features.shape[0] + right = features.shape[0] + auds = features[left:right] + if pad_left > 0: + auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0) + if pad_right > 0: + auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16] + return auds + else: + raise NotImplementedError(f'wrong att_mode: {att_mode}') + + +@torch.jit.script +def linear_to_srgb(x): + return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055) + + +@torch.jit.script +def srgb_to_linear(x): + return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) + +# copied from pytorch3d +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def _index_from_letter(letter: str) -> int: + if letter == "X": + return 0 + if letter == "Y": + return 1 + if letter == "Z": + return 2 + raise ValueError("letter must be either X, Y or Z.") + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3). + """ + # if len(convention) != 3: + # raise ValueError("Convention must have 3 letters.") + # if convention[1] in (convention[0], convention[2]): + # raise ValueError(f"Invalid convention {convention}.") + # for letter in convention: + # if letter not in ("X", "Y", "Z"): + # raise ValueError(f"Invalid letter {letter} in convention string.") + # if matrix.size(-1) != 3 or matrix.size(-2) != 3: + # raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + i0 = _index_from_letter(convention[0]) + i2 = _index_from_letter(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin( + matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0) + ) + else: + central_angle = torch.acos(matrix[..., i0, i0]) + + o = ( + _angle_from_tan( + convention[0], convention[1], matrix[..., i2], False, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0, :], True, tait_bryan + ), + ) + return torch.stack(o, -1) + +@torch.cuda.amp.autocast(enabled=False) +def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + +@torch.cuda.amp.autocast(enabled=False) +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str='XYZ') -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3). + convention: Convention string of three uppercase letters from + {"X", "Y", and "Z"}. + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + # print(euler_angles, euler_angles.dtype) + + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +@torch.cuda.amp.autocast(enabled=False) +def convert_poses(poses): + # poses: [B, 4, 4] + # return [B, 3], 4 rot, 3 trans + out = torch.empty(poses.shape[0], 6, dtype=torch.float32, device=poses.device) + out[:, :3] = matrix_to_euler_angles(poses[:, :3, :3]) + out[:, 3:] = poses[:, :3, 3] + return out + +@torch.cuda.amp.autocast(enabled=False) +def get_bg_coords(H, W, device): + X = torch.arange(H, device=device) / (H - 1) * 2 - 1 # in [-1, 1] + Y = torch.arange(W, device=device) / (W - 1) * 2 - 1 # in [-1, 1] + xs, ys = custom_meshgrid(X, Y) + bg_coords = torch.cat([xs.reshape(-1, 1), ys.reshape(-1, 1)], dim=-1).unsqueeze(0) # [1, H*W, 2], in [-1, 1] + return bg_coords + + +@torch.cuda.amp.autocast(enabled=False) +def get_rays(poses, intrinsics, H, W, N=-1, patch_size=1, rect=None): + ''' get rays + Args: + poses: [B, 4, 4], cam2world + intrinsics: [4] + H, W, N: int + Returns: + rays_o, rays_d: [B, N, 3] + inds: [B, N] + ''' + + device = poses.device + B = poses.shape[0] + fx, fy, cx, cy = intrinsics + + if rect is not None: + xmin, xmax, ymin, ymax = rect + N = (xmax - xmin) * (ymax - ymin) + + i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) # float + i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 + j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 + + results = {} + + if N > 0: + N = min(N, H*W) + + if patch_size > 1: + + # random sample left-top cores. + # NOTE: this impl will lead to less sampling on the image corner pixels... but I don't have other ideas. + num_patch = N // (patch_size ** 2) + inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device) + inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device) + inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2] + + # create meshgrid for each patch + pi, pj = custom_meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device)) + offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2] + + inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2] + inds = inds.view(-1, 2) # [N, 2] + inds = inds[:, 0] * W + inds[:, 1] # [N], flatten + + inds = inds.expand([B, N]) + + # only get rays in the specified rect + elif rect is not None: + # assert B == 1 + mask = torch.zeros(H, W, dtype=torch.bool, device=device) + xmin, xmax, ymin, ymax = rect + mask[xmin:xmax, ymin:ymax] = 1 + inds = torch.where(mask.view(-1))[0] # [nzn] + inds = inds.unsqueeze(0) # [1, N] + + else: + inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate + inds = inds.expand([B, N]) + + i = torch.gather(i, -1, inds) + j = torch.gather(j, -1, inds) + + + else: + inds = torch.arange(H*W, device=device).expand([B, H*W]) + + results['i'] = i + results['j'] = j + results['inds'] = inds + + zs = torch.ones_like(i) + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + directions = torch.stack((xs, ys, zs), dim=-1) + directions = directions / torch.norm(directions, dim=-1, keepdim=True) + + rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) + + rays_o = poses[..., :3, 3] # [B, 3] + rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] + + results['rays_o'] = rays_o + results['rays_d'] = rays_d + + return results + + +def seed_everything(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = True + + +def torch_vis_2d(x, renormalize=False): + # x: [3, H, W] or [1, H, W] or [H, W] + import matplotlib.pyplot as plt + import numpy as np + import torch + + if isinstance(x, torch.Tensor): + if len(x.shape) == 3: + x = x.permute(1,2,0).squeeze() + x = x.detach().cpu().numpy() + + print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}') + + x = x.astype(np.float32) + + # renormalize + if renormalize: + x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8) + + plt.imshow(x) + plt.show() + + +def extract_fields(bound_min, bound_max, resolution, query_func, S=128): + + X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S) + Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S) + Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S) + + u = np.zeros([resolution, resolution, resolution], dtype=np.float32) + with torch.no_grad(): + for xi, xs in enumerate(X): + for yi, ys in enumerate(Y): + for zi, zs in enumerate(Z): + xx, yy, zz = custom_meshgrid(xs, ys, zs) + pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3] + val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z] + u[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val + return u + + +def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): + #print('threshold: {}'.format(threshold)) + u = extract_fields(bound_min, bound_max, resolution, query_func) + + #print(u.shape, u.max(), u.min(), np.percentile(u, 50)) + + vertices, triangles = mcubes.marching_cubes(u, threshold) + + b_max_np = bound_max.detach().cpu().numpy() + b_min_np = bound_min.detach().cpu().numpy() + + vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] + return vertices, triangles + + +class PSNRMeter: + def __init__(self): + self.V = 0 + self.N = 0 + + def clear(self): + self.V = 0 + self.N = 0 + + def prepare_inputs(self, *inputs): + outputs = [] + for i, inp in enumerate(inputs): + if torch.is_tensor(inp): + inp = inp.detach().cpu().numpy() + outputs.append(inp) + + return outputs + + def update(self, preds, truths): + preds, truths = self.prepare_inputs(preds, truths) # [B, N, 3] or [B, H, W, 3], range in [0, 1] + + # simplified since max_pixel_value is 1 here. + psnr = -10 * np.log10(np.mean((preds - truths) ** 2)) + + self.V += psnr + self.N += 1 + + def measure(self): + return self.V / self.N + + def write(self, writer, global_step, prefix=""): + writer.add_scalar(os.path.join(prefix, "PSNR"), self.measure(), global_step) + + def report(self): + return f'PSNR = {self.measure():.6f}' + +class LPIPSMeter: + def __init__(self, net='alex', device=None): + self.V = 0 + self.N = 0 + self.net = net + + self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.fn = lpips.LPIPS(net=net).eval().to(self.device) + + def clear(self): + self.V = 0 + self.N = 0 + + def prepare_inputs(self, *inputs): + outputs = [] + for i, inp in enumerate(inputs): + inp = inp.permute(0, 3, 1, 2).contiguous() # [B, 3, H, W] + inp = inp.to(self.device) + outputs.append(inp) + return outputs + + def update(self, preds, truths): + preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1] + v = self.fn(truths, preds, normalize=True).item() # normalize=True: [0, 1] to [-1, 1] + self.V += v + self.N += 1 + + def measure(self): + return self.V / self.N + + def write(self, writer, global_step, prefix=""): + writer.add_scalar(os.path.join(prefix, f"LPIPS ({self.net})"), self.measure(), global_step) + + def report(self): + return f'LPIPS ({self.net}) = {self.measure():.6f}' + + +class LMDMeter: + def __init__(self, backend='dlib', region='mouth'): + self.backend = backend + self.region = region # mouth or face + + if self.backend == 'dlib': + import dlib + + # load checkpoint manually + self.predictor_path = './shape_predictor_68_face_landmarks.dat' + if not os.path.exists(self.predictor_path): + raise FileNotFoundError('Please download dlib checkpoint from http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2') + + self.detector = dlib.get_frontal_face_detector() + self.predictor = dlib.shape_predictor(self.predictor_path) + + else: + + import face_alignment + + self.predictor = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False) + + self.V = 0 + self.N = 0 + + def get_landmarks(self, img): + + if self.backend == 'dlib': + dets = self.detector(img, 1) + for det in dets: + shape = self.predictor(img, det) + # ref: https://github.com/PyImageSearch/imutils/blob/c12f15391fcc945d0d644b85194b8c044a392e0a/imutils/face_utils/helpers.py + lms = np.zeros((68, 2), dtype=np.int32) + for i in range(0, 68): + lms[i, 0] = shape.part(i).x + lms[i, 1] = shape.part(i).y + break + + else: + lms = self.predictor.get_landmarks(img)[-1] + + # self.vis_landmarks(img, lms) + lms = lms.astype(np.float32) + + return lms + + def vis_landmarks(self, img, lms): + plt.imshow(img) + plt.plot(lms[48:68, 0], lms[48:68, 1], marker='o', markersize=1, linestyle='-', lw=2) + plt.show() + + def clear(self): + self.V = 0 + self.N = 0 + + def prepare_inputs(self, *inputs): + outputs = [] + for i, inp in enumerate(inputs): + inp = inp.detach().cpu().numpy() + inp = (inp * 255).astype(np.uint8) + outputs.append(inp) + return outputs + + def update(self, preds, truths): + # assert B == 1 + preds, truths = self.prepare_inputs(preds[0], truths[0]) # [H, W, 3] numpy array + + # get lms + lms_pred = self.get_landmarks(preds) + lms_truth = self.get_landmarks(truths) + + if self.region == 'mouth': + lms_pred = lms_pred[48:68] + lms_truth = lms_truth[48:68] + + # avarage + lms_pred = lms_pred - lms_pred.mean(0) + lms_truth = lms_truth - lms_truth.mean(0) + + # distance + dist = np.sqrt(((lms_pred - lms_truth) ** 2).sum(1)).mean(0) + + self.V += dist + self.N += 1 + + def measure(self): + return self.V / self.N + + def write(self, writer, global_step, prefix=""): + writer.add_scalar(os.path.join(prefix, f"LMD ({self.backend})"), self.measure(), global_step) + + def report(self): + return f'LMD ({self.backend}) = {self.measure():.6f}' + + +class Trainer(object): + def __init__(self, + name, # name of this experiment + opt, # extra conf + model, # network + criterion=None, # loss function, if None, assume inline implementation in train_step + optimizer=None, # optimizer + ema_decay=None, # if use EMA, set the decay + ema_update_interval=1000, # update ema per $ training steps. + lr_scheduler=None, # scheduler + metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. + local_rank=0, # which GPU am I + world_size=1, # total num of GPUs + device=None, # device to use, usually setting to None is OK. (auto choose device) + mute=False, # whether to mute all print + fp16=False, # amp optimize level + eval_interval=1, # eval once every $ epoch + max_keep_ckpt=2, # max num of saved ckpts in disk + workspace='workspace', # workspace to save logs & ckpts + best_mode='min', # the smaller/larger result, the better + use_loss_as_metric=True, # use loss as the first metric + report_metric_at_train=False, # also report metrics at training + use_checkpoint="latest", # which ckpt to use at init time + use_tensorboardX=True, # whether to use tensorboard for logging + scheduler_update_every_step=False, # whether to call scheduler.step() after every train step + ): + + self.name = name + self.opt = opt + self.mute = mute + self.metrics = metrics + self.local_rank = local_rank + self.world_size = world_size + self.workspace = workspace + self.ema_decay = ema_decay + self.ema_update_interval = ema_update_interval + self.fp16 = fp16 + self.best_mode = best_mode + self.use_loss_as_metric = use_loss_as_metric + self.report_metric_at_train = report_metric_at_train + self.max_keep_ckpt = max_keep_ckpt + self.eval_interval = eval_interval + self.use_checkpoint = use_checkpoint + self.use_tensorboardX = use_tensorboardX + self.flip_finetune_lips = self.opt.finetune_lips + self.flip_init_lips = self.opt.init_lips + self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") + self.scheduler_update_every_step = scheduler_update_every_step + self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') + self.console = Console() + + model.to(self.device) + if self.world_size > 1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) + self.model = model + + if isinstance(criterion, nn.Module): + criterion.to(self.device) + self.criterion = criterion + + if optimizer is None: + self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam + else: + self.optimizer = optimizer(self.model) + + if lr_scheduler is None: + self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler + else: + self.lr_scheduler = lr_scheduler(self.optimizer) + + if ema_decay is not None: + self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay) + else: + self.ema = None + + self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) + + # optionally use LPIPS loss for patch-based training + if self.opt.patch_size > 1 or self.opt.finetune_lips or True: + import lpips + # self.criterion_lpips_vgg = lpips.LPIPS(net='vgg').to(self.device) + self.criterion_lpips_alex = lpips.LPIPS(net='alex').to(self.device) + + # variable init + self.epoch = 0 + self.global_step = 0 + self.local_step = 0 + self.stats = { + "loss": [], + "valid_loss": [], + "results": [], # metrics[0], or valid_loss + "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt + "best_result": None, + } + + # auto fix + if len(metrics) == 0 or self.use_loss_as_metric: + self.best_mode = 'min' + + # workspace prepare + self.log_ptr = None + if self.workspace is not None: + os.makedirs(self.workspace, exist_ok=True) + self.log_path = os.path.join(workspace, f"log_{self.name}.txt") + self.log_ptr = open(self.log_path, "a+") + + self.ckpt_path = os.path.join(self.workspace, 'checkpoints') + self.best_path = f"{self.ckpt_path}/{self.name}.pth" + os.makedirs(self.ckpt_path, exist_ok=True) + + self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}') + self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}') + + if self.workspace is not None: + if self.use_checkpoint == "scratch": + self.log("[INFO] Training from scratch ...") + elif self.use_checkpoint == "latest": + self.log("[INFO] Loading latest checkpoint ...") + self.load_checkpoint() + elif self.use_checkpoint == "latest_model": + self.log("[INFO] Loading latest checkpoint (model only)...") + self.load_checkpoint(model_only=True) + elif self.use_checkpoint == "best": + if os.path.exists(self.best_path): + self.log("[INFO] Loading best checkpoint ...") + self.load_checkpoint(self.best_path) + else: + self.log(f"[INFO] {self.best_path} not found, loading latest ...") + self.load_checkpoint() + else: # path to ckpt + self.log(f"[INFO] Loading {self.use_checkpoint} ...") + self.load_checkpoint(self.use_checkpoint) + + def __del__(self): + if self.log_ptr: + self.log_ptr.close() + + + def log(self, *args, **kwargs): + if self.local_rank == 0: + if not self.mute: + #print(*args) + self.console.print(*args, **kwargs) + if self.log_ptr: + print(*args, file=self.log_ptr) + self.log_ptr.flush() # write immediately to file + + ### ------------------------------ + + def train_step(self, data): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + bg_coords = data['bg_coords'] # [1, N, 2] + poses = data['poses'] # [B, 6] + face_mask = data['face_mask'] # [B, N] + eye_mask = data['eye_mask'] # [B, N] + lhalf_mask = data['lhalf_mask'] + eye = data['eye'] # [B, 1] + auds = data['auds'] # [B, 29, 16] + index = data['index'] # [B] + + if not self.opt.torso: + rgb = data['images'] # [B, N, 3] + else: + rgb = data['bg_torso_color'] + + B, N, C = rgb.shape + + if self.opt.color_space == 'linear': + rgb[..., :3] = srgb_to_linear(rgb[..., :3]) + + bg_color = data['bg_color'] + + if not self.opt.torso: + outputs = self.model.render(rays_o, rays_d, auds, bg_coords, poses, eye=eye, index=index, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False if (self.opt.patch_size <= 1 and not self.opt.train_camera) else True, **vars(self.opt)) + else: + outputs = self.model.render_torso(rays_o, rays_d, auds, bg_coords, poses, eye=eye, index=index, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False if (self.opt.patch_size <= 1 and not self.opt.train_camera) else True, **vars(self.opt)) + + if not self.opt.torso: + pred_rgb = outputs['image'] + else: + pred_rgb = outputs['torso_color'] + + + # loss factor + step_factor = min(self.global_step / self.opt.iters, 1.0) + + # MSE loss + loss = self.criterion(pred_rgb, rgb).mean(-1) # [B, N, 3] --> [B, N] + + if self.opt.torso: + loss = loss.mean() + loss += ((1 - self.model.anchor_points[:, 3])**2).mean() + return pred_rgb, rgb, loss + + # camera optim regularization + # if self.opt.train_camera: + # cam_reg = self.model.camera_dR[index].abs().mean() + self.model.camera_dT[index].abs().mean() + # loss = loss + 1e-2 * cam_reg + + if self.opt.unc_loss and not self.flip_finetune_lips: + alpha = 0.2 + uncertainty = outputs['uncertainty'] # [N], abs sum + beta = uncertainty + 1 + + unc_weight = F.softmax(uncertainty, dim=-1) * N + # print(unc_weight.shape, unc_weight.max(), unc_weight.min()) + loss *= alpha + (1-alpha)*((1 - step_factor) + step_factor * unc_weight.detach()).clamp(0, 10) + # loss *= unc_weight.detach() + + beta = uncertainty + 1 + norm_rgb = torch.norm((pred_rgb - rgb), dim=-1).detach() + loss_u = norm_rgb / (2*beta**2) + (torch.log(beta)**2) / 2 + loss_u *= face_mask.view(-1) + loss += step_factor * loss_u + + loss_static_uncertainty = (uncertainty * (~face_mask.view(-1))) + loss += 1e-3 * step_factor * loss_static_uncertainty + + # patch-based rendering + if self.opt.patch_size > 1 and not self.opt.finetune_lips: + rgb = rgb.view(-1, self.opt.patch_size, self.opt.patch_size, 3).permute(0, 3, 1, 2).contiguous() + pred_rgb = pred_rgb.view(-1, self.opt.patch_size, self.opt.patch_size, 3).permute(0, 3, 1, 2).contiguous() + + # torch_vis_2d(rgb[0]) + # torch_vis_2d(pred_rgb[0]) + + # LPIPS loss ? + loss_lpips = self.criterion_lpips_alex(pred_rgb, rgb) + loss = loss + 0.1 * loss_lpips + + # lips finetune + if self.opt.finetune_lips: + xmin, xmax, ymin, ymax = data['rect'] + rgb = rgb.view(-1, xmax - xmin, ymax - ymin, 3).permute(0, 3, 1, 2).contiguous() + pred_rgb = pred_rgb.view(-1, xmax - xmin, ymax - ymin, 3).permute(0, 3, 1, 2).contiguous() + + # torch_vis_2d(rgb[0]) + # torch_vis_2d(pred_rgb[0]) + + # LPIPS loss + loss = loss + 0.01 * self.criterion_lpips_alex(pred_rgb, rgb) + + # flip every step... if finetune lips + if self.flip_finetune_lips: + self.opt.finetune_lips = not self.opt.finetune_lips + + loss = loss.mean() + + # weights_sum loss + # entropy to encourage weights_sum to be 0 or 1. + if self.opt.torso: + alphas = outputs['torso_alpha'].clamp(1e-5, 1 - 1e-5) + # alphas = alphas ** 2 # skewed entropy, favors 0 over 1 + loss_ws = - alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas) + loss = loss + 1e-4 * loss_ws.mean() + + else: + alphas = outputs['weights_sum'].clamp(1e-5, 1 - 1e-5) + loss_ws = - alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas) + loss = loss + 1e-4 * loss_ws.mean() + + # aud att loss (regions out of face should be static) + if self.opt.amb_aud_loss and not self.opt.torso: + ambient_aud = outputs['ambient_aud'] + loss_amb_aud = (ambient_aud * (~face_mask.view(-1))).mean() + # gradually increase it + lambda_amb = step_factor * self.opt.lambda_amb + loss += lambda_amb * loss_amb_aud + + # eye att loss + if self.opt.amb_eye_loss and not self.opt.torso: + ambient_eye = outputs['ambient_eye'] / self.opt.max_steps + + loss_cross = ((ambient_eye * ambient_aud.detach())*face_mask.view(-1)).mean() + loss += lambda_amb * loss_cross + + # regularize + if self.global_step % 16 == 0 and not self.flip_finetune_lips: + xyzs, dirs, enc_a, ind_code, eye = outputs['rays'] + xyz_delta = (torch.rand(size=xyzs.shape, dtype=xyzs.dtype, device=xyzs.device) * 2 - 1) * 1e-3 + with torch.no_grad(): + sigmas_raw, rgbs_raw, ambient_aud_raw, ambient_eye_raw, unc_raw = self.model(xyzs, dirs, enc_a.detach(), ind_code.detach(), eye) + sigmas_reg, rgbs_reg, ambient_aud_reg, ambient_eye_reg, unc_reg = self.model(xyzs+xyz_delta, dirs, enc_a.detach(), ind_code.detach(), eye) + + lambda_reg = step_factor * 1e-5 + reg_loss = 0 + if self.opt.unc_loss: + reg_loss += self.criterion(unc_raw, unc_reg).mean() + if self.opt.amb_aud_loss: + reg_loss += self.criterion(ambient_aud_raw, ambient_aud_reg).mean() + if self.opt.amb_eye_loss: + reg_loss += self.criterion(ambient_eye_raw, ambient_eye_reg).mean() + + loss += reg_loss * lambda_reg + + return pred_rgb, rgb, loss + + + def eval_step(self, data): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + bg_coords = data['bg_coords'] # [1, N, 2] + poses = data['poses'] # [B, 7] + + images = data['images'] # [B, H, W, 3/4] + auds = data['auds'] + index = data['index'] # [B] + eye = data['eye'] # [B, 1] + + B, H, W, C = images.shape + + if self.opt.color_space == 'linear': + images[..., :3] = srgb_to_linear(images[..., :3]) + + # eval with fixed background color + # bg_color = 1 + bg_color = data['bg_color'] + + outputs = self.model.render(rays_o, rays_d, auds, bg_coords, poses, eye=eye, index=index, staged=True, bg_color=bg_color, perturb=False, **vars(self.opt)) + + pred_rgb = outputs['image'].reshape(B, H, W, 3) + pred_depth = outputs['depth'].reshape(B, H, W) + pred_ambient_aud = outputs['ambient_aud'].reshape(B, H, W) + pred_ambient_eye = outputs['ambient_eye'].reshape(B, H, W) + pred_uncertainty = outputs['uncertainty'].reshape(B, H, W) + + loss_raw = self.criterion(pred_rgb, images) + loss = loss_raw.mean() + + return pred_rgb, pred_depth, pred_ambient_aud, pred_ambient_eye, pred_uncertainty, images, loss, loss_raw + + # moved out bg_color and perturb for more flexible control... + def test_step(self, data, bg_color=None, perturb=False): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + bg_coords = data['bg_coords'] # [1, N, 2] + poses = data['poses'] # [B, 7] + + auds = data['auds'] # [B, 29, 16] + index = data['index'] + H, W = data['H'], data['W'] + + # allow using a fixed eye area (avoid eye blink) at test + if self.opt.exp_eye and self.opt.fix_eye >= 0: + eye = torch.FloatTensor([self.opt.fix_eye]).view(1, 1).to(self.device) + else: + eye = data['eye'] # [B, 1] + + if bg_color is not None: + bg_color = bg_color.to(self.device) + else: + bg_color = data['bg_color'] + + self.model.testing = True + outputs = self.model.render(rays_o, rays_d, auds, bg_coords, poses, eye=eye, index=index, staged=True, bg_color=bg_color, perturb=perturb, **vars(self.opt)) + self.model.testing = False + + pred_rgb = outputs['image'].reshape(-1, H, W, 3) + pred_depth = outputs['depth'].reshape(-1, H, W) + + return pred_rgb, pred_depth + + + def save_mesh(self, save_path=None, resolution=256, threshold=10): + + if save_path is None: + save_path = os.path.join(self.workspace, 'meshes', f'{self.name}_{self.epoch}.ply') + + self.log(f"==> Saving mesh to {save_path}") + + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + def query_func(pts): + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=self.fp16): + sigma = self.model.density(pts.to(self.device))['sigma'] + return sigma + + vertices, triangles = extract_geometry(self.model.aabb_infer[:3], self.model.aabb_infer[3:], resolution=resolution, threshold=threshold, query_func=query_func) + + mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault... + mesh.export(save_path) + + self.log(f"==> Finished saving mesh.") + + ### ------------------------------ + + def train(self, train_loader, valid_loader, max_epochs): + if self.use_tensorboardX and self.local_rank == 0: + self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name)) + + # mark untrained region (i.e., not covered by any camera from the training dataset) + if self.model.cuda_ray: + self.model.mark_untrained_grid(train_loader._data.poses, train_loader._data.intrinsics) + + for epoch in range(self.epoch + 1, max_epochs + 1): + self.epoch = epoch + + self.train_one_epoch(train_loader) + + if self.workspace is not None and self.local_rank == 0: + self.save_checkpoint(full=True, best=False) + + if self.epoch % self.eval_interval == 0: + self.evaluate_one_epoch(valid_loader) + self.save_checkpoint(full=False, best=True) + + if self.use_tensorboardX and self.local_rank == 0: + self.writer.close() + + def evaluate(self, loader, name=None): + self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX + self.evaluate_one_epoch(loader, name) + self.use_tensorboardX = use_tensorboardX + + def test(self, loader, save_path=None, name=None, write_image=False): + + if save_path is None: + save_path = os.path.join(self.workspace, 'results') + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + os.makedirs(save_path, exist_ok=True) + + self.log(f"==> Start Test, save results to {save_path}") + + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + self.model.eval() + + all_preds = [] + + with torch.no_grad(): + + for i, data in enumerate(loader): + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, preds_depth = self.test_step(data) + + path = os.path.join(save_path, f'{name}_{i:04d}_rgb.png') + path_depth = os.path.join(save_path, f'{name}_{i:04d}_depth.png') + + #self.log(f"[INFO] saving test image to {path}") + + if self.opt.color_space == 'linear': + preds = linear_to_srgb(preds) + + pred = preds[0].detach().cpu().numpy() + pred = (pred * 255).astype(np.uint8) + + pred_depth = preds_depth[0].detach().cpu().numpy() + pred_depth = (pred_depth * 255).astype(np.uint8) + + if write_image: + imageio.imwrite(path, pred) + imageio.imwrite(path_depth, pred_depth) + + all_preds.append(pred) + + pbar.update(loader.batch_size) + + # write video + all_preds = np.stack(all_preds, axis=0) + imageio.mimwrite(os.path.join(save_path, f'{name}.mp4'), all_preds, fps=25, quality=8, macro_block_size=1) + + self.log(f"==> Finished Test.") + + # [GUI] just train for 16 steps, without any other overhead that may slow down rendering. + def train_gui(self, train_loader, step=16): + + self.model.train() + + total_loss = torch.tensor([0], dtype=torch.float32, device=self.device) + + loader = iter(train_loader) + + # mark untrained grid + if self.global_step == 0: + self.model.mark_untrained_grid(train_loader._data.poses, train_loader._data.intrinsics) + + for _ in range(step): + + # mimic an infinite loop dataloader (in case the total dataset is smaller than step) + try: + data = next(loader) + except StopIteration: + loader = iter(train_loader) + data = next(loader) + + # update grid every 16 steps + if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() + + self.global_step += 1 + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, truths, loss = self.train_step(data) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + + if self.scheduler_update_every_step: + self.lr_scheduler.step() + + total_loss += loss.detach() + + if self.ema is not None and self.global_step % self.ema_update_interval == 0: + self.ema.update() + + average_loss = total_loss.item() / step + + if not self.scheduler_update_every_step: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(average_loss) + else: + self.lr_scheduler.step() + + outputs = { + 'loss': average_loss, + 'lr': self.optimizer.param_groups[0]['lr'], + } + + return outputs + + # [GUI] test on a single image + def test_gui(self, pose, intrinsics, W, H, auds, eye=None, index=0, bg_color=None, spp=1, downscale=1): + + # render resolution (may need downscale to for better frame rate) + rH = int(H * downscale) + rW = int(W * downscale) + intrinsics = intrinsics * downscale + + if auds is not None: + auds = auds.to(self.device) + + pose = torch.from_numpy(pose).unsqueeze(0).to(self.device) + rays = get_rays(pose, intrinsics, rH, rW, -1) + + bg_coords = get_bg_coords(rH, rW, self.device) + + if eye is not None: + eye = torch.FloatTensor([eye]).view(1, 1).to(self.device) + + data = { + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + 'H': rH, + 'W': rW, + 'auds': auds, + 'index': [index], # support choosing index for individual codes + 'eye': eye, + 'poses': pose, + 'bg_coords': bg_coords, + } + + self.model.eval() + + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=self.fp16): + # here spp is used as perturb random seed! + # face: do not perturb for the first spp, else lead to scatters. + preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=False if spp == 1 else spp) + + if self.ema is not None: + self.ema.restore() + + # interpolation to the original resolution + if downscale != 1: + # TODO: have to permute twice with torch... + preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='bilinear').permute(0, 2, 3, 1).contiguous() + preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1) + + if self.opt.color_space == 'linear': + preds = linear_to_srgb(preds) + + pred = preds[0].detach().cpu().numpy() + pred_depth = preds_depth[0].detach().cpu().numpy() + + outputs = { + 'image': pred, + 'depth': pred_depth, + } + + return outputs + + # [GUI] test with provided data + def test_gui_with_data(self, data, W, H): + + self.model.eval() + + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=self.fp16): + # here spp is used as perturb random seed! + # face: do not perturb for the first spp, else lead to scatters. + preds, preds_depth = self.test_step(data, perturb=False) + + if self.ema is not None: + self.ema.restore() + + if self.opt.color_space == 'linear': + preds = linear_to_srgb(preds) + + # the H/W in data may be differnt to GUI, so we still need to resize... + preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='bilinear').permute(0, 2, 3, 1).contiguous() + preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1) + + pred = preds[0].detach().cpu().numpy() + pred_depth = preds_depth[0].detach().cpu().numpy() + + outputs = { + 'image': pred, + 'depth': pred_depth, + } + + return outputs + + def train_one_epoch(self, loader): + self.log(f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...") + + total_loss = 0 + if self.local_rank == 0 and self.report_metric_at_train: + for metric in self.metrics: + metric.clear() + + self.model.train() + + # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs + # ref: https://pytorch.org/docs/stable/data.html + if self.world_size > 1: + loader.sampler.set_epoch(self.epoch) + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, mininterval=1, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + self.local_step = 0 + + for data in loader: + # update grid every 16 steps + if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() + + self.local_step += 1 + self.global_step += 1 + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, truths, loss = self.train_step(data) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + + if self.scheduler_update_every_step: + self.lr_scheduler.step() + + loss_val = loss.item() + total_loss += loss_val + + if self.ema is not None and self.global_step % self.ema_update_interval == 0: + self.ema.update() + + if self.local_rank == 0: + if self.report_metric_at_train: + for metric in self.metrics: + metric.update(preds, truths) + + if self.use_tensorboardX: + self.writer.add_scalar("train/loss", loss_val, self.global_step) + self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step) + + if self.scheduler_update_every_step: + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}") + else: + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") + pbar.update(loader.batch_size) + + average_loss = total_loss / self.local_step + self.stats["loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if self.report_metric_at_train: + for metric in self.metrics: + self.log(metric.report(), style="red") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="train") + metric.clear() + + if not self.scheduler_update_every_step: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(average_loss) + else: + self.lr_scheduler.step() + + self.log(f"==> Finished Epoch {self.epoch}.") + + + def evaluate_one_epoch(self, loader, name=None): + self.log(f"++> Evaluate at epoch {self.epoch} ...") + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + total_loss = 0 + if self.local_rank == 0: + for metric in self.metrics: + metric.clear() + + self.model.eval() + + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + with torch.no_grad(): + self.local_step = 0 + + for data in loader: + self.local_step += 1 + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, preds_depth, pred_ambient_aud, pred_ambient_eye, pred_uncertainty, truths, loss, loss_raw = self.eval_step(data) + + loss_val = loss.item() + total_loss += loss_val + + # only rank = 0 will perform evaluation. + if self.local_rank == 0: + + for metric in self.metrics: + metric.update(preds, truths) + + # save image + save_path = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png') + save_path_depth = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_depth.png') + # save_path_error = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_errormap.png') + save_path_ambient_aud = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_aud.png') + save_path_ambient_eye = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_eye.png') + save_path_uncertainty = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_uncertainty.png') + #save_path_gt = os.path.join(self.workspace, 'validation', f'{name}_{self.local_step:04d}_gt.png') + + #self.log(f"==> Saving validation image to {save_path}") + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + if self.opt.color_space == 'linear': + preds = linear_to_srgb(preds) + + pred = preds[0].detach().cpu().numpy() + pred_depth = preds_depth[0].detach().cpu().numpy() + # loss_raw = loss_raw[0].mean(-1).detach().cpu().numpy() + # loss_raw = (loss_raw - np.min(loss_raw)) / (np.max(loss_raw) - np.min(loss_raw)) + pred_ambient_aud = pred_ambient_aud[0].detach().cpu().numpy() + pred_ambient_aud /= np.max(pred_ambient_aud) + pred_ambient_eye = pred_ambient_eye[0].detach().cpu().numpy() + pred_ambient_eye /= np.max(pred_ambient_eye) + # pred_ambient = pred_ambient / 16 + # print(pred_ambient.shape) + pred_uncertainty = pred_uncertainty[0].detach().cpu().numpy() + # pred_uncertainty = (pred_uncertainty - np.min(pred_uncertainty)) / (np.max(pred_uncertainty) - np.min(pred_uncertainty)) + pred_uncertainty /= np.max(pred_uncertainty) + + cv2.imwrite(save_path, cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + + if not self.opt.torso: + cv2.imwrite(save_path_depth, (pred_depth * 255).astype(np.uint8)) + # cv2.imwrite(save_path_error, (loss_raw * 255).astype(np.uint8)) + cv2.imwrite(save_path_ambient_aud, (pred_ambient_aud * 255).astype(np.uint8)) + cv2.imwrite(save_path_ambient_eye, (pred_ambient_eye * 255).astype(np.uint8)) + cv2.imwrite(save_path_uncertainty, (pred_uncertainty * 255).astype(np.uint8)) + #cv2.imwrite(save_path_gt, cv2.cvtColor((linear_to_srgb(truths[0].detach().cpu().numpy()) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + + pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") + pbar.update(loader.batch_size) + + + average_loss = total_loss / self.local_step + self.stats["valid_loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if not self.use_loss_as_metric and len(self.metrics) > 0: + result = self.metrics[0].measure() + self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result + else: + self.stats["results"].append(average_loss) # if no metric, choose best by min loss + + for metric in self.metrics: + self.log(metric.report(), style="blue") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="evaluate") + metric.clear() + + if self.ema is not None: + self.ema.restore() + + self.log(f"++> Evaluate epoch {self.epoch} Finished.") + + def save_checkpoint(self, name=None, full=False, best=False, remove_old=True): + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + state = { + 'epoch': self.epoch, + 'global_step': self.global_step, + 'stats': self.stats, + } + + + state['mean_count'] = self.model.mean_count + state['mean_density'] = self.model.mean_density + state['mean_density_torso'] = self.model.mean_density_torso + + if full: + state['optimizer'] = self.optimizer.state_dict() + state['lr_scheduler'] = self.lr_scheduler.state_dict() + state['scaler'] = self.scaler.state_dict() + if self.ema is not None: + state['ema'] = self.ema.state_dict() + + if not best: + + state['model'] = self.model.state_dict() + + file_path = f"{self.ckpt_path}/{name}.pth" + + if remove_old: + self.stats["checkpoints"].append(file_path) + + if len(self.stats["checkpoints"]) > self.max_keep_ckpt: + old_ckpt = self.stats["checkpoints"].pop(0) + if os.path.exists(old_ckpt): + os.remove(old_ckpt) + + torch.save(state, file_path) + + else: + if len(self.stats["results"]) > 0: + # always save new as best... (since metric cannot really reflect performance...) + if True: + + # save ema results + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + state['model'] = self.model.state_dict() + + # we don't consider continued training from the best ckpt, so we discard the unneeded density_grid to save some storage (especially important for dnerf) + if 'density_grid' in state['model']: + del state['model']['density_grid'] + + if self.ema is not None: + self.ema.restore() + + torch.save(state, self.best_path) + else: + self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.") + + def load_checkpoint(self, checkpoint=None, model_only=False): + if checkpoint is None: + checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/{self.name}_ep*.pth')) + if checkpoint_list: + checkpoint = checkpoint_list[-1] + self.log(f"[INFO] Latest checkpoint is {checkpoint}") + else: + self.log("[WARN] No checkpoint found, model randomly initialized.") + return + + checkpoint_dict = torch.load(checkpoint, map_location=self.device) + + if 'model' not in checkpoint_dict: + self.model.load_state_dict(checkpoint_dict) + self.log("[INFO] loaded bare model.") + return + + missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False) + self.log("[INFO] loaded model.") + if len(missing_keys) > 0: + self.log(f"[WARN] missing keys: {missing_keys}") + if len(unexpected_keys) > 0: + self.log(f"[WARN] unexpected keys: {unexpected_keys}") + + if self.ema is not None and 'ema' in checkpoint_dict: + self.ema.load_state_dict(checkpoint_dict['ema']) + + + if 'mean_count' in checkpoint_dict: + self.model.mean_count = checkpoint_dict['mean_count'] + if 'mean_density' in checkpoint_dict: + self.model.mean_density = checkpoint_dict['mean_density'] + if 'mean_density_torso' in checkpoint_dict: + self.model.mean_density_torso = checkpoint_dict['mean_density_torso'] + + if model_only: + return + + self.stats = checkpoint_dict['stats'] + self.epoch = checkpoint_dict['epoch'] + self.global_step = checkpoint_dict['global_step'] + self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}") + + if self.optimizer and 'optimizer' in checkpoint_dict: + try: + self.optimizer.load_state_dict(checkpoint_dict['optimizer']) + self.log("[INFO] loaded optimizer.") + except: + self.log("[WARN] Failed to load optimizer.") + + if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict: + try: + self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler']) + self.log("[INFO] loaded scheduler.") + except: + self.log("[WARN] Failed to load scheduler.") + + if self.scaler and 'scaler' in checkpoint_dict: + try: + self.scaler.load_state_dict(checkpoint_dict['scaler']) + self.log("[INFO] loaded scaler.") + except: + self.log("[WARN] Failed to load scaler.") \ No newline at end of file diff --git a/raymarching/__init__.py b/raymarching/__init__.py new file mode 100644 index 0000000..26d3cc6 --- /dev/null +++ b/raymarching/__init__.py @@ -0,0 +1 @@ +from .raymarching import * \ No newline at end of file diff --git a/raymarching/backend.py b/raymarching/backend.py new file mode 100644 index 0000000..2d41d14 --- /dev/null +++ b/raymarching/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_raymarching_face', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'raymarching.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/raymarching/raymarching.py b/raymarching/raymarching.py new file mode 100644 index 0000000..05fb1e6 --- /dev/null +++ b/raymarching/raymarching.py @@ -0,0 +1,671 @@ +import numpy as np +import time + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _raymarching_face as _backend +except ImportError: + from .backend import _backend + +# ---------------------------------------- +# utils +# ---------------------------------------- + +class _near_far_from_aabb(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): + ''' near_far_from_aabb, CUDA implementation + Calculate rays' intersection time (near and far) with aabb + Args: + rays_o: float, [N, 3] + rays_d: float, [N, 3] + aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) + min_near: float, scalar + Returns: + nears: float, [N] + fars: float, [N] + ''' + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + + _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) + + return nears, fars + +near_far_from_aabb = _near_far_from_aabb.apply + + +class _sph_from_ray(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, radius): + ''' sph_from_ray, CUDA implementation + get spherical coordinate on the background sphere from rays. + Assume rays_o are inside the Sphere(radius). + Args: + rays_o: [N, 3] + rays_d: [N, 3] + radius: scalar, float + Return: + coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface) + ''' + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) + + _backend.sph_from_ray(rays_o, rays_d, radius, N, coords) + + return coords + +sph_from_ray = _sph_from_ray.apply + + +class _morton3D(Function): + @staticmethod + def forward(ctx, coords): + ''' morton3D, CUDA implementation + Args: + coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) + TODO: check if the coord range is valid! (current 128 is safe) + Returns: + indices: [N], int32, in [0, 128^3) + + ''' + if not coords.is_cuda: coords = coords.cuda() + + N = coords.shape[0] + + indices = torch.empty(N, dtype=torch.int32, device=coords.device) + + _backend.morton3D(coords.int(), N, indices) + + return indices + +morton3D = _morton3D.apply + +class _morton3D_invert(Function): + @staticmethod + def forward(ctx, indices): + ''' morton3D_invert, CUDA implementation + Args: + indices: [N], int32, in [0, 128^3) + Returns: + coords: [N, 3], int32, in [0, 128) + + ''' + if not indices.is_cuda: indices = indices.cuda() + + N = indices.shape[0] + + coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) + + _backend.morton3D_invert(indices.int(), N, coords) + + return coords + +morton3D_invert = _morton3D_invert.apply + + +class _packbits(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid, thresh, bitfield=None): + ''' packbits, CUDA implementation + Pack up the density grid into a bit field to accelerate ray marching. + Args: + grid: float, [C, H * H * H], assume H % 2 == 0 + thresh: float, threshold + Returns: + bitfield: uint8, [C, H * H * H / 8] + ''' + if not grid.is_cuda: grid = grid.cuda() + grid = grid.contiguous() + + C = grid.shape[0] + H3 = grid.shape[1] + N = C * H3 // 8 + + if bitfield is None: + bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) + + _backend.packbits(grid, N, thresh, bitfield) + + return bitfield + +packbits = _packbits.apply + + +class _morton3D_dilation(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid): + ''' max pooling with morton coord, CUDA implementation + or maybe call it dilation... we don't support adjust kernel size. + Args: + grid: float, [C, H * H * H], assume H % 2 == 0 + Returns: + grid_dilate: float, [C, H * H * H], assume H % 2 == 0bitfield: uint8, [C, H * H * H / 8] + ''' + if not grid.is_cuda: grid = grid.cuda() + grid = grid.contiguous() + + C = grid.shape[0] + H3 = grid.shape[1] + H = int(np.cbrt(H3)) + grid_dilation = torch.empty_like(grid) + + _backend.morton3D_dilation(grid, C, H, grid_dilation) + + return grid_dilation + +morton3D_dilation = _morton3D_dilation.apply + +# ---------------------------------------- +# train functions +# ---------------------------------------- + +class _march_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024): + ''' march rays to generate points (forward only) + Args: + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + step_counter: int32, (2), used to count the actual number of generated points. + mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) + perturb: bool + align: int, pad output so its size is dividable by align, set to -1 to disable. + force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) + dirs: float, [M, 3], all generated points' view dirs. + deltas: float, [M, 2], first is delta_t, second is rays_t + rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 1] + rays[i, 2]] --> points belonging to rays[i, 0] + ''' + + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + density_bitfield = density_bitfield.contiguous() + + N = rays_o.shape[0] # num rays + M = N * max_steps # init max points number in total + + # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) + # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated. + if not force_all_rays and mean_count > 0: + if align > 0: + mean_count += align - mean_count % align + M = mean_count + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) + rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps + + if step_counter is None: + step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter + + if perturb: + noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device) + + _backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number + + #print(step_counter, M) + + # only used at the first (few) epochs. + if force_all_rays or mean_count <= 0: + m = step_counter[0].item() # D2H copy + if align > 0: + m += align - m % align + xyzs = xyzs[:m] + dirs = dirs[:m] + deltas = deltas[:m] + + torch.cuda.empty_cache() + + ctx.save_for_backward(rays, deltas) + + return xyzs, dirs, deltas, rays + + # to support optimizing camera poses. + @staticmethod + @custom_bwd + def backward(ctx, grad_xyzs, grad_dirs, grad_deltas, grad_rays): + # grad_xyzs/dirs: [M, 3] + + rays, deltas = ctx.saved_tensors + + N = rays.shape[0] + M = grad_xyzs.shape[0] + + grad_rays_o = torch.zeros(N, 3, device=rays.device) + grad_rays_d = torch.zeros(N, 3, device=rays.device) + + _backend.march_rays_train_backward(grad_xyzs, grad_dirs, rays, deltas, N, M, grad_rays_o, grad_rays_d) + + return grad_rays_o, grad_rays_d, None, None, None, None, None, None, None, None, None, None, None, None, None + +march_rays_train = _march_rays_train.apply + + +class _composite_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4): + ''' composite rays' rgbs, according to the ray marching formula. + Args: + rgbs: float, [M, 3] + sigmas: float, [M,] + ambient: float, [M,] (after summing up the last dimension) + deltas: float, [M, 2] + rays: int32, [N, 3] + Returns: + weights_sum: float, [N,], the alpha channel + depth: float, [N, ], the Depth + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + + sigmas = sigmas.contiguous() + rgbs = rgbs.contiguous() + ambient = ambient.contiguous() + + M = sigmas.shape[0] + N = rays.shape[0] + + weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) + + _backend.composite_rays_train_forward(sigmas, rgbs, ambient, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, depth, image) + + ctx.save_for_backward(sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image) + ctx.dims = [M, N, T_thresh] + + return weights_sum, ambient_sum, depth, image + + @staticmethod + @custom_bwd + def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image): + + # NOTE: grad_depth is not used now! It won't be propagated to sigmas. + + grad_weights_sum = grad_weights_sum.contiguous() + grad_ambient_sum = grad_ambient_sum.contiguous() + grad_image = grad_image.contiguous() + + sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image = ctx.saved_tensors + M, N, T_thresh = ctx.dims + + grad_sigmas = torch.zeros_like(sigmas) + grad_rgbs = torch.zeros_like(rgbs) + grad_ambient = torch.zeros_like(ambient) + + _backend.composite_rays_train_backward(grad_weights_sum, grad_ambient_sum, grad_image, sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient) + + return grad_sigmas, grad_rgbs, grad_ambient, None, None, None + + +composite_rays_train = _composite_rays_train.apply + +# ---------------------------------------- +# infer functions +# ---------------------------------------- + +class _march_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024): + ''' march rays to generate points (forward only, for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) + rays_t: float, [N], the alive rays' time, we only use the first n_alive. + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + align: int, pad output so its size is dividable by align, set to -1 to disable. + perturb: bool/int, int > 0 is used as the random seed. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [n_alive * n_step, 3], all generated points' coords + dirs: float, [n_alive * n_step, 3], all generated points' view dirs. + deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). + ''' + + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + M = n_alive * n_step + + if align > 0: + M += align - (M % align) + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth + + if perturb: + # torch.manual_seed(perturb) # test_gui uses spp index as seed + noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device) + + _backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises) + + return xyzs, dirs, deltas + +march_rays = _march_rays.apply + + +class _composite_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float + def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2): + ''' composite rays' rgbs, according to the ray marching formula. (for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive) + rays_t: float, [N], the alive rays' time + sigmas: float, [n_alive * n_step,] + rgbs: float, [n_alive * n_step, 3] + deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). + In-place Outputs: + weights_sum: float, [N,], the alpha channel + depth: float, [N,], the depth value + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + _backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image) + return tuple() + + +composite_rays = _composite_rays.apply + + +class _composite_rays_ambient(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float + def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum, T_thresh=1e-2): + _backend.composite_rays_ambient(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum) + return tuple() + + +composite_rays_ambient = _composite_rays_ambient.apply + + + + + +# custom + +class _composite_rays_train_sigma(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, sigmas, rgbs, ambient, deltas, rays, T_thresh=1e-4): + ''' composite rays' rgbs, according to the ray marching formula. + Args: + rgbs: float, [M, 3] + sigmas: float, [M,] + ambient: float, [M,] (after summing up the last dimension) + deltas: float, [M, 2] + rays: int32, [N, 3] + Returns: + weights_sum: float, [N,], the alpha channel + depth: float, [N, ], the Depth + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + + sigmas = sigmas.contiguous() + rgbs = rgbs.contiguous() + ambient = ambient.contiguous() + + M = sigmas.shape[0] + N = rays.shape[0] + + weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) + + _backend.composite_rays_train_sigma_forward(sigmas, rgbs, ambient, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, depth, image) + + ctx.save_for_backward(sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image) + ctx.dims = [M, N, T_thresh] + + return weights_sum, ambient_sum, depth, image + + @staticmethod + @custom_bwd + def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_depth, grad_image): + + # NOTE: grad_depth is not used now! It won't be propagated to sigmas. + + grad_weights_sum = grad_weights_sum.contiguous() + grad_ambient_sum = grad_ambient_sum.contiguous() + grad_image = grad_image.contiguous() + + sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, depth, image = ctx.saved_tensors + M, N, T_thresh = ctx.dims + + grad_sigmas = torch.zeros_like(sigmas) + grad_rgbs = torch.zeros_like(rgbs) + grad_ambient = torch.zeros_like(ambient) + + _backend.composite_rays_train_sigma_backward(grad_weights_sum, grad_ambient_sum, grad_image, sigmas, rgbs, ambient, deltas, rays, weights_sum, ambient_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient) + + return grad_sigmas, grad_rgbs, grad_ambient, None, None, None + + +composite_rays_train_sigma = _composite_rays_train_sigma.apply + + +class _composite_rays_ambient_sigma(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float + def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum, T_thresh=1e-2): + _backend.composite_rays_ambient_sigma(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, weights_sum, depth, image, ambient_sum) + return tuple() + + +composite_rays_ambient_sigma = _composite_rays_ambient_sigma.apply + + + +# uncertainty +class _composite_rays_train_uncertainty(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, sigmas, rgbs, ambient, uncertainty, deltas, rays, T_thresh=1e-4): + ''' composite rays' rgbs, according to the ray marching formula. + Args: + rgbs: float, [M, 3] + sigmas: float, [M,] + ambient: float, [M,] (after summing up the last dimension) + deltas: float, [M, 2] + rays: int32, [N, 3] + Returns: + weights_sum: float, [N,], the alpha channel + depth: float, [N, ], the Depth + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + + sigmas = sigmas.contiguous() + rgbs = rgbs.contiguous() + ambient = ambient.contiguous() + uncertainty = uncertainty.contiguous() + + M = sigmas.shape[0] + N = rays.shape[0] + + weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + ambient_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + uncertainty_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) + + _backend.composite_rays_train_uncertainty_forward(sigmas, rgbs, ambient, uncertainty, deltas, rays, M, N, T_thresh, weights_sum, ambient_sum, uncertainty_sum, depth, image) + + ctx.save_for_backward(sigmas, rgbs, ambient, uncertainty, deltas, rays, weights_sum, ambient_sum, uncertainty_sum, depth, image) + ctx.dims = [M, N, T_thresh] + + return weights_sum, ambient_sum, uncertainty_sum, depth, image + + @staticmethod + @custom_bwd + def backward(ctx, grad_weights_sum, grad_ambient_sum, grad_uncertainty_sum, grad_depth, grad_image): + + # NOTE: grad_depth is not used now! It won't be propagated to sigmas. + + grad_weights_sum = grad_weights_sum.contiguous() + grad_ambient_sum = grad_ambient_sum.contiguous() + grad_uncertainty_sum = grad_uncertainty_sum.contiguous() + grad_image = grad_image.contiguous() + + sigmas, rgbs, ambient, uncertainty, deltas, rays, weights_sum, ambient_sum, uncertainty_sum, depth, image = ctx.saved_tensors + M, N, T_thresh = ctx.dims + + grad_sigmas = torch.zeros_like(sigmas) + grad_rgbs = torch.zeros_like(rgbs) + grad_ambient = torch.zeros_like(ambient) + grad_uncertainty = torch.zeros_like(uncertainty) + + _backend.composite_rays_train_uncertainty_backward(grad_weights_sum, grad_ambient_sum, grad_uncertainty_sum, grad_image, sigmas, rgbs, ambient, uncertainty, deltas, rays, weights_sum, ambient_sum, uncertainty_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_ambient, grad_uncertainty) + + return grad_sigmas, grad_rgbs, grad_ambient, grad_uncertainty, None, None, None + + +composite_rays_train_uncertainty = _composite_rays_train_uncertainty.apply + + +class _composite_rays_uncertainty(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float + def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum, T_thresh=1e-2): + _backend.composite_rays_uncertainty(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambients, uncertainties, weights_sum, depth, image, ambient_sum, uncertainty_sum) + return tuple() + + +composite_rays_uncertainty = _composite_rays_uncertainty.apply + + + +# triplane(eye) +class _composite_rays_train_triplane(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, T_thresh=1e-4): + ''' composite rays' rgbs, according to the ray marching formula. + Args: + rgbs: float, [M, 3] + sigmas: float, [M,] + ambient: float, [M,] (after summing up the last dimension) + deltas: float, [M, 2] + rays: int32, [N, 3] + Returns: + weights_sum: float, [N,], the alpha channel + depth: float, [N, ], the Depth + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + + sigmas = sigmas.contiguous() + rgbs = rgbs.contiguous() + amb_aud = amb_aud.contiguous() + amb_eye = amb_eye.contiguous() + uncertainty = uncertainty.contiguous() + + M = sigmas.shape[0] + N = rays.shape[0] + + weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + amb_aud_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + amb_eye_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + uncertainty_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) + + _backend.composite_rays_train_triplane_forward(sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, M, N, T_thresh, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image) + + ctx.save_for_backward(sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image) + ctx.dims = [M, N, T_thresh] + + return weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image + + @staticmethod + @custom_bwd + def backward(ctx, grad_weights_sum, grad_amb_aud_sum, grad_amb_eye_sum, grad_uncertainty_sum, grad_depth, grad_image): + + # NOTE: grad_depth is not used now! It won't be propagated to sigmas. + + grad_weights_sum = grad_weights_sum.contiguous() + grad_amb_aud_sum = grad_amb_aud_sum.contiguous() + grad_amb_eye_sum = grad_amb_eye_sum.contiguous() + grad_uncertainty_sum = grad_uncertainty_sum.contiguous() + grad_image = grad_image.contiguous() + + sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, depth, image = ctx.saved_tensors + M, N, T_thresh = ctx.dims + + grad_sigmas = torch.zeros_like(sigmas) + grad_rgbs = torch.zeros_like(rgbs) + grad_amb_aud = torch.zeros_like(amb_aud) + grad_amb_eye = torch.zeros_like(amb_eye) + grad_uncertainty = torch.zeros_like(uncertainty) + + _backend.composite_rays_train_triplane_backward(grad_weights_sum, grad_amb_aud_sum, grad_amb_eye_sum, grad_uncertainty_sum, grad_image, sigmas, rgbs, amb_aud, amb_eye, uncertainty, deltas, rays, weights_sum, amb_aud_sum, amb_eye_sum, uncertainty_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, grad_amb_aud, grad_amb_eye, grad_uncertainty) + + return grad_sigmas, grad_rgbs, grad_amb_aud, grad_amb_eye, grad_uncertainty, None, None, None + + +composite_rays_train_triplane = _composite_rays_train_triplane.apply + + +class _composite_rays_triplane(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float + def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, ambs_aud, ambs_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum, T_thresh=1e-2): + _backend.composite_rays_triplane(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, ambs_aud, ambs_eye, uncertainties, weights_sum, depth, image, amb_aud_sum, amb_eye_sum, uncertainty_sum) + return tuple() + + +composite_rays_triplane = _composite_rays_triplane.apply \ No newline at end of file diff --git a/raymarching/setup.py b/raymarching/setup.py new file mode 100644 index 0000000..c2fbd1b --- /dev/null +++ b/raymarching/setup.py @@ -0,0 +1,63 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + # '-lineinfo', # to debug illegal memory access + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +''' +Usage: + +python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) + +python setup.py install # build extensions and install (copy) to PATH. +pip install . # ditto but better (e.g., dependency & metadata handling) + +python setup.py develop # build extensions and install (symbolic) to PATH. +pip install -e . # ditto but better (e.g., dependency & metadata handling) + +''' +setup( + name='raymarching_face', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_raymarching_face', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'raymarching.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/raymarching/src/bindings.cpp b/raymarching/src/bindings.cpp new file mode 100644 index 0000000..f8298bf --- /dev/null +++ b/raymarching/src/bindings.cpp @@ -0,0 +1,39 @@ +#include + +#include "raymarching.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // utils + m.def("packbits", &packbits, "packbits (CUDA)"); + m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); + m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); + m.def("morton3D", &morton3D, "morton3D (CUDA)"); + m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); + m.def("morton3D_dilation", &morton3D_dilation, "morton3D_dilation (CUDA)"); + // train + m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); + m.def("march_rays_train_backward", &march_rays_train_backward, "march_rays_train_backward (CUDA)"); + m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); + m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); + // infer + m.def("march_rays", &march_rays, "march rays (CUDA)"); + m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); + m.def("composite_rays_ambient", &composite_rays_ambient, "composite rays with ambient (CUDA)"); + + // train + m.def("composite_rays_train_sigma_forward", &composite_rays_train_sigma_forward, "composite_rays_train_forward (CUDA)"); + m.def("composite_rays_train_sigma_backward", &composite_rays_train_sigma_backward, "composite_rays_train_backward (CUDA)"); + // infer + m.def("composite_rays_ambient_sigma", &composite_rays_ambient_sigma, "composite rays with ambient (CUDA)"); + + // uncertainty train + m.def("composite_rays_train_uncertainty_forward", &composite_rays_train_uncertainty_forward, "composite_rays_train_forward (CUDA)"); + m.def("composite_rays_train_uncertainty_backward", &composite_rays_train_uncertainty_backward, "composite_rays_train_backward (CUDA)"); + m.def("composite_rays_uncertainty", &composite_rays_uncertainty, "composite rays with ambient (CUDA)"); + + // triplane + m.def("composite_rays_train_triplane_forward", &composite_rays_train_triplane_forward, "composite_rays_train_forward (CUDA)"); + m.def("composite_rays_train_triplane_backward", &composite_rays_train_triplane_backward, "composite_rays_train_backward (CUDA)"); + m.def("composite_rays_triplane", &composite_rays_triplane, "composite rays with ambient (CUDA)"); + +} \ No newline at end of file diff --git a/raymarching/src/raymarching.cu b/raymarching/src/raymarching.cu new file mode 100644 index 0000000..d7788b7 --- /dev/null +++ b/raymarching/src/raymarching.cu @@ -0,0 +1,2258 @@ +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; } +inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; } +inline constexpr __device__ float PI() { return 3.141592653589793f; } +inline constexpr __device__ float RPI() { return 0.3183098861837907f; } + + +template +inline __host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +inline __host__ __device__ float signf(const float x) { + return copysignf(1.0, x); +} + +inline __host__ __device__ float clamp(const float x, const float min, const float max) { + return fminf(max, fmaxf(min, x)); +} + +inline __host__ __device__ void swapf(float& a, float& b) { + float c = a; a = b; b = c; +} + +inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) { + const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z))); + int exponent; + frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ... + return fminf(max_cascade - 1, fmaxf(0, exponent)); +} + +inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) { + const float mx = dt * H * 0.5; + int exponent; + frexpf(mx, &exponent); + return fminf(max_cascade - 1, fmaxf(0, exponent)); +} + +inline __host__ __device__ uint32_t __expand_bits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + +inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z) +{ + uint32_t xx = __expand_bits(x); + uint32_t yy = __expand_bits(y); + uint32_t zz = __expand_bits(z); + return xx | (yy << 1) | (zz << 2); +} + +inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) +{ + x = x & 0x49249249; + x = (x | (x >> 2)) & 0xc30c30c3; + x = (x | (x >> 4)) & 0x0f00f00f; + x = (x | (x >> 8)) & 0xff0000ff; + x = (x | (x >> 16)) & 0x0000ffff; + return x; +} + + +//////////////////////////////////////////////////// +///////////// utils ///////////// +//////////////////////////////////////////////////// + +// rays_o/d: [N, 3] +// nears/fars: [N] +// scalar_t should always be float in use. +template +__global__ void kernel_near_far_from_aabb( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const scalar_t * __restrict__ aabb, + const uint32_t N, + const float min_near, + scalar_t * nears, scalar_t * fars +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + + // get near far (assume cube scene) + float near = (aabb[0] - ox) * rdx; + float far = (aabb[3] - ox) * rdx; + if (near > far) swapf(near, far); + + float near_y = (aabb[1] - oy) * rdy; + float far_y = (aabb[4] - oy) * rdy; + if (near_y > far_y) swapf(near_y, far_y); + + if (near > far_y || near_y > far) { + nears[n] = fars[n] = std::numeric_limits::max(); + return; + } + + if (near_y > near) near = near_y; + if (far_y < far) far = far_y; + + float near_z = (aabb[2] - oz) * rdz; + float far_z = (aabb[5] - oz) * rdz; + if (near_z > far_z) swapf(near_z, far_z); + + if (near > far_z || near_z > far) { + nears[n] = fars[n] = std::numeric_limits::max(); + return; + } + + if (near_z > near) near = near_z; + if (far_z < far) far = far_z; + + if (near < min_near) near = min_near; + + nears[n] = near; + fars[n] = far; +} + + +void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "near_far_from_aabb", ([&] { + kernel_near_far_from_aabb<<>>(rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr()); + })); +} + + +// rays_o/d: [N, 3] +// radius: float +// coords: [N, 2] +template +__global__ void kernel_sph_from_ray( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const float radius, + const uint32_t N, + scalar_t * coords +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + coords += n * 2; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + + // solve t from || o + td || = radius + const float A = dx * dx + dy * dy + dz * dz; + const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2 + const float C = ox * ox + oy * oy + oz * oz - radius * radius; + + const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive) + + // solve theta, phi (assume y is the up axis) + const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz; + const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI) + const float phi = atan2(z, x); // [-PI, PI) + + // normalize to [-1, 1] + coords[0] = 2 * theta * RPI() - 1; + coords[1] = phi * RPI(); +} + + +void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "sph_from_ray", ([&] { + kernel_sph_from_ray<<>>(rays_o.data_ptr(), rays_d.data_ptr(), radius, N, coords.data_ptr()); + })); +} + + +// coords: int32, [N, 3] +// indices: int32, [N] +__global__ void kernel_morton3D( + const int * __restrict__ coords, + const uint32_t N, + int * indices +) { + // parallel + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + coords += n * 3; + indices[n] = __morton3D(coords[0], coords[1], coords[2]); +} + + +void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) { + static constexpr uint32_t N_THREAD = 128; + kernel_morton3D<<>>(coords.data_ptr(), N, indices.data_ptr()); +} + + +// indices: int32, [N] +// coords: int32, [N, 3] +__global__ void kernel_morton3D_invert( + const int * __restrict__ indices, + const uint32_t N, + int * coords +) { + // parallel + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + coords += n * 3; + + const int ind = indices[n]; + + coords[0] = __morton3D_invert(ind >> 0); + coords[1] = __morton3D_invert(ind >> 1); + coords[2] = __morton3D_invert(ind >> 2); +} + + +void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) { + static constexpr uint32_t N_THREAD = 128; + kernel_morton3D_invert<<>>(indices.data_ptr(), N, coords.data_ptr()); +} + + +// grid: float, [C, H, H, H] +// N: int, C * H * H * H / 8 +// density_thresh: float +// bitfield: uint8, [N] +template +__global__ void kernel_packbits( + const scalar_t * __restrict__ grid, + const uint32_t N, + const float density_thresh, + uint8_t * bitfield +) { + // parallel per byte + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + grid += n * 8; + + uint8_t bits = 0; + + #pragma unroll + for (uint8_t i = 0; i < 8; i++) { + bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0; + } + + bitfield[n] = bits; +} + + +void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grid.scalar_type(), "packbits", ([&] { + kernel_packbits<<>>(grid.data_ptr(), N, density_thresh, bitfield.data_ptr()); + })); +} + + +// grid: float, [C, H, H, H] +__global__ void kernel_morton3D_dilation( + const float * __restrict__ grid, + const uint32_t C, + const uint32_t H, + float * __restrict__ grid_dilation +) { + // parallel per byte + const uint32_t H3 = H * H * H; + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= C * H3) return; + + // locate + const uint32_t c = n / H3; + const uint32_t ind = n - c * H3; + + const uint32_t x = __morton3D_invert(ind >> 0); + const uint32_t y = __morton3D_invert(ind >> 1); + const uint32_t z = __morton3D_invert(ind >> 2); + + // manual max pool + float res = grid[n]; + + if (x + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x + 1, y, z)]); + if (x > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x - 1, y, z)]); + if (y + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x, y + 1, z)]); + if (y > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x, y - 1, z)]); + if (z + 1 < H) res = fmaxf(res, grid[c * H3 + __morton3D(x, y, z + 1)]); + if (z > 0) res = fmaxf(res, grid[c * H3 + __morton3D(x, y, z - 1)]); + + // write + grid_dilation[n] = res; +} + +void morton3D_dilation(const at::Tensor grid, const uint32_t C, const uint32_t H, at::Tensor grid_dilation) { + static constexpr uint32_t N_THREAD = 128; + + kernel_morton3D_dilation<<>>(grid.data_ptr(), C, H, grid_dilation.data_ptr()); +} + +//////////////////////////////////////////////////// +///////////// training ///////////// +//////////////////////////////////////////////////// + +// rays_o/d: [N, 3] +// grid: [CHHH / 8] +// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2] +// dirs: [M, 3] +// rays: [N, 3], idx, offset, num_steps +template +__global__ void kernel_march_rays_train( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const uint8_t * __restrict__ grid, + const float bound, + const float dt_gamma, const uint32_t max_steps, + const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, + const scalar_t* __restrict__ nears, + const scalar_t* __restrict__ fars, + scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas, + int * rays, + int * counter, + const scalar_t* __restrict__ noises +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + + // ray marching + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + const float rH = 1 / (float)H; + const float H3 = H * H * H; + + const float near = nears[n]; + const float far = fars[n]; + const float noise = noises[n]; + + const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; + const float dt_min = fminf(dt_max, 2 * SQRT3() / max_steps); + + float t0 = near; + + // perturb + t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise; + + // first pass: estimation of num_steps + float t = t0; + uint32_t num_steps = 0; + + //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far); + + while (t < far && num_steps < max_steps) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1.0f, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps); + + if (occ) { + num_steps++; + t += dt; + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } + + //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min); + + // second pass: really locate and write points & dirs + uint32_t point_index = atomicAdd(counter, num_steps); + uint32_t ray_index = atomicAdd(counter + 1, 1); + + //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index); + + // write rays + rays[ray_index * 3] = n; + rays[ray_index * 3 + 1] = point_index; + rays[ray_index * 3 + 2] = num_steps; + + if (num_steps == 0) return; + if (point_index + num_steps > M) return; + + xyzs += point_index * 3; + dirs += point_index * 3; + deltas += point_index * 2; + + t = t0; + uint32_t step = 0; + + while (t < far && step < num_steps) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1.0f, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + // query grid + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + if (occ) { + // write step + xyzs[0] = x; + xyzs[1] = y; + xyzs[2] = z; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; + t += dt; + deltas[0] = dt; + deltas[1] = t; // used to calc depth + xyzs += 3; + dirs += 3; + deltas += 2; + step++; + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } +} + +void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "march_rays_train", ([&] { + kernel_march_rays_train<<>>(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr(), fars.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), counter.data_ptr(), noises.data_ptr()); + })); +} + + +// grad_xyzs/dirs: [M, 3] +// rays: [N, 3] +// deltas: [M, 2] +// grad_rays_o/d: [N, 3] +template +__global__ void kernel_march_rays_train_backward( + const scalar_t * __restrict__ grad_xyzs, + const scalar_t * __restrict__ grad_dirs, + const int * __restrict__ rays, + const scalar_t * __restrict__ deltas, + const uint32_t N, const uint32_t M, + scalar_t * grad_rays_o, + scalar_t * grad_rays_d +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + grad_rays_o += n * 3; + grad_rays_d += n * 3; + + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps > M) return; + + grad_xyzs += offset * 3; + grad_dirs += offset * 3; + deltas += offset * 2; + + // accumulate + uint32_t step = 0; + while (step < num_steps) { + + grad_rays_o[0] += grad_xyzs[0]; + grad_rays_o[1] += grad_xyzs[1]; + grad_rays_o[2] += grad_xyzs[2]; + + grad_rays_d[0] += grad_xyzs[0] * deltas[1] + grad_dirs[0]; + grad_rays_d[1] += grad_xyzs[1] * deltas[1] + grad_dirs[1]; + grad_rays_d[2] += grad_xyzs[2] * deltas[1] + grad_dirs[2]; + + // locate + grad_xyzs += 3; + grad_dirs += 3; + deltas += 2; + + step++; + } +} + +void march_rays_train_backward(const at::Tensor grad_xyzs, const at::Tensor grad_dirs, const at::Tensor rays, const at::Tensor deltas, const uint32_t N, const uint32_t M, at::Tensor grad_rays_o, at::Tensor grad_rays_d) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_xyzs.scalar_type(), "march_rays_train_backward", ([&] { + kernel_march_rays_train_backward<<>>(grad_xyzs.data_ptr(), grad_dirs.data_ptr(), rays.data_ptr(), deltas.data_ptr(), N, M, grad_rays_o.data_ptr(), grad_rays_d.data_ptr()); + })); +} + + +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N], final pixel alpha +// depth: [N,] +// image: [N, 3] +template +__global__ void kernel_composite_rays_train_forward( + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ ambient, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * weights_sum, + scalar_t * ambient_sum, + scalar_t * depth, + scalar_t * image +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps > M) { + weights_sum[index] = 0; + ambient_sum[index] = 0; + depth[index] = 0; + image[index * 3] = 0; + image[index * 3 + 1] = 0; + image[index * 3 + 2] = 0; + return; + } + + sigmas += offset; + rgbs += offset * 3; + ambient += offset; + deltas += offset * 2; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + scalar_t r = 0, g = 0, b = 0, ws = 0, d = 0, amb = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + d += weight * deltas[1]; + + ws += weight; + + amb += ambient[0]; + + T *= 1.0f - alpha; + + // minimal remained transmittence + if (T < T_thresh) break; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // locate + sigmas++; + rgbs += 3; + ambient++; + deltas += 2; + + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // write + weights_sum[index] = ws; // weights_sum + ambient_sum[index] = amb; + depth[index] = d; + image[index * 3] = r; + image[index * 3 + 1] = g; + image[index * 3 + 2] = b; +} + + +void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + sigmas.scalar_type(), "composite_rays_train_forward", ([&] { + kernel_composite_rays_train_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, T_thresh, weights_sum.data_ptr(), ambient_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} + + +// grad_weights_sum: [N,] +// grad: [N, 3] +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N,], weights_sum here +// image: [N, 3] +// grad_sigmas: [M] +// grad_rgbs: [M, 3] +template +__global__ void kernel_composite_rays_train_backward( + const scalar_t * __restrict__ grad_weights_sum, + const scalar_t * __restrict__ grad_ambient_sum, + const scalar_t * __restrict__ grad_image, + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ ambient, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const scalar_t * __restrict__ weights_sum, + const scalar_t * __restrict__ ambient_sum, + const scalar_t * __restrict__ image, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * grad_sigmas, + scalar_t * grad_rgbs, + scalar_t * grad_ambient +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + if (num_steps == 0 || offset + num_steps > M) return; + + grad_weights_sum += index; + grad_ambient_sum += index; + grad_image += index * 3; + weights_sum += index; + ambient_sum += index; + image += index * 3; + + sigmas += offset; + rgbs += offset * 3; + ambient += offset; + deltas += offset * 2; + + grad_sigmas += offset; + grad_rgbs += offset * 3; + grad_ambient += offset; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0]; + scalar_t r = 0, g = 0, b = 0, ws = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + // amb += weight * ambient[0]; + ws += weight; + + T *= 1.0f - alpha; + + // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. + // write grad_rgbs + grad_rgbs[0] = grad_image[0] * weight; + grad_rgbs[1] = grad_image[1] * weight; + grad_rgbs[2] = grad_image[2] * weight; + + // write grad_ambient + grad_ambient[0] = grad_ambient_sum[0]; + + // write grad_sigmas + grad_sigmas[0] = deltas[0] * ( + grad_image[0] * (T * rgbs[0] - (r_final - r)) + + grad_image[1] * (T * rgbs[1] - (g_final - g)) + + grad_image[2] * (T * rgbs[2] - (b_final - b)) + + // grad_ambient_sum[0] * (T * ambient[0] - (amb_final - amb)) + + grad_weights_sum[0] * (1 - ws_final) + ); + + //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); + // minimal remained transmittence + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + // ambient++; + deltas += 2; + grad_sigmas++; + grad_rgbs += 3; + grad_ambient++; + + step++; + } +} + + +void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_image.scalar_type(), "composite_rays_train_backward", ([&] { + kernel_composite_rays_train_backward<<>>(grad_weights_sum.data_ptr(), grad_ambient_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), ambient_sum.data_ptr(), image.data_ptr(), M, N, T_thresh, grad_sigmas.data_ptr(), grad_rgbs.data_ptr(), grad_ambient.data_ptr()); + })); +} + + +//////////////////////////////////////////////////// +///////////// infernce ///////////// +//////////////////////////////////////////////////// + +template +__global__ void kernel_march_rays( + const uint32_t n_alive, + const uint32_t n_step, + const int* __restrict__ rays_alive, + const scalar_t* __restrict__ rays_t, + const scalar_t* __restrict__ rays_o, + const scalar_t* __restrict__ rays_d, + const float bound, + const float dt_gamma, const uint32_t max_steps, + const uint32_t C, const uint32_t H, + const uint8_t * __restrict__ grid, + const scalar_t* __restrict__ nears, + const scalar_t* __restrict__ fars, + scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas, + const scalar_t* __restrict__ noises +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + const float noise = noises[n]; + + // locate + rays_o += index * 3; + rays_d += index * 3; + xyzs += n * n_step * 3; + dirs += n * n_step * 3; + deltas += n * n_step * 2; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + const float rH = 1 / (float)H; + const float H3 = H * H * H; + + float t = rays_t[index]; // current ray's t + const float near = nears[index], far = fars[index]; + + const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; + const float dt_min = fminf(dt_max, 2 * SQRT3() / max_steps); + + // march for n_step steps, record points + uint32_t step = 0; + + // introduce some randomness + t += clamp(t * dt_gamma, dt_min, dt_max) * noise; + + while (t < far && step < n_step) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + if (occ) { + // write step + xyzs[0] = x; + xyzs[1] = y; + xyzs[2] = z; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; + // calc dt + t += dt; + deltas[0] = dt; + deltas[1] = t; // used to calc depth + // step + xyzs += 3; + dirs += 3; + deltas += 2; + step++; + + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } +} + + +void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) { + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "march_rays", ([&] { + kernel_march_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, dt_gamma, max_steps, C, H, grid.data_ptr(), near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), noises.data_ptr()); + })); +} + + +template +__global__ void kernel_composite_rays( + const uint32_t n_alive, + const uint32_t n_step, + const float T_thresh, + int* rays_alive, + scalar_t* rays_t, + const scalar_t* __restrict__ sigmas, + const scalar_t* __restrict__ rgbs, + const scalar_t* __restrict__ deltas, + scalar_t* weights_sum, scalar_t* depth, scalar_t* image +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + + // locate + sigmas += n * n_step; + rgbs += n * n_step * 3; + deltas += n * n_step * 2; + + rays_t += index; + weights_sum += index; + depth += index; + image += index * 3; + + scalar_t t = rays_t[0]; // current ray's t + + scalar_t weight_sum = weights_sum[0]; + scalar_t d = depth[0]; + scalar_t r = image[0]; + scalar_t g = image[1]; + scalar_t b = image[2]; + + // accumulate + uint32_t step = 0; + while (step < n_step) { + + // ray is terminated if delta == 0 + if (deltas[0] == 0) break; + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + + /* + T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) + w_i = alpha_i * T_i + --> + T_i = 1 - \sum_{j=0}^{i-1} w_j + */ + const scalar_t T = 1 - weight_sum; + const scalar_t weight = alpha * T; + weight_sum += weight; + + t = deltas[1]; + d += weight * t; + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // ray is terminated if T is too small + // use a larger bound to further accelerate inference + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // rays_alive = -1 means ray is terminated early. + if (step < n_step) { + rays_alive[n] = -1; + } else { + rays_t[0] = t; + } + + weights_sum[0] = weight_sum; // this is the thing I needed! + depth[0] = d; + image[0] = r; + image[1] = g; + image[2] = b; +} + + +void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) { + static constexpr uint32_t N_THREAD = 128; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + image.scalar_type(), "composite_rays", ([&] { + kernel_composite_rays<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} + + + +template +__global__ void kernel_composite_rays_ambient( + const uint32_t n_alive, + const uint32_t n_step, + const float T_thresh, + int* rays_alive, + scalar_t* rays_t, + const scalar_t* __restrict__ sigmas, + const scalar_t* __restrict__ rgbs, + const scalar_t* __restrict__ deltas, + const scalar_t* __restrict__ ambients, + scalar_t* weights_sum, scalar_t* depth, scalar_t* image, scalar_t* ambient_sum +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + + // locate + sigmas += n * n_step; + rgbs += n * n_step * 3; + deltas += n * n_step * 2; + ambients += n * n_step; + + rays_t += index; + weights_sum += index; + depth += index; + image += index * 3; + ambient_sum += index; + + scalar_t t = rays_t[0]; // current ray's t + + scalar_t weight_sum = weights_sum[0]; + scalar_t d = depth[0]; + scalar_t r = image[0]; + scalar_t g = image[1]; + scalar_t b = image[2]; + scalar_t a = ambient_sum[0]; + + // accumulate + uint32_t step = 0; + while (step < n_step) { + + // ray is terminated if delta == 0 + if (deltas[0] == 0) break; + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + + /* + T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) + w_i = alpha_i * T_i + --> + T_i = 1 - \sum_{j=0}^{i-1} w_j + */ + const scalar_t T = 1 - weight_sum; + const scalar_t weight = alpha * T; + weight_sum += weight; + + t = deltas[1]; + d += weight * t; + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + a += ambients[0]; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // ray is terminated if T is too small + // use a larger bound to further accelerate inference + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + step++; + ambients++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // rays_alive = -1 means ray is terminated early. + if (step < n_step) { + rays_alive[n] = -1; + } else { + rays_t[0] = t; + } + + weights_sum[0] = weight_sum; // this is the thing I needed! + depth[0] = d; + image[0] = r; + image[1] = g; + image[2] = b; + ambient_sum[0] = a; +} + + +void composite_rays_ambient(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum) { + static constexpr uint32_t N_THREAD = 128; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + image.scalar_type(), "composite_rays_ambient", ([&] { + kernel_composite_rays_ambient<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), ambients.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr(), ambient_sum.data_ptr()); + })); +} + + + + + + +// -------------------------------- sigma ambient ----------------------------- + +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N], final pixel alpha +// depth: [N,] +// image: [N, 3] +template +__global__ void kernel_composite_rays_train_sigma_forward( + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ ambient, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * weights_sum, + scalar_t * ambient_sum, + scalar_t * depth, + scalar_t * image +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps > M) { + weights_sum[index] = 0; + ambient_sum[index] = 0; + depth[index] = 0; + image[index * 3] = 0; + image[index * 3 + 1] = 0; + image[index * 3 + 2] = 0; + return; + } + + sigmas += offset; + rgbs += offset * 3; + ambient += offset; + deltas += offset * 2; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + scalar_t r = 0, g = 0, b = 0, ws = 0, d = 0, amb = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + d += weight * deltas[1]; + + ws += weight; + + amb += weight * ambient[0]; + + T *= 1.0f - alpha; + + // minimal remained transmittence + if (T < T_thresh) break; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // locate + sigmas++; + rgbs += 3; + ambient++; + deltas += 2; + + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // write + weights_sum[index] = ws; // weights_sum + ambient_sum[index] = amb; + depth[index] = d; + image[index * 3] = r; + image[index * 3 + 1] = g; + image[index * 3 + 2] = b; +} + + +void composite_rays_train_sigma_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + sigmas.scalar_type(), "composite_rays_train_sigma_forward", ([&] { + kernel_composite_rays_train_sigma_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, T_thresh, weights_sum.data_ptr(), ambient_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} + + +// grad_weights_sum: [N,] +// grad: [N, 3] +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N,], weights_sum here +// image: [N, 3] +// grad_sigmas: [M] +// grad_rgbs: [M, 3] +template +__global__ void kernel_composite_rays_train_sigma_backward( + const scalar_t * __restrict__ grad_weights_sum, + const scalar_t * __restrict__ grad_ambient_sum, + const scalar_t * __restrict__ grad_image, + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ ambient, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const scalar_t * __restrict__ weights_sum, + const scalar_t * __restrict__ ambient_sum, + const scalar_t * __restrict__ image, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * grad_sigmas, + scalar_t * grad_rgbs, + scalar_t * grad_ambient +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + if (num_steps == 0 || offset + num_steps > M) return; + + grad_weights_sum += index; + grad_ambient_sum += index; + grad_image += index * 3; + weights_sum += index; + ambient_sum += index; + image += index * 3; + + sigmas += offset; + rgbs += offset * 3; + ambient += offset; + deltas += offset * 2; + + grad_sigmas += offset; + grad_rgbs += offset * 3; + grad_ambient += offset; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0], amb_final = ambient_sum[0]; + scalar_t r = 0, g = 0, b = 0, ws = 0, amb = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + amb += weight * ambient[0]; + ws += weight; + + T *= 1.0f - alpha; + + // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. + // write grad_rgbs + grad_rgbs[0] = grad_image[0] * weight; + grad_rgbs[1] = grad_image[1] * weight; + grad_rgbs[2] = grad_image[2] * weight; + + // write grad_ambient + grad_ambient[0] = grad_ambient_sum[0] * weight; + + // write grad_sigmas + grad_sigmas[0] = deltas[0] * ( + grad_image[0] * (T * rgbs[0] - (r_final - r)) + + grad_image[1] * (T * rgbs[1] - (g_final - g)) + + grad_image[2] * (T * rgbs[2] - (b_final - b)) + + grad_ambient_sum[0] * (T * ambient[0] - (amb_final - amb)) + + grad_weights_sum[0] * (1 - ws_final) + ); + + //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); + // minimal remained transmittence + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + ambient++; + deltas += 2; + grad_sigmas++; + grad_rgbs += 3; + grad_ambient++; + + step++; + } +} + + +void composite_rays_train_sigma_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_image.scalar_type(), "composite_rays_train_sigma_backward", ([&] { + kernel_composite_rays_train_sigma_backward<<>>(grad_weights_sum.data_ptr(), grad_ambient_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), ambient_sum.data_ptr(), image.data_ptr(), M, N, T_thresh, grad_sigmas.data_ptr(), grad_rgbs.data_ptr(), grad_ambient.data_ptr()); + })); +} + + +//////////////////////////////////////////////////// +///////////// infernce ///////////// +//////////////////////////////////////////////////// + + +template +__global__ void kernel_composite_rays_ambient_sigma( + const uint32_t n_alive, + const uint32_t n_step, + const float T_thresh, + int* rays_alive, + scalar_t* rays_t, + const scalar_t* __restrict__ sigmas, + const scalar_t* __restrict__ rgbs, + const scalar_t* __restrict__ deltas, + const scalar_t* __restrict__ ambients, + scalar_t* weights_sum, scalar_t* depth, scalar_t* image, scalar_t* ambient_sum +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + + // locate + sigmas += n * n_step; + rgbs += n * n_step * 3; + deltas += n * n_step * 2; + ambients += n * n_step; + + rays_t += index; + weights_sum += index; + depth += index; + image += index * 3; + ambient_sum += index; + + scalar_t t = rays_t[0]; // current ray's t + + scalar_t weight_sum = weights_sum[0]; + scalar_t d = depth[0]; + scalar_t r = image[0]; + scalar_t g = image[1]; + scalar_t b = image[2]; + scalar_t a = ambient_sum[0]; + + // accumulate + uint32_t step = 0; + while (step < n_step) { + + // ray is terminated if delta == 0 + if (deltas[0] == 0) break; + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + + /* + T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) + w_i = alpha_i * T_i + --> + T_i = 1 - \sum_{j=0}^{i-1} w_j + */ + const scalar_t T = 1 - weight_sum; + const scalar_t weight = alpha * T; + weight_sum += weight; + + t = deltas[1]; + d += weight * t; + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + a += weight * ambients[0]; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // ray is terminated if T is too small + // use a larger bound to further accelerate inference + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + step++; + ambients++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // rays_alive = -1 means ray is terminated early. + if (step < n_step) { + rays_alive[n] = -1; + } else { + rays_t[0] = t; + } + + weights_sum[0] = weight_sum; // this is the thing I needed! + depth[0] = d; + image[0] = r; + image[1] = g; + image[2] = b; + ambient_sum[0] = a; +} + + +void composite_rays_ambient_sigma(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum) { + static constexpr uint32_t N_THREAD = 128; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + image.scalar_type(), "composite_rays_ambient_sigma", ([&] { + kernel_composite_rays_ambient_sigma<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), ambients.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr(), ambient_sum.data_ptr()); + })); +} + + + + + + + +// -------------------------------- uncertainty ----------------------------- + +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N], final pixel alpha +// depth: [N,] +// image: [N, 3] +template +__global__ void kernel_composite_rays_train_uncertainty_forward( + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ ambient, + const scalar_t * __restrict__ uncertainty, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * weights_sum, + scalar_t * ambient_sum, + scalar_t * uncertainty_sum, + scalar_t * depth, + scalar_t * image +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps > M) { + weights_sum[index] = 0; + ambient_sum[index] = 0; + uncertainty_sum[index] = 0; + depth[index] = 0; + image[index * 3] = 0; + image[index * 3 + 1] = 0; + image[index * 3 + 2] = 0; + return; + } + + sigmas += offset; + rgbs += offset * 3; + ambient += offset; + uncertainty += offset; + deltas += offset * 2; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + scalar_t r = 0, g = 0, b = 0, ws = 0, d = 0, amb = 0, unc = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + d += weight * deltas[1]; + + ws += weight; + + amb += ambient[0]; + unc += weight * uncertainty[0]; + + T *= 1.0f - alpha; + + // minimal remained transmittence + if (T < T_thresh) break; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // locate + sigmas++; + rgbs += 3; + ambient++; + uncertainty++; + deltas += 2; + + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // write + weights_sum[index] = ws; // weights_sum + ambient_sum[index] = amb; + uncertainty_sum[index] = unc; + depth[index] = d; + image[index * 3] = r; + image[index * 3 + 1] = g; + image[index * 3 + 2] = b; +} + + +void composite_rays_train_uncertainty_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor uncertainty_sum, at::Tensor depth, at::Tensor image) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + sigmas.scalar_type(), "composite_rays_train_uncertainty_forward", ([&] { + kernel_composite_rays_train_uncertainty_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), uncertainty.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, T_thresh, weights_sum.data_ptr(), ambient_sum.data_ptr(), uncertainty_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} + + +// grad_weights_sum: [N,] +// grad: [N, 3] +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N,], weights_sum here +// image: [N, 3] +// grad_sigmas: [M] +// grad_rgbs: [M, 3] +template +__global__ void kernel_composite_rays_train_uncertainty_backward( + const scalar_t * __restrict__ grad_weights_sum, + const scalar_t * __restrict__ grad_ambient_sum, + const scalar_t * __restrict__ grad_uncertainty_sum, + const scalar_t * __restrict__ grad_image, + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ ambient, + const scalar_t * __restrict__ uncertainty, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const scalar_t * __restrict__ weights_sum, + const scalar_t * __restrict__ ambient_sum, + const scalar_t * __restrict__ uncertainty_sum, + const scalar_t * __restrict__ image, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * grad_sigmas, + scalar_t * grad_rgbs, + scalar_t * grad_ambient, + scalar_t * grad_uncertainty +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + if (num_steps == 0 || offset + num_steps > M) return; + + grad_weights_sum += index; + grad_ambient_sum += index; + grad_uncertainty_sum += index; + grad_image += index * 3; + weights_sum += index; + ambient_sum += index; + uncertainty_sum += index; + image += index * 3; + + sigmas += offset; + rgbs += offset * 3; + ambient += offset; + uncertainty += offset; + deltas += offset * 2; + + grad_sigmas += offset; + grad_rgbs += offset * 3; + grad_ambient += offset; + grad_uncertainty += offset; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0], amb_final = ambient_sum[0], unc_final = uncertainty_sum[0]; + scalar_t r = 0, g = 0, b = 0, ws = 0, amb = 0, unc = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + // amb += ambient[0]; + unc += weight * uncertainty[0]; + ws += weight; + + T *= 1.0f - alpha; + + // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. + // write grad_rgbs + grad_rgbs[0] = grad_image[0] * weight; + grad_rgbs[1] = grad_image[1] * weight; + grad_rgbs[2] = grad_image[2] * weight; + + // write grad_ambient + grad_ambient[0] = grad_ambient_sum[0]; + + // write grad_unc + grad_uncertainty[0] = grad_uncertainty_sum[0] * weight; + + // write grad_sigmas + grad_sigmas[0] = deltas[0] * ( + grad_image[0] * (T * rgbs[0] - (r_final - r)) + + grad_image[1] * (T * rgbs[1] - (g_final - g)) + + grad_image[2] * (T * rgbs[2] - (b_final - b)) + + // grad_ambient_sum[0] * (T * ambient[0] - (amb_final - amb)) + + grad_uncertainty_sum[0] * (T * uncertainty[0] - (unc_final - unc)) + + grad_weights_sum[0] * (1 - ws_final) + ); + + //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); + // minimal remained transmittence + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + // ambient++; + uncertainty++; + deltas += 2; + grad_sigmas++; + grad_rgbs += 3; + grad_ambient++; + grad_uncertainty++; + + step++; + } +} + + +void composite_rays_train_uncertainty_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_uncertainty_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor uncertainty_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient, at::Tensor grad_uncertainty) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_image.scalar_type(), "composite_rays_train_uncertainty_backward", ([&] { + kernel_composite_rays_train_uncertainty_backward<<>>(grad_weights_sum.data_ptr(), grad_ambient_sum.data_ptr(), grad_uncertainty_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), ambient.data_ptr(), uncertainty.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), ambient_sum.data_ptr(), uncertainty_sum.data_ptr(), image.data_ptr(), M, N, T_thresh, grad_sigmas.data_ptr(), grad_rgbs.data_ptr(), grad_ambient.data_ptr(), grad_uncertainty.data_ptr()); + })); +} + + +//////////////////////////////////////////////////// +///////////// infernce ///////////// +//////////////////////////////////////////////////// + + +template +__global__ void kernel_composite_rays_uncertainty( + const uint32_t n_alive, + const uint32_t n_step, + const float T_thresh, + int* rays_alive, + scalar_t* rays_t, + const scalar_t* __restrict__ sigmas, + const scalar_t* __restrict__ rgbs, + const scalar_t* __restrict__ deltas, + const scalar_t* __restrict__ ambients, + const scalar_t* __restrict__ uncertainties, + scalar_t* weights_sum, scalar_t* depth, scalar_t* image, scalar_t* ambient_sum, scalar_t* uncertainty_sum +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + + // locate + sigmas += n * n_step; + rgbs += n * n_step * 3; + deltas += n * n_step * 2; + ambients += n * n_step; + uncertainties += n * n_step; + + rays_t += index; + weights_sum += index; + depth += index; + image += index * 3; + ambient_sum += index; + uncertainty_sum += index; + + scalar_t t = rays_t[0]; // current ray's t + + scalar_t weight_sum = weights_sum[0]; + scalar_t d = depth[0]; + scalar_t r = image[0]; + scalar_t g = image[1]; + scalar_t b = image[2]; + scalar_t a = ambient_sum[0]; + scalar_t u = uncertainty_sum[0]; + + // accumulate + uint32_t step = 0; + while (step < n_step) { + + // ray is terminated if delta == 0 + if (deltas[0] == 0) break; + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + + /* + T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) + w_i = alpha_i * T_i + --> + T_i = 1 - \sum_{j=0}^{i-1} w_j + */ + const scalar_t T = 1 - weight_sum; + const scalar_t weight = alpha * T; + weight_sum += weight; + + t = deltas[1]; + d += weight * t; + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + a += ambients[0]; + u += weight * uncertainties[0]; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // ray is terminated if T is too small + // use a larger bound to further accelerate inference + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + step++; + ambients++; + uncertainties++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // rays_alive = -1 means ray is terminated early. + if (step < n_step) { + rays_alive[n] = -1; + } else { + rays_t[0] = t; + } + + weights_sum[0] = weight_sum; // this is the thing I needed! + depth[0] = d; + image[0] = r; + image[1] = g; + image[2] = b; + ambient_sum[0] = a; + uncertainty_sum[0] = u; +} + + +void composite_rays_uncertainty(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor uncertainties, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum, at::Tensor uncertainty_sum) { + static constexpr uint32_t N_THREAD = 128; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + image.scalar_type(), "composite_rays_uncertainty", ([&] { + kernel_composite_rays_uncertainty<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), ambients.data_ptr(), uncertainties.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr(), ambient_sum.data_ptr(), uncertainty_sum.data_ptr()); + })); +} + + + + +// -------------------------------- triplane ----------------------------- + +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N], final pixel alpha +// depth: [N,] +// image: [N, 3] +template +__global__ void kernel_composite_rays_train_triplane_forward( + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ amb_aud, + const scalar_t * __restrict__ amb_eye, + const scalar_t * __restrict__ uncertainty, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * weights_sum, + scalar_t * amb_aud_sum, + scalar_t * amb_eye_sum, + scalar_t * uncertainty_sum, + scalar_t * depth, + scalar_t * image +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps > M) { + weights_sum[index] = 0; + amb_aud_sum[index] = 0; + amb_eye_sum[index] = 0; + uncertainty_sum[index] = 0; + depth[index] = 0; + image[index * 3] = 0; + image[index * 3 + 1] = 0; + image[index * 3 + 2] = 0; + return; + } + + sigmas += offset; + rgbs += offset * 3; + amb_aud += offset; + amb_eye += offset; + uncertainty += offset; + deltas += offset * 2; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + scalar_t r = 0, g = 0, b = 0, ws = 0, d = 0, a_aud = 0, a_eye=0, unc = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + d += weight * deltas[1]; + + ws += weight; + + a_aud += amb_aud[0]; + a_eye += amb_eye[0]; + unc += weight * uncertainty[0]; + + T *= 1.0f - alpha; + + // minimal remained transmittence + if (T < T_thresh) break; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // locate + sigmas++; + rgbs += 3; + amb_aud++; + amb_eye++; + uncertainty++; + deltas += 2; + + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // write + weights_sum[index] = ws; // weights_sum + amb_aud_sum[index] = a_aud; + amb_eye_sum[index] = a_eye; + uncertainty_sum[index] = unc; + depth[index] = d; + image[index * 3] = r; + image[index * 3 + 1] = g; + image[index * 3 + 2] = b; +} + + +void composite_rays_train_triplane_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor amb_aud, const at::Tensor amb_eye, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor amb_aud_sum, at::Tensor amb_eye_sum, at::Tensor uncertainty_sum, at::Tensor depth, at::Tensor image) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + sigmas.scalar_type(), "composite_rays_train_triplane_forward", ([&] { + kernel_composite_rays_train_triplane_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), amb_aud.data_ptr(), amb_eye.data_ptr(), uncertainty.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, T_thresh, weights_sum.data_ptr(), amb_aud_sum.data_ptr(), amb_eye_sum.data_ptr(), uncertainty_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} + + +// grad_weights_sum: [N,] +// grad: [N, 3] +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N,], weights_sum here +// image: [N, 3] +// grad_sigmas: [M] +// grad_rgbs: [M, 3] +template +__global__ void kernel_composite_rays_train_triplane_backward( + const scalar_t * __restrict__ grad_weights_sum, + const scalar_t * __restrict__ grad_amb_aud_sum, + const scalar_t * __restrict__ grad_amb_eye_sum, + const scalar_t * __restrict__ grad_uncertainty_sum, + const scalar_t * __restrict__ grad_image, + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ amb_aud, + const scalar_t * __restrict__ amb_eye, + const scalar_t * __restrict__ uncertainty, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const scalar_t * __restrict__ weights_sum, + const scalar_t * __restrict__ amb_aud_sum, + const scalar_t * __restrict__ amb_eye_sum, + const scalar_t * __restrict__ uncertainty_sum, + const scalar_t * __restrict__ image, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * grad_sigmas, + scalar_t * grad_rgbs, + scalar_t * grad_amb_aud, + scalar_t * grad_amb_eye, + scalar_t * grad_uncertainty +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + if (num_steps == 0 || offset + num_steps > M) return; + + grad_weights_sum += index; + grad_amb_aud_sum += index; + grad_amb_eye_sum += index; + grad_uncertainty_sum += index; + grad_image += index * 3; + weights_sum += index; + amb_aud_sum += index; + amb_eye_sum += index; + uncertainty_sum += index; + image += index * 3; + + sigmas += offset; + rgbs += offset * 3; + amb_aud += offset; + amb_eye += offset; + uncertainty += offset; + deltas += offset * 2; + + grad_sigmas += offset; + grad_rgbs += offset * 3; + grad_amb_aud += offset; + grad_amb_eye += offset; + grad_uncertainty += offset; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0], unc_final = uncertainty_sum[0]; + scalar_t r = 0, g = 0, b = 0, ws = 0, amb = 0, unc = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + // amb += ambient[0]; + unc += weight * uncertainty[0]; + ws += weight; + + T *= 1.0f - alpha; + + // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. + // write grad_rgbs + grad_rgbs[0] = grad_image[0] * weight; + grad_rgbs[1] = grad_image[1] * weight; + grad_rgbs[2] = grad_image[2] * weight; + + // write grad_ambient + grad_amb_aud[0] = grad_amb_aud_sum[0]; + grad_amb_eye[0] = grad_amb_eye_sum[0]; + + // write grad_unc + grad_uncertainty[0] = grad_uncertainty_sum[0] * weight; + + // write grad_sigmas + grad_sigmas[0] = deltas[0] * ( + grad_image[0] * (T * rgbs[0] - (r_final - r)) + + grad_image[1] * (T * rgbs[1] - (g_final - g)) + + grad_image[2] * (T * rgbs[2] - (b_final - b)) + + // grad_ambient_sum[0] * (T * ambient[0] - (amb_final - amb)) + + grad_uncertainty_sum[0] * (T * uncertainty[0] - (unc_final - unc)) + + grad_weights_sum[0] * (1 - ws_final) + ); + + //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); + // minimal remained transmittence + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + // ambient++; + uncertainty++; + deltas += 2; + grad_sigmas++; + grad_rgbs += 3; + grad_amb_aud++; + grad_amb_eye++; + grad_uncertainty++; + + step++; + } +} + + +void composite_rays_train_triplane_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_amb_aud_sum, const at::Tensor grad_amb_eye_sum, const at::Tensor grad_uncertainty_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor amb_aud, const at::Tensor amb_eye, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor amb_aud_sum, const at::Tensor amb_eye_sum, const at::Tensor uncertainty_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_amb_aud, at::Tensor grad_amb_eye, at::Tensor grad_uncertainty) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_image.scalar_type(), "composite_rays_train_triplane_backward", ([&] { + kernel_composite_rays_train_triplane_backward<<>>(grad_weights_sum.data_ptr(), grad_amb_aud_sum.data_ptr(), grad_amb_eye_sum.data_ptr(), grad_uncertainty_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), amb_aud.data_ptr(), amb_eye.data_ptr(), uncertainty.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), amb_aud_sum.data_ptr(), amb_eye_sum.data_ptr(), uncertainty_sum.data_ptr(), image.data_ptr(), M, N, T_thresh, grad_sigmas.data_ptr(), grad_rgbs.data_ptr(), grad_amb_aud.data_ptr(), grad_amb_eye.data_ptr(), grad_uncertainty.data_ptr()); + })); +} + + +//////////////////////////////////////////////////// +///////////// infernce ///////////// +//////////////////////////////////////////////////// + + +template +__global__ void kernel_composite_rays_triplane( + const uint32_t n_alive, + const uint32_t n_step, + const float T_thresh, + int* rays_alive, + scalar_t* rays_t, + const scalar_t* __restrict__ sigmas, + const scalar_t* __restrict__ rgbs, + const scalar_t* __restrict__ deltas, + const scalar_t* __restrict__ ambs_aud, + const scalar_t* __restrict__ ambs_eye, + const scalar_t* __restrict__ uncertainties, + scalar_t* weights_sum, scalar_t* depth, scalar_t* image, scalar_t* amb_aud_sum, scalar_t* amb_eye_sum, scalar_t* uncertainty_sum +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + + // locate + sigmas += n * n_step; + rgbs += n * n_step * 3; + deltas += n * n_step * 2; + ambs_aud += n * n_step; + ambs_eye += n * n_step; + uncertainties += n * n_step; + + rays_t += index; + weights_sum += index; + depth += index; + image += index * 3; + amb_aud_sum += index; + amb_eye_sum += index; + uncertainty_sum += index; + + scalar_t t = rays_t[0]; // current ray's t + + scalar_t weight_sum = weights_sum[0]; + scalar_t d = depth[0]; + scalar_t r = image[0]; + scalar_t g = image[1]; + scalar_t b = image[2]; + scalar_t a_aud = amb_aud_sum[0]; + scalar_t a_eye = amb_eye_sum[0]; + scalar_t u = uncertainty_sum[0]; + + // accumulate + uint32_t step = 0; + while (step < n_step) { + + // ray is terminated if delta == 0 + if (deltas[0] == 0) break; + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + + /* + T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) + w_i = alpha_i * T_i + --> + T_i = 1 - \sum_{j=0}^{i-1} w_j + */ + const scalar_t T = 1 - weight_sum; + const scalar_t weight = alpha * T; + weight_sum += weight; + + t = deltas[1]; + d += weight * t; + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + a_aud += ambs_aud[0]; + a_eye += ambs_eye[0]; + u += weight * uncertainties[0]; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // ray is terminated if T is too small + // use a larger bound to further accelerate inference + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + step++; + ambs_aud++; + ambs_eye++; + uncertainties++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // rays_alive = -1 means ray is terminated early. + if (step < n_step) { + rays_alive[n] = -1; + } else { + rays_t[0] = t; + } + + weights_sum[0] = weight_sum; // this is the thing I needed! + depth[0] = d; + image[0] = r; + image[1] = g; + image[2] = b; + amb_aud_sum[0] = a_aud; + amb_eye_sum[0] = a_eye; + uncertainty_sum[0] = u; +} + + +void composite_rays_triplane(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambs_aud, at::Tensor ambs_eye, at::Tensor uncertainties, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor amb_aud_sum, at::Tensor amb_eye_sum, at::Tensor uncertainty_sum) { + static constexpr uint32_t N_THREAD = 128; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + image.scalar_type(), "composite_rays_triplane", ([&] { + kernel_composite_rays_triplane<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), ambs_aud.data_ptr(), ambs_eye.data_ptr(), uncertainties.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr(), amb_aud_sum.data_ptr(), amb_eye_sum.data_ptr(), uncertainty_sum.data_ptr()); + })); +} diff --git a/raymarching/src/raymarching.h b/raymarching/src/raymarching.h new file mode 100644 index 0000000..cd08969 --- /dev/null +++ b/raymarching/src/raymarching.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + + +void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); +void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); +void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); +void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); +void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); +void morton3D_dilation(const at::Tensor grid, const uint32_t C, const uint32_t H, at::Tensor grid_dilation); + +void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises); +void march_rays_train_backward(const at::Tensor grad_xyzs, const at::Tensor grad_dirs, const at::Tensor rays, const at::Tensor deltas, const uint32_t N, const uint32_t M, at::Tensor grad_rays_o, at::Tensor grad_rays_d); +void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image); +void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient); + +void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises); +void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); +void composite_rays_ambient(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum); + + +void composite_rays_train_sigma_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor depth, at::Tensor image); +void composite_rays_train_sigma_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient); + +void composite_rays_ambient_sigma(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum); + + +// uncertainty +void composite_rays_train_uncertainty_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor ambient_sum, at::Tensor uncertainty_sum, at::Tensor depth, at::Tensor image); +void composite_rays_train_uncertainty_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_ambient_sum, const at::Tensor grad_uncertainty_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ambient, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor ambient_sum, const at::Tensor uncertainty_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_ambient, at::Tensor grad_uncertainty); +void composite_rays_uncertainty(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambients, at::Tensor uncertainties, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor ambient_sum, at::Tensor uncertainty_sum); + +// triplane +void composite_rays_train_triplane_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor amb_aud, const at::Tensor amb_eye, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor amb_aud_sum, at::Tensor amb_eye_sum, at::Tensor uncertainty_sum, at::Tensor depth, at::Tensor image); +void composite_rays_train_triplane_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_amb_aud_sum, const at::Tensor grad_amb_eye_sum, const at::Tensor grad_uncertainty_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor amb_aud, const at::Tensor amb_eye, const at::Tensor uncertainty, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor amb_aud_sum, const at::Tensor amb_eye_sum, const at::Tensor uncertainty_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs, at::Tensor grad_amb_aud, at::Tensor grad_amb_eye, at::Tensor grad_uncertainty); +void composite_rays_triplane(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor ambs_aud, at::Tensor ambs_eye, at::Tensor uncertainties, at::Tensor weights, at::Tensor depth, at::Tensor image, at::Tensor amb_aud_sum, at::Tensor amb_eye_sum, at::Tensor uncertainty_sum); \ No newline at end of file diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..482e514 --- /dev/null +++ b/readme.md @@ -0,0 +1,140 @@ +# ER-NeRF + +This is the official repo for our ICCV2023 paper **Efficient Region-Aware Neural Radiance Fields for High-Fidelity Talking Portrait Synthesis**. + +![image](assets/main.png) + +## Install + +Tested on Ubuntu 18.04, Pytorch 1.12 and CUDA 11.3. + +### Install dependency + +```bash +conda install pytorch==1.12.1 cudatoolkit=11.3 -c pytorch +pip install -r requirements.txt +pip install "git+https://github.com/facebookresearch/pytorch3d.git" +pip install tensorflow-gpu==2.8.0 +``` + +### Preparation + +- Prepare face-parsing model. + + ```bash + wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_parsing/79999_iter.pth?raw=true -O data_utils/face_parsing/79999_iter.pth + ``` + +- Prepare the 3DMM model for head pose estimation. + + ```bash + wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/exp_info.npy?raw=true -O data_utils/face_tracking/3DMM/exp_info.npy + wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/keys_info.npy?raw=true -O data_utils/face_tracking/3DMM/keys_info.npy + wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/sub_mesh.obj?raw=true -O data_utils/face_tracking/3DMM/sub_mesh.obj + wget https://github.com/YudongGuo/AD-NeRF/blob/master/data_util/face_tracking/3DMM/topology_info.npy?raw=true -O data_utils/face_tracking/3DMM/topology_info.npy + ``` + +- Download 3DMM model from [Basel Face Model 2009](https://faces.dmi.unibas.ch/bfm/main.php?nav=1-1-0&id=details): + + ``` + cp 01_MorphableModel.mat data_util/face_tracking/3DMM/ + cd data_util/face_tracking + python convert_BFM.py + ``` + +## Datasets and pretrained models + +We get the experiment videos mainly from [DFRF](https://github.com/sstzal/DFRF) and YouTube. Due to copyright restrictions, we can't distribute them. You can download these videos and crop them by youself. Here is an example training video (Obama) from AD-NeRF with the resolution of 450x450. + +``` +mkdir -p data/obama +wget https://github.com/YudongGuo/AD-NeRF/blob/master/dataset/vids/Obama.mp4?raw=true -O data/obama/obama.mp4 +``` + +We also provide pretrained checkpoints on the Obama video clip. You can download and test them after completing the data pre-processing step by: + +```bash +python main.py data/obama/ --workspace trial_obama/ -O --test --ckpt trial_obama/checkpoints/ngp.pth # head +python main.py data/obama/ --workspace trial_obama/ -O --test --torso --ckpt trial_obama_torso/checkpoints/ngp.pth # head+torso +``` + +The test results should be about: + +| setting | PSNR | LPIPS | LMD | +| ---------- | ------ | ------ | ----- | +| head | 35.607 | 0.0178 | 2.525 | +| head+torso | 26.594 | 0.0446 | 2.550 | + +## Usage + +### Pre-processing Custom Training Video + +* Put training video under `data//.mp4`. + + The video **must be 25FPS, with all frames containing the talking person**. + The resolution should be about 512x512, and duration about 1-5 min. + +* Run script to process the video. (may take several hours) + + ```bash + python data_utils/process.py data//.mp4 + ``` + +### Audio Pre-process + +In our paper, we use DeepSpeech features for evaluation: + +```bash +python data_utils/deepspeech_features/extract_ds_features.py --input data/.wav # save to data/.npy +``` + +You can also try to extract audio features via Wav2Vec like [RAD-NeRF](https://github.com/ashawkey/RAD-NeRF) by: + +```bash +python nerf/asr.py --wav data/.wav --save_feats # save to data/_eo.npy +``` + +### Train + +First time running will take some time to compile the CUDA extensions. + +```bash +# train (head and lpips finetune) +python main.py data/obama/ --workspace trial_obama/ -O --iters 100000 +python main.py data/obama/ --workspace trial_obama/ -O --iters 125000 --finetune_lips --patch_size 32 + +# train (torso) +# .pth should be the latest checkpoint in trial_obama +python main.py data/obama/ --workspace trial_obama_torso/ -O --torso --head_ckpt .pth --iters 200000 +``` + +### Test + +```bash +# test on the test split +python main.py data/obama/ --workspace trial_obama/ -O --test # only render the head and use GT image for torso +python main.py data/obama/ --workspace trial_obama_torso/ -O --torso --test # render both head and torso +``` + +### Inference with target audio + +```bash +python main.py data/obama/ --workspace trial_obama_torso/ -O --torso --test --test_train --aud data/