Skip to content

Generative Query Network by PyTorch

License

Notifications You must be signed in to change notification settings

Eagle-E/gqnlib

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

gqnlib (Work in progress)

Generative Query Network by PyTorch.

Requirements

  • Python == 3.7
  • PyTorch == 1.5.0

Requirements for example code

  • torchvision == 0.6.0
  • tqdm == 4.46.0
  • tensorflow == 2.2.0
  • tensorboardX == 2.0
  • matplotlib == 3.2.1

How to use

Set up environments

Clone repository.

git clone https://github.com/rnagumo/gqnlib.git
cd gqnlib

Install the package in virtual env.

python3 -m venv .venv
source .venv/bin/activate
pip3 install --upgrade pip
pip3 install .

Or use Docker and NVIDIA Container Toolkit. You can run container with GPUs by Docker 19.03+.

docker build -t gqnlib .
docker run --gpus all -it gqnlib bash

Install other requirements for sample code.

pip3 install tqdm==4.46.0 tensorflow==2.2.0 tensorboardX==2.0 matplotlib==3.2.1 torchvision==0.6.0

Prepare dataset

Dataset is provided by DeepMind as GQN dataset and SLIM dataset.

The following command will download the specified dataset and convert tfrecords into torch gziped files. This shell script uses gsutil command, which should be installed in advance (read here).

Caution: This process takes a very long time. For example, shepard_metzler_5_parts dataset which is the smallest one takes 2~3 hours on my PC with 32 GB memory.

Caution: This process creates very large size files. For example, original shepard_metzler_5_parts dataset contains 900 files (17 GB) for train and 100 files (5 GB) for test, and converted dataset contains 2,100 files (47 GB) for train and 400 files (12 GB) for test.

bash bin/download_scene.sh shepard_metzler_5_parts

Run experiment

Run training. bin/train.sh contains the necessary settings. This takes a very long time, 10~30 hours.

bash bin/train.sh

Example

Training

import pathlib
import torch
import gqnlib

# Prepare dataset and model
root = "./data/shepard_metzler_5_parts_torch/train/"
dataset = gqnlib.SceneDataset(root, 20)
model = gqnlib.GenerativeQueryNetwork()
optimizer = torch.optim.Adam(model.parameters())

model.train()
for batch in dataset:
    for data in batch:
        # Partition data into context and query
        data = gqnlib.partition_scene(*data)

        # Inference
        optimizer.zero_grad()
        loss_dict = model(*data)

        # Backward
        loss = loss_dict["loss"].mean()
        loss.backward()
        optimizer.step()

# Save checkpoints
p = pathlib.Path("./logs/tmp")
p.mkdir(exist_ok=True)

cp = {"model_state_dict": model.state_dict(),
      "optimizer_state_dict": optimizer.state_dict()}
torch.save(cp, p / "example.pt")

Use pre-trained model

import torch
import gqnlib

# Load pre-trained model
model = gqnlib.GenerativeQueryNetwork()
cp = torch.load("./logs/tmp/example.pt")
model.load_state_dict(cp["model_state_dict"])

# Data
root = "./data/shepard_metzler_5_parts_torch/train/"
dataset = gqnlib.SceneDataset(root, 20)
images, viewpoints = dataset[0][0]
x_c, v_c, x_q, v_q = gqnlib.partition_scene(images, viewpoints)

# Reconstruct and sample
with torch.no_grad():
    recon = model.reconstruct(x_c, v_c, x_q, v_q)
    sample = model.sample(x_c, v_c, v_q)

print(recon.size())  # -> torch.Size([20, 1, 3, 64, 64])
print(sample.size())  # -> torch.Size([20, 1, 3, 64, 64])

Reference

Original papers

Datasets

  • Datasets by DeepMind for GQN. GitHub
  • Datasetf by DeepMind for SLIM. GitHub

Codes

  • mushoku, chainer-gqn. GitHub
  • iShohei220, torch-gqn. GitHub
  • wohlert, generative-query-network-pytorch. GitHub
  • l3robot, gqn_datasets_translator. GitHub

About

Generative Query Network by PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 95.6%
  • Shell 3.5%
  • Other 0.9%