Skip to content

Latest commit

 

History

History
92 lines (54 loc) · 2.63 KB

README.md

File metadata and controls

92 lines (54 loc) · 2.63 KB

ACFlow: Flow Models for Arbitrary Conditional Likelihoods

This is the official implementation of ACFlow.

Get Started

Prerequisites

refer to requirements.txt.

Download data

download CelebA, CIFAR10, MNIST and Omniglot to your local workspace. You might need to change the path for each dataset in datasets folder accordingly.

MNIST and CIFAR10 can be downloaded by torchvision. Links for CelebA and Omniglot are provided here. Please cite their work if you use this repo.

Train and Test

You can train your own model by the scripts provided below. Or you can download our pretrained weights form here.

CelebA

  • Train with Gaussian base likelihood
python scripts/train.py --cfg_file=./exp/celeba/rnvp/params.json
  • Train with autoregressive likelihood
python scripts/train_tan.py --cfg_file=./exp/celeba/tan/params.json
  • Compute log likelihood on testset and compute the PSNR and PRD scores using samples.
python  scripts/test.py --cfg_file=./exp/celeba/rnvp/params.json

NOTE: you can run this script for multiple times with different random seed to get mean score and standard deviation.

  • Compute joint likelihood p(x).
python scripts/test_joint.py --cfg_file=./exp/celeba/rnvp/params.json
  • Sample from arbitrary conditional distribution p(x_u | x_o) for multiple imputation.
python scripts/sample.py --cfg_file=./exp/celeba/rnvp/params.json
  • Sample the 'Best Guess' single imputation.
python scripts/sample_single.py --cfg_file=./exp/celeba/rnvp/params.json
  • Sample from joint distribution p(x).
python scripts/sample_joint.py --cfg_file=./exp/celeba/rnvp/params.json
  • Gibbs sampling
python scripts/gibbs_sampling.py --cfg_file=./exp/celeba/rnvp/params.json

Sample the upper and lower half condition on the remaining half.

Gibbs Sampling

MNIST

similar commands can be run. Config files are provided in exp/mnist folder.

Omniglot

similar commands can be run. Config files are provided in exp/omniglot folder.

CIFAR10

similar commands can be run. Config files are provided in exp/cifar folder.

Acknowledgements

Code for evaluating FID and PRD are adapted from their public implementations. Please cite their work if you use this repo.