Generative Query Network by PyTorch.
- Python == 3.7
- PyTorch == 1.5.0
Requirements for example code
- torchvision == 0.6.0
- tqdm == 4.46.0
- tensorflow == 2.2.0
- tensorboardX == 2.0
- matplotlib == 3.2.1
Clone repository.
git clone https://github.com/rnagumo/gqnlib.git
cd gqnlib
Install the package in virtual env.
python3 -m venv .venv
source .venv/bin/activate
pip3 install --upgrade pip
pip3 install .
Or use Docker and NVIDIA Container Toolkit. You can run container with GPUs by Docker 19.03+.
docker build -t gqnlib .
docker run --gpus all -it gqnlib bash
Install other requirements for sample code.
pip3 install tqdm==4.46.0 tensorflow==2.2.0 tensorboardX==2.0 matplotlib==3.2.1 torchvision==0.6.0
Dataset is provided by DeepMind as GQN dataset and SLIM dataset.
The following command will download the specified dataset and convert tfrecords into torch gziped files. This shell script uses gsutil
command, which should be installed in advance (read here).
Caution: This process takes a very long time. For example, shepard_metzler_5_parts
dataset which is the smallest one takes 2~3 hours on my PC with 32 GB memory.
Caution: This process creates very large size files. For example, original shepard_metzler_5_parts
dataset contains 900 files (17 GB) for train and 100 files (5 GB) for test, and converted dataset contains 2,100 files (47 GB) for train and 400 files (12 GB) for test.
bash bin/download_scene.sh shepard_metzler_5_parts
Run training. bin/train.sh
contains the necessary settings. This takes a very long time, 10~30 hours.
bash bin/train.sh
import pathlib
import torch
import gqnlib
# Prepare dataset and model
root = "./data/shepard_metzler_5_parts_torch/train/"
dataset = gqnlib.SceneDataset(root, 20)
model = gqnlib.GenerativeQueryNetwork()
optimizer = torch.optim.Adam(model.parameters())
model.train()
for batch in dataset:
for data in batch:
# Partition data into context and query
data = gqnlib.partition_scene(*data)
# Inference
optimizer.zero_grad()
loss_dict = model(*data)
# Backward
loss = loss_dict["loss"].mean()
loss.backward()
optimizer.step()
# Save checkpoints
p = pathlib.Path("./logs/tmp")
p.mkdir(exist_ok=True)
cp = {"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()}
torch.save(cp, p / "example.pt")
import torch
import gqnlib
# Load pre-trained model
model = gqnlib.GenerativeQueryNetwork()
cp = torch.load("./logs/tmp/example.pt")
model.load_state_dict(cp["model_state_dict"])
# Data
root = "./data/shepard_metzler_5_parts_torch/train/"
dataset = gqnlib.SceneDataset(root, 20)
images, viewpoints = dataset[0][0]
x_c, v_c, x_q, v_q = gqnlib.partition_scene(images, viewpoints)
# Reconstruct and sample
with torch.no_grad():
recon = model.reconstruct(x_c, v_c, x_q, v_q)
sample = model.sample(x_c, v_c, v_q)
print(recon.size()) # -> torch.Size([20, 1, 3, 64, 64])
print(sample.size()) # -> torch.Size([20, 1, 3, 64, 64])
- S. M. Ali Eslami et al., "Neural scene representation and rendering," Science Vol. 360, Issue 6394, pp.1204-1210 (15 Jun 2018)
- A. Kumar et al., "Consistent Generative Query Network," arXiv
- T. Ramalho et al., "Encoding Spatial Relations from Natural Language," arXiv
- D. Rosenbaum et al., "Learning models for visual 3D localization with implicit mapping," arXiv
- DeepMind. Blog post