Skip to content
/ DBFS Public

Pytorch Implmentation of Diffusion Bridges in Function Spaces (DBFS), NeurIPS 2024

License

Notifications You must be signed in to change notification settings

bw-park/DBFS

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Stochastic Optimal Control for Diffusion Bridges in Function Spaces (DBFS)

Byoungwoo Park1·Jungwon Choi1·Sungbin Lim2,3·Juho Lee1
1KAIST   2Korea University   3LG AI Research

Neurips

Stochastic Optimal Control for Diffusion Bridges in Function Spaces (DBFS) extends previous bridge matching algorithms to learn diffusion models between two infinite-dimensional distributions in a resolution-free manner.

Examples

pi_0 ⇆ pi_T Results (left: pi_0 → pi_T, right: pi_0 ← pi_T)
EMNIST ⇆ MNIST (32x32, observed)

drawing drawing

EMNIST ⇆ MNIST (64x64, unseen)

drawing drawing

EMNIST ⇆ MNIST (128x128, unseen)

drawing drawing

AFHQ-64 Wild ⇆ Cat (64x64, observed)

drawing drawing

AFHQ-64 Wild ⇆ Cat (128x128, unseen)

drawing drawing

Installation

This code is developed with Python3 and Pytorch. To set up an environment with the required packages,

  1. Create a virtual enviornment, for example:
conda create -n dbfs pip
conda activate dbfs
  1. Install Pytorch according to the official instructions.
  2. Install the requirements:
pip install -r requirements.txt

Download AFHQ dataset

Download the AFHQ dataset from stargan-v2, and save them in the dbfs/data directory.

You can also download the dataset with the following commands:

bash download.sh afhq-dataset

Sampling from trained models

You can download the model checkpoints from Google Drive and save them in the dbfs/checkpoint directory.

See dbfs/dbfs_{DATASET}_sample.ipynb for sampling from the trained models.

Training from scratch

We train DBFS with single or multi A6000 GPUs for each dataset.

You can also adjust the --batch_dim and --nproc-per-node options according to your local resources.

EMNIST ⇆ MNIST

For Single-GPU

CUDA_VISIBLE_DEVICES=0 python dbfs_mnist.py

For Multi-GPU

CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc-per-node 2 dbfs_mnist.py

AFHQ-64 Wild ⇆ Cat

For Multi-GPU

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc-per-node 8 dbfs_afhq.py

The running histories are available on Weights & Biases for reproducibility.

Reference

If you found our work useful for your research, please consider citing our work.

@inproceedings{
  park2024stochastic,
  title={Stochastic Optimal Control for Diffusion Bridges in Function Spaces},
  author={Byoungwoo Park and Jungwon Choi and Sungbin Lim and Juho Lee},
  booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
  year={2024},
  url={https://openreview.net/forum?id=WyQW4G57Zd}
}

Acknowledgements

Our code builds upon an outstanding open source projects and papers:

About

Pytorch Implmentation of Diffusion Bridges in Function Spaces (DBFS), NeurIPS 2024

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published