This repository contains experiments with Variational Autoencoders (VAEs) using KL divergence and Wasserstein distance as regularization losses.
- Setup
- Project Structure
- Experiment Description
- How to Run
- Results
- Future Work
- Acknowledgments
- License
- Python 3.8 or higher
- PyTorch
- ClearML for experiment tracking
- Matplotlib, Seaborn, and other dependencies as specified in the code
-
Clone the repository:
git clone https://github.com/kezouke/VAExperiment.git cd VAExperiment
-
Set up ClearML credentials in .ipynb file:
%env CLEARML_API_ACCESS_KEY=YOUR_ACCESS_KEY %env CLEARML_API_SECRET_KEY=YOUR_SECRET_KEY
README.md
: This file.vae_experiment.ipynb
: The main notebook containing the experiment code.weights/
: Directory to save trained models.data/
: Directory to store the MNIST dataset.
- VAE Architecture:
- Encoder: Two-layer fully connected network mapping input to latent space.
- Decoder: Two-layer fully connected network reconstructing the input from latent space.
- Latent Dimension: 10
-
KL Divergence Loss:
- Regularizes the latent distribution to match a standard Gaussian.
-
Wasserstein Loss:
- Uses a critic network to approximate the Wasserstein distance between the latent distribution and the prior.
- Includes gradient penalty to enforce Lipschitz continuity.
- Dataset: MNIST handwritten digits.
- Training Parameters:
- Batch size: 64
- Learning rate: 1e-3
- Epochs: 35
- Reconstructions: Comparing original and reconstructed images.
- Latent Space Visualization: Using t-SNE to visualize the latent space.
- Latent Distributions: KDE plots of latent dimensions.
- Interpolation: Visualizing smooth transitions between digits in latent space.
- Reconstruction Error: Histogram of reconstruction errors.
-
Train VAE with KL Loss:
- Uncomment and run the training cell in the notebook.
- The model will be saved to
weights/vae_kl_divergence.pth
.
-
Train W-VAE with Wasserstein Loss:
- Uncomment and run the training cell for the Wasserstein VAE.
- The model and critic will be saved to
weights/vae_wass_distance.pth
andweights/critic.pth
.
-
Generate Visualizations:
- Run the corresponding plotting functions in the notebook.
- Results will be displayed inline and saved as images.
Please read report.md
to get insights from VAExperiment :)
- Hyperparameter Tuning: Experiment with different latent dimensions, learning rates, and critic network architectures.
- Advanced Losses: Explore other forms of regularizers or hybrid losses.
- Applications: Use the trained VAEs for tasks like anomaly detection or data generation.
- Inspired by the ClearML tutorial for experiment tracking.
- MNIST dataset courtesy of Yann LeCun and contributors.
- This project is licensed under the MIT License - see the LICENSE file for details.