diff --git a/README.md b/README.md index 7e6aa1b..55337d5 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,9 @@ -# Neural Network Pruning -[Lucas Liebenwein](https://people.csail.mit.edu/lucasl/), -[Cenk Baykal](http://www.mit.edu/~baykal/), -[Alaa Maalouf](https://www.linkedin.com/in/alaa-maalouf/), -[Igor Gilitschenski](https://www.gilitschenski.org/igor/), -[Dan Feldman](http://people.csail.mit.edu/dannyf/), -[Daniela Rus](http://danielarus.csail.mit.edu/) +# torchprune +Main contributors of this code base: +[Lucas Liebenwein](http://www.mit.edu/~lucasl/), +[Cenk Baykal](http://www.mit.edu/~baykal/). + +Please check individual paper folders for authors of each paper.
@@ -15,10 +14,11 @@ This repository contains code to reproduce the results from the following papers: | Paper | Venue | Title & Link | | :---: | :---: | :--- | +| **Node** | NeurIPS 2021 | [Sparse Flows: Pruning Continuous-depth Models](https://proceedings.neurips.cc/paper/2021/hash/bf1b2f4b901c21a1d8645018ea9aeb05-Abstract.html) | | **ALDS** | NeurIPS 2021 | [Compressing Neural Networks: Towards Determining the Optimal Layer-wise Decomposition](https://arxiv.org/abs/2107.11442) | | **Lost** | MLSys 2021 | [Lost in Pruning: The Effects of Pruning Neural Networks beyond Test Accuracy](https://proceedings.mlsys.org/paper/2021/hash/2a79ea27c279e471f4d180b08d62b00a-Abstract.html) | | **PFP** | ICLR 2020 | [Provable Filter Pruning for Efficient Neural Networks](https://openreview.net/forum?id=BJxkOlSYDH) | -| **SiPP** | arXiv | [SiPPing Neural Networks: Sensitivity-informed Provable Pruning of Neural Networks](https://arxiv.org/abs/1910.05422) | +| **SiPP** | SIAM 2022 | [SiPPing Neural Networks: Sensitivity-informed Provable Pruning of Neural Networks](https://doi.org/10.1137/20M1383239) | ### Packages In addition, the repo also contains two stand-alone python packages that @@ -35,6 +35,7 @@ about the paper and scripts and parameter configuration to reproduce the exact results from the paper. | Paper | Location | | :---: | :---: | +| **Node** | [paper/node](./paper/node) | | **ALDS** | [paper/alds](./paper/alds) | | **Lost** | [paper/lost](./paper/lost) | | **PFP** | [paper/pfp](./paper/pfp) | @@ -98,14 +99,27 @@ using the codebase. | --- | --- | | [src/torchprune/README.md](./src/torchprune) | more details to prune neural networks, how to use and setup the data sets, how to implement custom pruning methods, and how to add your data sets and networks. | | [src/experiment/README.md](./src/experiment) | more details on how to configure and run your own experiments, and more information on how to re-produce the results. | +| [paper/node/README.md](./paper/node) | check out for more information on the [Node](https://proceedings.neurips.cc/paper/2021/hash/bf1b2f4b901c21a1d8645018ea9aeb05-Abstract.html) paper. | | [paper/alds/README.md](./paper/alds) | check out for more information on the [ALDS](https://arxiv.org/abs/2107.11442) paper. | | [paper/lost/README.md](./paper/lost) | check out for more information on the [Lost](https://proceedings.mlsys.org/paper/2021/hash/2a79ea27c279e471f4d180b08d62b00a-Abstract.html) paper. | | [paper/pfp/README.md](./paper/pfp) | check out for more information on the [PFP](https://openreview.net/forum?id=BJxkOlSYDH) paper. | -| [paper/sipp/README.md](./paper/sipp) | check out for more information on the [SiPP](https://arxiv.org/abs/1910.05422) paper. | +| [paper/sipp/README.md](./paper/sipp) | check out for more information on the [SiPP](https://doi.org/10.1137/20M1383239) paper. | ## Citations Please cite the respective papers when using our work. +### [Sparse flows: Pruning continuous-depth models](https://proceedings.neurips.cc/paper/2021/hash/bf1b2f4b901c21a1d8645018ea9aeb05-Abstract.html) +``` +@article{liebenwein2021sparse, + title={Sparse flows: Pruning continuous-depth models}, + author={Liebenwein, Lucas and Hasani, Ramin and Amini, Alexander and Rus, Daniela}, + journal={Advances in Neural Information Processing Systems}, + volume={34}, + pages={22628--22642}, + year={2021} +} +``` + ### [Towards Determining the Optimal Layer-wise Decomposition](https://arxiv.org/abs/2107.11442) ``` @inproceedings{liebenwein2021alds, @@ -140,12 +154,16 @@ url={https://openreview.net/forum?id=BJxkOlSYDH} } ``` -### [SiPPing Neural Networks](https://arxiv.org/abs/1910.05422) +### [SiPPing Neural Networks](https://doi.org/10.1137/20M1383239) (Weight Pruning) ``` -@article{baykal2019sipping, -title={SiPPing Neural Networks: Sensitivity-informed Provable Pruning of Neural Networks}, -author={Baykal, Cenk and Liebenwein, Lucas and Gilitschenski, Igor and Feldman, Dan and Rus, Daniela}, -journal={arXiv preprint arXiv:1910.05422}, -year={2019} +@article{baykal2022sensitivity, + title={Sensitivity-informed provable pruning of neural networks}, + author={Baykal, Cenk and Liebenwein, Lucas and Gilitschenski, Igor and Feldman, Dan and Rus, Daniela}, + journal={SIAM Journal on Mathematics of Data Science}, + volume={4}, + number={1}, + pages={26--45}, + year={2022}, + publisher={SIAM} } ``` \ No newline at end of file diff --git a/misc/imgs/node_overview.png b/misc/imgs/node_overview.png new file mode 100644 index 0000000..ff19684 Binary files /dev/null and b/misc/imgs/node_overview.png differ diff --git a/misc/requirements.txt b/misc/requirements.txt index efe3eb2..704fb17 100644 --- a/misc/requirements.txt +++ b/misc/requirements.txt @@ -3,10 +3,10 @@ -e ./src/experiment # We need those with special tags unfortunately... --f https://download.pytorch.org/whl/torch_stable.html -torch==1.7.1+cu110 -torchvision==0.8.2+cu110 -torchaudio===0.7.2 +-f https://download.pytorch.org/whl/lts/1.8/torch_lts.html +torch==1.8.2+cu111 +torchvision==0.9.2+cu111 +torchaudio==0.8.2 # Some extra requirements for the code base jupyter diff --git a/paper/alds/param/imagenet/prune/mbv2.yaml b/paper/alds/param/imagenet/prune/mobilenet_v2.yaml similarity index 100% rename from paper/alds/param/imagenet/prune/mbv2.yaml rename to paper/alds/param/imagenet/prune/mobilenet_v2.yaml diff --git a/paper/alds/param/imagenet/retrain/mbv2.yaml b/paper/alds/param/imagenet/retrain/mobilenet_v2.yaml similarity index 100% rename from paper/alds/param/imagenet/retrain/mbv2.yaml rename to paper/alds/param/imagenet/retrain/mobilenet_v2.yaml diff --git a/paper/alds/script/results_viewer.py b/paper/alds/script/results_viewer.py index 95b405a..f2c8d8b 100644 --- a/paper/alds/script/results_viewer.py +++ b/paper/alds/script/results_viewer.py @@ -38,7 +38,7 @@ TABLE_BOLD_THRESHOLD = 0.005 # auto-discover files from folder without "common.yaml" -FILES = glob.glob(os.path.join(FOLDER, "[!common]*.yaml")) +FILES = glob.glob(os.path.join(FOLDER, "*[!common]*.yaml")) def key_files(item): @@ -52,6 +52,7 @@ def key_files(item): "resnet18", "resnet101", "wide_resnet50_2", + "mobilenet_v2", "deeplabv3_resnet50", ] @@ -127,6 +128,8 @@ def get_results(file, logger, legend_on): elif "imagenet/prune" in file: graphers[0]._figure.gca().set_xlim([0, 87]) graphers[0]._figure.gca().set_ylim([-87, 5]) + elif "imagenet/retrain/mobilenet_v2" in file: + graphers[0]._figure.gca().set_ylim([-5, 0.5]) elif "imagenet/retrain/" in file: graphers[0]._figure.gca().set_ylim([-3.5, 1.5]) elif "imagenet/retraincascade" in file: @@ -317,6 +320,7 @@ def generate_table_entries( "resnet18": "ResNet18", "resnet101": "ResNet101", "wide_resnet50_2": "WRN50-2", + "mobilenet_v2": "MobileNetV2", "deeplabv3_resnet50": "DeeplabV3-ResNet50", } diff --git a/paper/node/README.md b/paper/node/README.md new file mode 100644 index 0000000..c24e74d --- /dev/null +++ b/paper/node/README.md @@ -0,0 +1,68 @@ +# Sparse flows: Pruning continuous-depth models +[Lucas Liebenwein*](https://people.csail.mit.edu/lucasl/), +[Ramin Hasani*](http://www.raminhasani.com), +[Alexander Amini](https://www.mit.edu/~amini/), +[Daniela Rus](http://danielarus.csail.mit.edu/) + +***Equal contribution** + +
+ +
+ + +Continuous deep learning architectures enable learning of flexible +probabilistic models for predictive modeling as neural ordinary differential +equations (ODEs), and for generative modeling as continuous normalizing flows. +In this work, we design a framework to decipher the internal dynamics of these +continuous depth models by pruning their network architectures. Our empirical +results suggest that pruning improves generalization for neural ODEs in +generative modeling. We empirically show that the improvement is because +pruning helps avoid mode- collapse and flatten the loss surface. Moreover, +pruning finds efficient neural ODE representations with up to 98% less +parameters compared to the original network, without loss of accuracy. We hope +our results will invigorate further research into the performance-size +trade-offs of modern continuous-depth models. + +## Setup +Check out the main [README.md](../../README.md) and the respective packages for +more information on the code base. + +## Overview + +### Run compression experiments +The experiment configurations are located [here](./param). To reproduce the +experiments for a specific configuration, run: +```bash +python -m experiment.main param/toy/ffjord/spirals/vanilla_l4_h64.yaml +``` + +The pruning experiments will be run fully automatically and store all the +results. + +### Experimental evaluations + +The [script](./script) contains the evaluation and plotting scripts to +evaluate and analyze the various experiments. Please take a look at each of +them to understand how to load the pruning experiments and how to analyze +the pruning experiments. + +Each plot and experiment presented in the paper can be reproduced this way. + +## Citation +Please cite the following paper when using our work. + +### Paper link +[Sparse flows: Pruning continuous-depth models](https://proceedings.neurips.cc/paper/2021/hash/bf1b2f4b901c21a1d8645018ea9aeb05-Abstract.html) + +### Bibtex +``` +@article{liebenwein2021sparse, + title={Sparse flows: Pruning continuous-depth models}, + author={Liebenwein, Lucas and Hasani, Ramin and Amini, Alexander and Rus, Daniela}, + journal={Advances in Neural Information Processing Systems}, + volume={34}, + pages={22628--22642}, + year={2021} +} +``` diff --git a/paper/node/param/cnf/cifar_multiscale.yaml b/paper/node/param/cnf/cifar_multiscale.yaml new file mode 100644 index 0000000..c08c237 --- /dev/null +++ b/paper/node/param/cnf/cifar_multiscale.yaml @@ -0,0 +1,68 @@ +network: + name: "ffjord_multiscale_cifar" + dataset: "CIFAR10" + outputSize: 10 + +training: + transformsTrain: + - type: RandomHorizontalFlip + kwargs: {} + transformsTest: [] + transformsFinal: + - type: Resize + kwargs: { size: 32 } + - type: ToTensor + kwargs: {} + - type: RandomNoise + kwargs: { "normalization": 255.0 } + + loss: "NLLBitsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLBits + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 200 # don't change that since it's hard-coded + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 0.0 + + numEpochs: 50 + earlyStopEpoch: 0 + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [45] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 12 + maxVal: 0.80 + minVal: 0.05 + + retrainIterations: -1 diff --git a/paper/node/param/cnf/mnist_multiscale.yaml b/paper/node/param/cnf/mnist_multiscale.yaml new file mode 100644 index 0000000..fab6fac --- /dev/null +++ b/paper/node/param/cnf/mnist_multiscale.yaml @@ -0,0 +1,66 @@ +network: + name: "ffjord_multiscale_mnist" + dataset: "MNIST" + outputSize: 10 + +training: + transformsTrain: [] + transformsTest: [] + transformsFinal: + - type: Resize + kwargs: { size: 28 } + - type: ToTensor + kwargs: {} + - type: RandomNoise + kwargs: { "normalization": 255.0 } + + loss: "NLLBitsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLBits + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 200 # don't change that since it's hard-coded + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 0.0 + + numEpochs: 50 + earlyStopEpoch: 0 + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [45] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 12 + maxVal: 0.80 + minVal: 0.05 + + retrainIterations: -1 diff --git a/paper/node/param/directories.yaml b/paper/node/param/directories.yaml new file mode 100644 index 0000000..ab6b9e8 --- /dev/null +++ b/paper/node/param/directories.yaml @@ -0,0 +1,6 @@ +# relative directories from where main.py was called +directories: + results: "./data/node/results" + trained_networks: null + training_data: "./data/training" + local_data: "./local" diff --git a/paper/node/param/tabular/bsds300/l3_hm20_f2_softplus.yaml b/paper/node/param/tabular/bsds300/l3_hm20_f2_softplus.yaml new file mode 100644 index 0000000..ced304b --- /dev/null +++ b/paper/node/param/tabular/bsds300/l3_hm20_f2_softplus.yaml @@ -0,0 +1,61 @@ +network: + name: "ffjord_l3_hm20_f2_softplus" + dataset: "Bsds300" + outputSize: 63 + +training: + transformsTrain: [] + transformsTest: [] + transformsFinal: [] + + loss: "NLLNatsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLNats + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 10000 + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 1.0e-6 + + numEpochs: 100 + earlyStopEpoch: 0 + + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [96, 99] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 8 + maxVal: 0.70 + minVal: 0.10 + + retrainIterations: -1 diff --git a/paper/node/param/tabular/gas/l3_hm20_f5_tanh.yaml b/paper/node/param/tabular/gas/l3_hm20_f5_tanh.yaml new file mode 100644 index 0000000..7678a0b --- /dev/null +++ b/paper/node/param/tabular/gas/l3_hm20_f5_tanh.yaml @@ -0,0 +1,61 @@ +network: + name: "ffjord_l3_hm20_f5_tanh" + dataset: "Gas" + outputSize: 8 + +training: + transformsTrain: [] + transformsTest: [] + transformsFinal: [] + + loss: "NLLNatsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLNats + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 1000 + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 1.0e-6 + + numEpochs: 30 + earlyStopEpoch: 0 + + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [25, 28] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 8 + maxVal: 0.70 + minVal: 0.10 + + retrainIterations: -1 diff --git a/paper/node/param/tabular/hepmass/l2_hm10_f10_softplus.yaml b/paper/node/param/tabular/hepmass/l2_hm10_f10_softplus.yaml new file mode 100644 index 0000000..04f28ea --- /dev/null +++ b/paper/node/param/tabular/hepmass/l2_hm10_f10_softplus.yaml @@ -0,0 +1,68 @@ +network: + name: "ffjord_l2_hm10_f10_softplus" + dataset: "Hepmass" + outputSize: 21 + +training: + transformsTrain: [] + transformsTest: [] + transformsFinal: [] + + loss: "NLLNatsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLNats + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 10000 + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 1.0e-6 + + numEpochs: 400 + earlyStopEpoch: 0 + + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [325, 375] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + numEpochs: 300 + earlyStopEpoch: 0 + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [250, 295] } + kwargs: { gamma: 0.1 } + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 12 + maxVal: 0.80 + minVal: 0.05 + + retrainIterations: -1 diff --git a/paper/node/param/tabular/miniboone/l2_hm20_f1_softplus.yaml b/paper/node/param/tabular/miniboone/l2_hm20_f1_softplus.yaml new file mode 100644 index 0000000..9cd4394 --- /dev/null +++ b/paper/node/param/tabular/miniboone/l2_hm20_f1_softplus.yaml @@ -0,0 +1,65 @@ +network: + name: "ffjord_l2_hm20_f1_softplus" + dataset: "Miniboone" + outputSize: 43 + +training: + transformsTrain: [] + transformsTest: [] + transformsFinal: [] + + loss: "NLLNatsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLNats + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 1000 + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 1.0e-6 + + numEpochs: 400 + earlyStopEpoch: 0 + + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [300, 350] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 12 + maxVal: 0.80 + minVal: 0.05 + - type: "linear" + numIntervals: 4 + maxVal: 0.04 + minVal: 0.01 + + retrainIterations: -1 diff --git a/paper/node/param/tabular/power/l3_hm10_f5_tanh.yaml b/paper/node/param/tabular/power/l3_hm10_f5_tanh.yaml new file mode 100644 index 0000000..71b07b4 --- /dev/null +++ b/paper/node/param/tabular/power/l3_hm10_f5_tanh.yaml @@ -0,0 +1,61 @@ +network: + name: "ffjord_l3_hm10_f5_tanh" + dataset: "Power" + outputSize: 6 + +training: + transformsTrain: [] + transformsTest: [] + transformsFinal: [] + + loss: "NLLNatsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLNats + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 10000 + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 1.0e-6 + + numEpochs: 100 + earlyStopEpoch: 0 + + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [90, 97] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 8 + maxVal: 0.70 + minVal: 0.10 + + retrainIterations: -1 diff --git a/paper/node/param/toy/ffjord/common/experiment.yaml b/paper/node/param/toy/ffjord/common/experiment.yaml new file mode 100644 index 0000000..511c93b --- /dev/null +++ b/paper/node/param/toy/ffjord/common/experiment.yaml @@ -0,0 +1,29 @@ +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "FilterThresNet" + - "ThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 20 + maxVal: 0.80 + minVal: 0.20 + - type: "linear" + numIntervals: 9 + maxVal: 0.18 + minVal: 0.02 + + retrainIterations: -1 diff --git a/paper/node/param/toy/ffjord/common/sweep_activation_da.yaml b/paper/node/param/toy/ffjord/common/sweep_activation_da.yaml new file mode 100644 index 0000000..fcbe321 --- /dev/null +++ b/paper/node/param/toy/ffjord/common/sweep_activation_da.yaml @@ -0,0 +1,12 @@ +file: "paper/node/param/toy/ffjord/common/experiment.yaml" + +# we vary the activation function +customizations: + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l4_h64_softplus_da" + - key: ["network", "name"] + value: "ffjord_l4_h64_tanh_da" + - key: ["network", "name"] + value: "ffjord_l4_h64_relu_da" diff --git a/paper/node/param/toy/ffjord/common/sweep_model_da.yaml b/paper/node/param/toy/ffjord/common/sweep_model_da.yaml new file mode 100644 index 0000000..d906115 --- /dev/null +++ b/paper/node/param/toy/ffjord/common/sweep_model_da.yaml @@ -0,0 +1,12 @@ +file: "paper/node/param/toy/ffjord/common/experiment.yaml" + +# now we vary the architecture +customizations: + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l8_h64_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l2_h128_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l2_h64_sigmoid_da" diff --git a/paper/node/param/toy/ffjord/common/sweep_model_large.yaml b/paper/node/param/toy/ffjord/common/sweep_model_large.yaml new file mode 100644 index 0000000..ad6de78 --- /dev/null +++ b/paper/node/param/toy/ffjord/common/sweep_model_large.yaml @@ -0,0 +1,14 @@ +file: "paper/node/param/toy/ffjord/common/experiment.yaml" + +# now we vary the architecture +customizations: + - key: ["network", "name"] + value: "ffjord_l8_h37_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l6_h45_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l3_h90_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l2_h1700_sigmoid_da" diff --git a/paper/node/param/toy/ffjord/common/sweep_model_med.yaml b/paper/node/param/toy/ffjord/common/sweep_model_med.yaml new file mode 100644 index 0000000..d37fc6b --- /dev/null +++ b/paper/node/param/toy/ffjord/common/sweep_model_med.yaml @@ -0,0 +1,14 @@ +file: "paper/node/param/toy/ffjord/common/experiment.yaml" + +# now we vary the architecture with approximately fixed # of parameters +customizations: + - key: ["network", "name"] + value: "ffjord_l8_h18_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l6_h22_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l4_h30_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l3_h43_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l2_h400_sigmoid_da" diff --git a/paper/node/param/toy/ffjord/common/sweep_model_small.yaml b/paper/node/param/toy/ffjord/common/sweep_model_small.yaml new file mode 100644 index 0000000..9dcccf5 --- /dev/null +++ b/paper/node/param/toy/ffjord/common/sweep_model_small.yaml @@ -0,0 +1,14 @@ +file: "paper/node/param/toy/ffjord/common/experiment.yaml" + +# now we vary the architecture with approximately fixed # of parameters +customizations: + - key: ["network", "name"] + value: "ffjord_l8_h10_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l6_h12_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l4_h17_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l3_h23_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l2_h128_sigmoid_da" diff --git a/paper/node/param/toy/ffjord/common/sweep_solver.yaml b/paper/node/param/toy/ffjord/common/sweep_solver.yaml new file mode 100644 index 0000000..4f75ce2 --- /dev/null +++ b/paper/node/param/toy/ffjord/common/sweep_solver.yaml @@ -0,0 +1,16 @@ +file: "paper/node/param/toy/ffjord/common/experiment.yaml" + +# we vary the solver +customizations: + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_rk4_autograd" + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_rk4_adjoint" + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_dopri_autograd" + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_dopri_adjoint" + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_euler_autograd" + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_euler_adjoint" diff --git a/paper/node/param/toy/ffjord/gaussians/l2_h128_sigmoid_da.yaml b/paper/node/param/toy/ffjord/gaussians/l2_h128_sigmoid_da.yaml new file mode 100644 index 0000000..28818a0 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/l2_h128_sigmoid_da.yaml @@ -0,0 +1,8 @@ +network: + file: "paper/node/param/toy/ffjord/gaussians/network_da.yaml" + name: "ffjord_l2_h128_sigmoid_da" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/ffjord/gaussians/l4_h64_sigmoid_da.yaml b/paper/node/param/toy/ffjord/gaussians/l4_h64_sigmoid_da.yaml new file mode 100644 index 0000000..05ed81c --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/l4_h64_sigmoid_da.yaml @@ -0,0 +1,7 @@ +network: + file: "paper/node/param/toy/ffjord/gaussians/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/ffjord/gaussians/network_da.yaml b/paper/node/param/toy/ffjord/gaussians/network_da.yaml new file mode 100644 index 0000000..4dcacdc --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/network_da.yaml @@ -0,0 +1,3 @@ +name: "ffjord_l4_h64_sigmoid_da" +dataset: "ToyGaussians" +outputSize: 2 diff --git a/paper/node/param/toy/ffjord/gaussians/sweep_activation_da.yaml b/paper/node/param/toy/ffjord/gaussians/sweep_activation_da.yaml new file mode 100644 index 0000000..b1c848b --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/sweep_activation_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussians/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/ffjord/common/sweep_activation_da.yaml" diff --git a/paper/node/param/toy/ffjord/gaussians/sweep_model_da.yaml b/paper/node/param/toy/ffjord/gaussians/sweep_model_da.yaml new file mode 100644 index 0000000..9c1c004 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/sweep_model_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussians/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +# now we vary the model +file: "paper/node/param/toy/ffjord/common/sweep_model_da.yaml" diff --git a/paper/node/param/toy/ffjord/gaussians/sweep_model_small.yaml b/paper/node/param/toy/ffjord/gaussians/sweep_model_small.yaml new file mode 100644 index 0000000..472c5fa --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/sweep_model_small.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussians/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +# now we vary the model +file: "paper/node/param/toy/ffjord/common/sweep_model_small.yaml" diff --git a/paper/node/param/toy/ffjord/gaussians/sweep_solver.yaml b/paper/node/param/toy/ffjord/gaussians/sweep_solver.yaml new file mode 100644 index 0000000..80f89fa --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/sweep_solver.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussians/network.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +# now we vary the solver +file: "paper/node/param/toy/ffjord/common/sweep_solver.yaml" diff --git a/paper/node/param/toy/ffjord/gaussians/training.yaml b/paper/node/param/toy/ffjord/gaussians/training.yaml new file mode 100644 index 0000000..e9b2554 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/training.yaml @@ -0,0 +1,25 @@ +transformsTrain: [] +transformsTest: [] +transformsFinal: [] + +loss: "NLLPriorLoss" +lossKwargs: {} + +metricsTest: + - type: NLLPrior + kwargs: {} + - type: Dummy + kwargs: {} + +batchSize: 1024 + +optimizer: "AdamW" +optimizerKwargs: + lr: 5.0e-3 + weight_decay: 1.0e-5 + +numEpochs: 100 + +enableAMP: False + +lrSchedulers: [] diff --git a/paper/node/param/toy/ffjord/gaussians/vanilla_l2_h128.yaml b/paper/node/param/toy/ffjord/gaussians/vanilla_l2_h128.yaml new file mode 100644 index 0000000..2326bdd --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/vanilla_l2_h128.yaml @@ -0,0 +1,8 @@ +network: + file: "paper/node/param/toy/ffjord/gaussians/network_da.yaml" + name: "cnf_l2_h128_sigmoid_da" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/l4_h64_sigmoid_da.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/l4_h64_sigmoid_da.yaml new file mode 100644 index 0000000..cac948f --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/l4_h64_sigmoid_da.yaml @@ -0,0 +1,7 @@ +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml new file mode 100644 index 0000000..5e1865e --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml @@ -0,0 +1,3 @@ +name: "ffjord_l4_h64_sigmoid_da" +dataset: "ToyGaussiansSpiral" +outputSize: 2 diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/sweep_act_da.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_act_da.yaml new file mode 100644 index 0000000..5abf6ba --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_act_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/ffjord/common/sweep_activation_da.yaml" diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/sweep_model_da.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_model_da.yaml new file mode 100644 index 0000000..9dc88e0 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_model_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +# now we vary the model +file: "paper/node/param/toy/ffjord/common/sweep_model_da.yaml" diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/sweep_model_med.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_model_med.yaml new file mode 100644 index 0000000..b395378 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_model_med.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +# now we vary the model +file: "paper/node/param/toy/ffjord/common/sweep_model_med.yaml" diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/sweep_opt_ref.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_opt_ref.yaml new file mode 100644 index 0000000..6134a06 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_opt_ref.yaml @@ -0,0 +1,56 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: [] + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 2 + maxVal: 0.80 + minVal: 0.20 + + retrainIterations: -1 + +customizations: + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-5 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-4 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-5 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-4 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-7 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-5 } diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/sweep_solver.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_solver.yaml new file mode 100644 index 0000000..784f5b9 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_solver.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +# now we vary the solver +file: "paper/node/param/toy/ffjord/common/sweep_solver.yaml" diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/training.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/training.yaml new file mode 100644 index 0000000..21f219c --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/training.yaml @@ -0,0 +1,25 @@ +transformsTrain: [] +transformsTest: [] +transformsFinal: [] + +loss: "NLLPriorLoss" +lossKwargs: {} + +metricsTest: + - type: NLLPrior + kwargs: {} + - type: Dummy + kwargs: {} + +batchSize: 1024 + +optimizer: "AdamW" +optimizerKwargs: + lr: 0.05 + weight_decay: 0.01 + +numEpochs: 100 + +enableAMP: False + +lrSchedulers: [] diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/vanilla_l4_h64.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/vanilla_l4_h64.yaml new file mode 100644 index 0000000..feb1d10 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/vanilla_l4_h64.yaml @@ -0,0 +1,8 @@ +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + name: "cnf_l4_h64_sigmoid_da" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/ffjord/spirals/l4_h64_sigmoid_da.yaml b/paper/node/param/toy/ffjord/spirals/l4_h64_sigmoid_da.yaml new file mode 100644 index 0000000..b4e4bef --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/l4_h64_sigmoid_da.yaml @@ -0,0 +1,7 @@ +network: + file: "paper/node/param/toy/ffjord/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/spirals/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/ffjord/spirals/network_da.yaml b/paper/node/param/toy/ffjord/spirals/network_da.yaml new file mode 100644 index 0000000..3965cba --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/network_da.yaml @@ -0,0 +1,3 @@ +name: "ffjord_l4_h64_sigmoid_da" +dataset: "ToySpirals2" +outputSize: 2 diff --git a/paper/node/param/toy/ffjord/spirals/sweep_act_da.yaml b/paper/node/param/toy/ffjord/spirals/sweep_act_da.yaml new file mode 100644 index 0000000..c6f1dd0 --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/sweep_act_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/spirals/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/ffjord/common/sweep_activation_da.yaml" diff --git a/paper/node/param/toy/ffjord/spirals/sweep_model_da.yaml b/paper/node/param/toy/ffjord/spirals/sweep_model_da.yaml new file mode 100644 index 0000000..22365fd --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/sweep_model_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/spirals/training.yaml" + +# now we vary the model +file: "paper/node/param/toy/ffjord/common/sweep_model_da.yaml" diff --git a/paper/node/param/toy/ffjord/spirals/sweep_model_large.yaml b/paper/node/param/toy/ffjord/spirals/sweep_model_large.yaml new file mode 100644 index 0000000..b0a30e1 --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/sweep_model_large.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/spirals/training.yaml" + +# now we vary the model +file: "paper/node/param/toy/ffjord/common/sweep_model_large.yaml" diff --git a/paper/node/param/toy/ffjord/spirals/sweep_opt_ref.yaml b/paper/node/param/toy/ffjord/spirals/sweep_opt_ref.yaml new file mode 100644 index 0000000..43067d6 --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/sweep_opt_ref.yaml @@ -0,0 +1,56 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/spirals/training.yaml" + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: [] + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 2 + maxVal: 0.80 + minVal: 0.20 + + retrainIterations: -1 + +customizations: + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-5 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-4 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-5 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-4 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-7 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-5 } diff --git a/paper/node/param/toy/ffjord/spirals/training.yaml b/paper/node/param/toy/ffjord/spirals/training.yaml new file mode 100644 index 0000000..7551456 --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/training.yaml @@ -0,0 +1,25 @@ +transformsTrain: [] +transformsTest: [] +transformsFinal: [] + +loss: "NLLPriorLoss" +lossKwargs: {} + +metricsTest: + - type: NLLPrior + kwargs: {} + - type: Dummy + kwargs: {} + +batchSize: 1024 + +optimizer: "AdamW" +optimizerKwargs: + lr: 0.05 + weight_decay: 1.0e-6 + +numEpochs: 100 + +enableAMP: False + +lrSchedulers: [] diff --git a/paper/node/param/toy/ffjord/spirals/vanilla_l4_h64.yaml b/paper/node/param/toy/ffjord/spirals/vanilla_l4_h64.yaml new file mode 100644 index 0000000..d8cc39e --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/vanilla_l4_h64.yaml @@ -0,0 +1,8 @@ +network: + file: "paper/node/param/toy/ffjord/spirals/network_da.yaml" + name: "cnf_l4_h64_sigmoid_da_high_tol" + +training: + file: "paper/node/param/toy/ffjord/spirals/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/node/common/experiment.yaml b/paper/node/param/toy/node/common/experiment.yaml new file mode 100644 index 0000000..511c93b --- /dev/null +++ b/paper/node/param/toy/node/common/experiment.yaml @@ -0,0 +1,29 @@ +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "FilterThresNet" + - "ThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 20 + maxVal: 0.80 + minVal: 0.20 + - type: "linear" + numIntervals: 9 + maxVal: 0.18 + minVal: 0.02 + + retrainIterations: -1 diff --git a/paper/node/param/toy/node/common/sweep_activation_da.yaml b/paper/node/param/toy/node/common/sweep_activation_da.yaml new file mode 100644 index 0000000..b6601af --- /dev/null +++ b/paper/node/param/toy/node/common/sweep_activation_da.yaml @@ -0,0 +1,12 @@ +file: "paper/node/param/toy/node/common/experiment.yaml" + +# now we vary the activation function +customizations: + - key: ["network", "name"] + value: "node_l2_h64_tanh_da" + - key: ["network", "name"] + value: "node_l2_h64_sigmoid_da" + - key: ["network", "name"] + value: "node_l2_h64_softplus_da" + - key: ["network", "name"] + value: "node_l2_h64_relu_da" diff --git a/paper/node/param/toy/node/common/sweep_model_da.yaml b/paper/node/param/toy/node/common/sweep_model_da.yaml new file mode 100644 index 0000000..7076374 --- /dev/null +++ b/paper/node/param/toy/node/common/sweep_model_da.yaml @@ -0,0 +1,14 @@ +file: "paper/node/param/toy/node/common/experiment.yaml" + +# now we vary the architecture +customizations: + - key: ["network", "name"] + value: "node_l2_h32_tanh_da" + - key: ["network", "name"] + value: "node_l2_h64_tanh_da" + - key: ["network", "name"] + value: "node_l2_h128_tanh_da" + - key: ["network", "name"] + value: "node_l4_h32_tanh_da" + - key: ["network", "name"] + value: "node_l4_h128_tanh_da" diff --git a/paper/node/param/toy/node/common/sweep_opt.yaml b/paper/node/param/toy/node/common/sweep_opt.yaml new file mode 100644 index 0000000..c6c86e0 --- /dev/null +++ b/paper/node/param/toy/node/common/sweep_opt.yaml @@ -0,0 +1,28 @@ +file: "paper/node/param/toy/node/common/experiment.yaml" + +# now we vary the optimizer +customizations: + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-5 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-4 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-5 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-4 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-7 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-5 } diff --git a/paper/node/param/toy/node/common/sweep_solver.yaml b/paper/node/param/toy/node/common/sweep_solver.yaml new file mode 100644 index 0000000..9f610b5 --- /dev/null +++ b/paper/node/param/toy/node/common/sweep_solver.yaml @@ -0,0 +1,16 @@ +file: "paper/node/param/toy/node/common/experiment.yaml" + +# now we vary the solver +customizations: + - key: ["network", "name"] + value: "node_l4_h32_tanh_dopri_adjoint" + - key: ["network", "name"] + value: "node_l4_h32_tanh_dopri_autograd" + - key: ["network", "name"] + value: "node_l4_h32_tanh_rk4_adjoint" + - key: ["network", "name"] + value: "node_l4_h32_tanh_rk4_autograd" + - key: ["network", "name"] + value: "node_l4_h32_tanh_euler_adjoint" + - key: ["network", "name"] + value: "node_l4_h32_tanh_euler_autograd" diff --git a/paper/node/param/toy/node/concentric/l2_h128_tanh_da.yaml b/paper/node/param/toy/node/concentric/l2_h128_tanh_da.yaml new file mode 100644 index 0000000..3425d86 --- /dev/null +++ b/paper/node/param/toy/node/concentric/l2_h128_tanh_da.yaml @@ -0,0 +1,8 @@ +network: + name: "node_l2_h128_tanh_da" + file: "paper/node/param/toy/node/concentric/network_da.yaml" + +training: + file: "paper/node/param/toy/node/concentric/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/concentric/l2_h64_tanh_da.yaml b/paper/node/param/toy/node/concentric/l2_h64_tanh_da.yaml new file mode 100644 index 0000000..eb16d82 --- /dev/null +++ b/paper/node/param/toy/node/concentric/l2_h64_tanh_da.yaml @@ -0,0 +1,7 @@ +network: + file: "paper/node/param/toy/node/concentric/network_da.yaml" + +training: + file: "paper/node/param/toy/node/concentric/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/concentric/network_da.yaml b/paper/node/param/toy/node/concentric/network_da.yaml new file mode 100644 index 0000000..d5f9fd2 --- /dev/null +++ b/paper/node/param/toy/node/concentric/network_da.yaml @@ -0,0 +1,3 @@ +name: "node_l2_h64_tanh_da" +dataset: "ToyConcentric" +outputSize: 2 diff --git a/paper/node/param/toy/node/concentric/sweep_activation_da.yaml b/paper/node/param/toy/node/concentric/sweep_activation_da.yaml new file mode 100644 index 0000000..b23ae9c --- /dev/null +++ b/paper/node/param/toy/node/concentric/sweep_activation_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/concentric/network_da.yaml" + +training: + file: "paper/node/param/toy/node/concentric/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_activation_da.yaml" diff --git a/paper/node/param/toy/node/concentric/sweep_model_da.yaml b/paper/node/param/toy/node/concentric/sweep_model_da.yaml new file mode 100644 index 0000000..f3ce6bc --- /dev/null +++ b/paper/node/param/toy/node/concentric/sweep_model_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/concentric/network_da.yaml" + +training: + file: "paper/node/param/toy/node/concentric/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_model_da.yaml" diff --git a/paper/node/param/toy/node/concentric/sweep_opt_da.yaml b/paper/node/param/toy/node/concentric/sweep_opt_da.yaml new file mode 100644 index 0000000..727d46b --- /dev/null +++ b/paper/node/param/toy/node/concentric/sweep_opt_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/concentric/network_da.yaml" + +training: + file: "paper/node/param/toy/node/concentric/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_opt.yaml" diff --git a/paper/node/param/toy/node/concentric/sweep_solver.yaml b/paper/node/param/toy/node/concentric/sweep_solver.yaml new file mode 100644 index 0000000..0d9a67b --- /dev/null +++ b/paper/node/param/toy/node/concentric/sweep_solver.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/concentric/network.yaml" + +training: + file: "paper/node/param/toy/node/concentric/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_solver.yaml" diff --git a/paper/node/param/toy/node/concentric/training.yaml b/paper/node/param/toy/node/concentric/training.yaml new file mode 100644 index 0000000..6c7e65d --- /dev/null +++ b/paper/node/param/toy/node/concentric/training.yaml @@ -0,0 +1,25 @@ +transformsTrain: [] +transformsTest: [] +transformsFinal: [] + +loss: "CrossEntropyLoss" +lossKwargs: { reduction: mean } + +metricsTest: + - type: TopK + kwargs: { topk: 1 } + - type: MCorr + kwargs: {} + +batchSize: 128 + +optimizer: "Adam" +optimizerKwargs: + lr: 0.01 + weight_decay: 1.0e-5 + +numEpochs: 50 + +enableAMP: False + +lrSchedulers: [] diff --git a/paper/node/param/toy/node/moons/l2_h128_tanh_da.yaml b/paper/node/param/toy/node/moons/l2_h128_tanh_da.yaml new file mode 100644 index 0000000..c6849fd --- /dev/null +++ b/paper/node/param/toy/node/moons/l2_h128_tanh_da.yaml @@ -0,0 +1,8 @@ +network: + name: "node_l2_h128_tanh_da" + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/moons/l2_h32_tanh_da.yaml b/paper/node/param/toy/node/moons/l2_h32_tanh_da.yaml new file mode 100644 index 0000000..3c6ab7d --- /dev/null +++ b/paper/node/param/toy/node/moons/l2_h32_tanh_da.yaml @@ -0,0 +1,8 @@ +network: + name: "node_l2_h32_tanh_da" + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/moons/l2_h3_tanh_da.yaml b/paper/node/param/toy/node/moons/l2_h3_tanh_da.yaml new file mode 100644 index 0000000..d1dbf91 --- /dev/null +++ b/paper/node/param/toy/node/moons/l2_h3_tanh_da.yaml @@ -0,0 +1,8 @@ +network: + name: "node_l2_h3_tanh_da" + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/moons/l2_h64_tanh_da.yaml b/paper/node/param/toy/node/moons/l2_h64_tanh_da.yaml new file mode 100644 index 0000000..3386573 --- /dev/null +++ b/paper/node/param/toy/node/moons/l2_h64_tanh_da.yaml @@ -0,0 +1,7 @@ +network: + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/moons/network_da.yaml b/paper/node/param/toy/node/moons/network_da.yaml new file mode 100644 index 0000000..0aa3519 --- /dev/null +++ b/paper/node/param/toy/node/moons/network_da.yaml @@ -0,0 +1,3 @@ +name: "node_l2_h64_tanh_da" +dataset: "ToyMoons" +outputSize: 2 diff --git a/paper/node/param/toy/node/moons/sweep_activation_da.yaml b/paper/node/param/toy/node/moons/sweep_activation_da.yaml new file mode 100644 index 0000000..b24561c --- /dev/null +++ b/paper/node/param/toy/node/moons/sweep_activation_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_activation_da.yaml" diff --git a/paper/node/param/toy/node/moons/sweep_model_da.yaml b/paper/node/param/toy/node/moons/sweep_model_da.yaml new file mode 100644 index 0000000..ff7a051 --- /dev/null +++ b/paper/node/param/toy/node/moons/sweep_model_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_model_da.yaml" diff --git a/paper/node/param/toy/node/moons/sweep_opt_da.yaml b/paper/node/param/toy/node/moons/sweep_opt_da.yaml new file mode 100644 index 0000000..d56369c --- /dev/null +++ b/paper/node/param/toy/node/moons/sweep_opt_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_opt.yaml" diff --git a/paper/node/param/toy/node/moons/sweep_solver.yaml b/paper/node/param/toy/node/moons/sweep_solver.yaml new file mode 100644 index 0000000..a66cc4a --- /dev/null +++ b/paper/node/param/toy/node/moons/sweep_solver.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/moons/network.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_solver.yaml" diff --git a/paper/node/param/toy/node/moons/training.yaml b/paper/node/param/toy/node/moons/training.yaml new file mode 100644 index 0000000..e091f28 --- /dev/null +++ b/paper/node/param/toy/node/moons/training.yaml @@ -0,0 +1,25 @@ +transformsTrain: [] +transformsTest: [] +transformsFinal: [] + +loss: "CrossEntropyLoss" +lossKwargs: { reduction: mean } + +metricsTest: + - type: TopK + kwargs: { topk: 1 } + - type: MCorr + kwargs: {} + +batchSize: 128 + +optimizer: "Adam" +optimizerKwargs: + lr: 0.01 + weight_decay: 1.0e-4 + +numEpochs: 50 + +enableAMP: False + +lrSchedulers: [] diff --git a/paper/node/param/toy/node/spirals/l2_h64_tanh_da.yaml b/paper/node/param/toy/node/spirals/l2_h64_tanh_da.yaml new file mode 100644 index 0000000..300f3f0 --- /dev/null +++ b/paper/node/param/toy/node/spirals/l2_h64_tanh_da.yaml @@ -0,0 +1,7 @@ +network: + file: "paper/node/param/toy/node/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/node/spirals/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/spirals/network_da.yaml b/paper/node/param/toy/node/spirals/network_da.yaml new file mode 100644 index 0000000..d8538c1 --- /dev/null +++ b/paper/node/param/toy/node/spirals/network_da.yaml @@ -0,0 +1,3 @@ +name: "node_l2_h64_tanh_da" +dataset: "ToySpirals" +outputSize: 2 diff --git a/paper/node/param/toy/node/spirals/sweep_activation_da.yaml b/paper/node/param/toy/node/spirals/sweep_activation_da.yaml new file mode 100644 index 0000000..9f7cc31 --- /dev/null +++ b/paper/node/param/toy/node/spirals/sweep_activation_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/node/spirals/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_activation_da.yaml" diff --git a/paper/node/param/toy/node/spirals/sweep_model_da.yaml b/paper/node/param/toy/node/spirals/sweep_model_da.yaml new file mode 100644 index 0000000..73be2cc --- /dev/null +++ b/paper/node/param/toy/node/spirals/sweep_model_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/node/spirals/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_model_da.yaml" diff --git a/paper/node/param/toy/node/spirals/sweep_opt_da.yaml b/paper/node/param/toy/node/spirals/sweep_opt_da.yaml new file mode 100644 index 0000000..546f1b0 --- /dev/null +++ b/paper/node/param/toy/node/spirals/sweep_opt_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/node/spirals/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_opt.yaml" diff --git a/paper/node/param/toy/node/spirals/sweep_solver.yaml b/paper/node/param/toy/node/spirals/sweep_solver.yaml new file mode 100644 index 0000000..b5eac12 --- /dev/null +++ b/paper/node/param/toy/node/spirals/sweep_solver.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/spirals/network.yaml" + +training: + file: "paper/node/param/toy/node/spirals/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_solver.yaml" diff --git a/paper/node/param/toy/node/spirals/training.yaml b/paper/node/param/toy/node/spirals/training.yaml new file mode 100644 index 0000000..ae28b7e --- /dev/null +++ b/paper/node/param/toy/node/spirals/training.yaml @@ -0,0 +1,25 @@ +transformsTrain: [] +transformsTest: [] +transformsFinal: [] + +loss: "CrossEntropyLoss" +lossKwargs: { reduction: mean } + +metricsTest: + - type: TopK + kwargs: { topk: 1 } + - type: MCorr + kwargs: {} + +batchSize: 128 + +optimizer: "Adam" +optimizerKwargs: + lr: 0.01 + weight_decay: 1.0e-5 + +numEpochs: 100 + +enableAMP: False + +lrSchedulers: [] diff --git a/paper/node/script/plot_datasets.py b/paper/node/script/plot_datasets.py new file mode 100644 index 0000000..0622b39 --- /dev/null +++ b/paper/node/script/plot_datasets.py @@ -0,0 +1,67 @@ +# %% Setup script +import random +import os +import sys + +import numpy as np +from IPython import get_ipython +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import torch + +from torchprune.util import datasets + +IN_JUPYTER = True +try: + get_ipython().run_line_magic("matplotlib", "inline") # show plots +except AttributeError: + IN_JUPYTER = False + +# switch to root folder for data +folder = os.path.abspath("") +if "paper/node/script" in folder: + src_folder = os.path.join(folder, "../../..") + os.chdir(src_folder) + +# %% define plotting function +def plot_dataset(dset): + # put them in a loader and retrieve batch + loader = torch.utils.data.DataLoader( + dataset=dset, + batch_size=min(len(dset), 1000), + num_workers=0, + shuffle=False, + ) + x_data, y_data = next(loader.__iter__()) + + fig = plt.figure(figsize=(3, 3)) + ax = fig.add_subplot(111) + ax.scatter(x_data[:, 0], x_data[:, 1], s=1, c=y_data) + + +# %% go through datasets and plot each of them +dset_list = [ + # "ToyConcentric", + # "ToyMoons", + # "ToySpirals", + # "ToySpirals2", + # "ToyGaussians", + # "ToyGaussiansSpiral", + # "ToyDiffeqml", + "Bsds300", + "Hepmass", + "Miniboone", + "Power", + "Gas", +] +dsets = [] +for dset_name in dset_list: + dsets.append( + getattr(datasets, dset_name)( + root="./local", + file_dir="./data/training", + download=True, + train=False, + ) + ) + plot_dataset(dsets[-1]) diff --git a/paper/node/script/plots2d.py b/paper/node/script/plots2d.py new file mode 100644 index 0000000..56cdc21 --- /dev/null +++ b/paper/node/script/plots2d.py @@ -0,0 +1,209 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import os + +from torchdyn import utils as plot + + +def get_mesh(X_data, N=1000): + X_data = X_data.detach() + spacing = [torch.linspace(x_i.min(), x_i.max(), N) for x_i in X_data.T] + return torch.stack(torch.meshgrid(*spacing), dim=-1) + + +def plot_for_sweep(**kwargs): + """Plot the desired sweep plot.""" + plot_2d_boundary(**kwargs) + + +def plot_dataset(x_data, y_data, **kwargs): + x_data = x_data.detach().cpu() + y_data = y_data.detach().cpu() + colors = ["orange", "blue"] + fig = plt.figure(figsize=(3, 3)) + ax = fig.add_subplot(111) + for i in range(len(x_data)): + ax.scatter( + x_data[i, 0], x_data[i, 1], s=1, color=colors[y_data[i].int()] + ) + + +def plot_2d_boundary( + model, + x_data, + y_data, + mesh, + num_classes=2, + axis=None, + **kwargs, +): + x_data = x_data.detach().cpu() + y_data = y_data.detach().cpu() + preds = torch.argmax(nn.Softmax(-1)(model(mesh)), dim=-1) + preds = preds.detach().cpu().reshape(mesh.size(0), mesh.size(1)) + if axis is None: + plt.figure(figsize=(8, 4)) + axis = plt.gca() + + contour_colors = ["navy", "tab:orange"] + scatter_colors = ["midnightblue", "darkorange"] + axis.contourf( + mesh[:, :, 0].detach().cpu(), + mesh[:, :, 1].detach().cpu(), + preds, + colors=contour_colors, + alpha=0.4, + levels=1, + ) + for i in range(num_classes): + axis.scatter( + x_data[y_data == i, 0], + x_data[y_data == i, 1], + alpha=1.0, + s=6.0, + linewidths=0, + c=scatter_colors[i], + edgecolors=None, + ) + + +def plot_static_vector_field(model, x_data, t=0.0, N=100, axis=None, **kwargs): + device = next(model.parameters()).device + x = torch.linspace(x_data[:, 0].min(), x_data[:, 0].max(), N) + y = torch.linspace(x_data[:, 1].min(), x_data[:, 1].max(), N) + X, Y = torch.meshgrid(x, y) + + U, V = torch.zeros_like(X), torch.zeros_like(Y) + + for i in range(N): + for j in range(N): + p = torch.cat( + [X[i, j].reshape(1, 1), Y[i, j].reshape(1, 1)], 1 + ).to(device) + O = model.defunc(t, p).detach().cpu() + U[i, j], V[i, j] = O[0, 0], O[0, 1] + + # convert to cpu numpy + X, Y, U, V = [tnsr.cpu().numpy() for tnsr in (X, Y, U, V)] + + if axis is None: + fig = plt.figure(figsize=(3, 3)) + axis = fig.add_subplot(111) + axis.contourf( + X, + Y, + np.sqrt(U ** 2 + V ** 2), + cmap="RdYlBu", + levels=1000, + alpha=0.6, + ) + axis.streamplot( + X.T, + Y.T, + U.T, + V.T, + color="k", + density=1.5, + linewidth=0.7, + arrowsize=0.7, + arrowstyle="-|>", + ) + + axis.set_xlim([x.min(), x.max()]) + axis.set_ylim([y.min(), y.max()]) + axis.set_xlabel(r"$h_0$") + axis.set_ylabel(r"$h_1$") + axis.set_title("Learned Vector Field") + + +def plot_2D_state_space(trajectory, y_data, n_lines, **kwargs): + plot.plot_2D_state_space(trajectory, y_data, n_lines) + + +def plot_2D_depth_trajectory( + s_span, trajectory, y_data, axis1=None, axis2=None, **kwargs +): + if axis1 is None or axis2 is None: + fig = plt.figure(figsize=(8, 2)) + axis1 = fig.add_subplot(121) + axis2 = fig.add_subplot(122) + + colors = ["midnightblue", "darkorange"] + + for i, label in enumerate(y_data): + color = colors[int(label)] + axis1.plot(s_span, trajectory[:, i, 0], color=color, alpha=0.1) + axis2.plot(s_span, trajectory[:, i, 1], color=color, alpha=0.1) + + axis1.set_xlabel(r"Depth") + axis1.set_ylabel(r"Dim. 0") + axis2.set_xlabel(r"Depth") + axis2.set_ylabel(r"Dim. 1") + + +def prepare_data(model, loader, compute_yhat=False, **kwargs): + """Prepare and return the required data.""" + # setup + model = model.model + device = next(model.parameters()).device + plt.style.use("default") + + # collect data from loader + x_data, y_data = None, None + for x_b, y_b in loader: + if x_data is None: + x_data = x_b + y_data = y_b + else: + x_data = torch.cat((x_data, x_b)) + y_data = torch.cat((y_data, y_b)) + x_data, y_data = x_data.to(device), y_data.to(device) + s_span = torch.linspace(0, 1, 100) + trajectory = model.trajectory(x_data, s_span.to(device)).detach().cpu() + mesh = get_mesh(x_data).to(device) + + data = { + "x_data": x_data, + "y_data": y_data, + "n_lines": len(x_data), + "model": model, + "device": device, + "s_span": s_span, + "trajectory": trajectory, + "mesh": mesh, + } + if compute_yhat: + data["y_hat"] = model(x_data).argmax(dim=1) + + return data + + +def plot_all(model, loader, plot_folder=None, all_p=False): + # default plotting style + plt.style.use("default") + + # retrieve plotting kwargs + kwargs_plot = prepare_data(model, loader, all_p) + + def _plot_and_save(plt_handle, plt_name): + plt_handle(**kwargs_plot) + if plot_folder is not None: + os.makedirs(plot_folder, exist_ok=True) + fig = plt.gcf() + fig.savefig( + os.path.join(plot_folder, f"{plt_name}.pdf"), + bbox_inches="tight", + ) + plt.close(fig) + + _plot_and_save(plot_2d_boundary, "2d_boundary") + _plot_and_save(plot_static_vector_field, "static_vector_field") + + if not all_p: + return + + _plot_and_save(plot_2D_state_space, "2D_state_space") + _plot_and_save(plot_2D_depth_trajectory, "2D_depth_trajectory") + _plot_and_save(plot_dataset, "dataset") diff --git a/paper/node/script/plots_cnf.py b/paper/node/script/plots_cnf.py new file mode 100644 index 0000000..f44642e --- /dev/null +++ b/paper/node/script/plots_cnf.py @@ -0,0 +1,215 @@ +import copy +import os +import numpy as np +import matplotlib.pyplot as plt +import torch +from torchdyn.nn import Augmenter + + +def plot_dataset(x_data, **kwargs): + x_data = x_data.detach().cpu() + plt.figure(figsize=(3, 3)) + plot_samples(x_data, axis=plt.gca()) + + +def plot_for_sweep(**kwargs): + """Plot the desired sweep plot.""" + plot_samples(**kwargs) + + +def plot_samples(x_sampled, axis, **kwargs): + x_sampled = x_sampled.detach().cpu() + if x_sampled.shape[1] > 2: + x_sampled = x_sampled[:, 1:3] + axis.scatter( + x_sampled[:, 0], + x_sampled[:, 1], + s=0.2, + alpha=0.8, + linewidths=0, + c="midnightblue", + edgecolors=None, + ) + axis.set_xlim([-2, 2]) + axis.set_ylim([-2, 2]) + + +def plot_samples_density(x_data, x_sampled, **kwargs): + x_data = x_data.detach().cpu() + x_sampled = x_sampled.detach().cpu() + plt.figure(figsize=(12, 4)) + plt.subplot(121) + plot_samples(x_sampled, axis=plt.gca()) + + plt.subplot(122) + plot_samples(x_data, axis=plt.gca(), color="red") + plt.xlim(-2, 2) + plt.ylim(-2, 2) + + +def plot_flow(sample, trajectory, **kwargs): + traj = trajectory.detach().cpu() + sample = sample.detach().cpu() + n = 2000 + plt.figure(figsize=(6, 6)) + plt.scatter(sample[:n, 0], sample[:n, 1], s=10, alpha=0.8, c="black") + plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive") + plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue") + plt.legend(["Prior sample z(S)", "Flow", "z(0)"]) + + +def plot_2D_depth_trajectory( + s_span, trajectory, axis1=None, axis2=None, num_lines=200, **kwargs +): + if axis1 is None or axis2 is None: + fig = plt.figure(figsize=(8, 2)) + axis1 = fig.add_subplot(121) + axis2 = fig.add_subplot(122) + + # trajectory has shape [len(s_span), num_data_points, dim] originally + trajectory = trajectory.detach().permute(1, 2, 0).cpu() + + # subsample trajectories + num_lines = min(num_lines, len(trajectory)) + trajectory = trajectory[torch.randperm(num_lines)][:num_lines] + for traj_one in trajectory: + axis1.plot(s_span, traj_one[0], alpha=0.2) + axis2.plot(s_span, traj_one[1], alpha=0.2) + + axis1.set_xlabel(r"Depth") + axis1.set_ylabel(r"Dim. 0") + axis2.set_xlabel(r"Depth") + axis2.set_ylabel(r"Dim. 1") + + +def plot_static_vector_field(model, N=100, axis=None, **kwargs): + device = next(model.parameters()).device + model = model[1].defunc.m.net + x = torch.linspace(-2, 2, N) + y = torch.linspace(-2, 2, N) + + X, Y = torch.meshgrid(x, y) + U, V = torch.zeros(N, N), torch.zeros(N, N) + + for i in range(N): + for j in range(N): + p = torch.cat( + [X[i, j].reshape(1, 1), Y[i, j].reshape(1, 1)], 1 + ).to(device) + O = model(p).detach().cpu() + U[i, j], V[i, j] = O[0, 0], O[0, 1] + + # convert to cpu numpy + X, Y, U, V = [tnsr.cpu().numpy() for tnsr in (X, Y, U, V)] + + if axis is None: + fig = plt.figure(figsize=(3, 3)) + axis = fig.add_subplot(111) + + axis.contourf( + X, + Y, + np.sqrt(U ** 2 + V ** 2), + cmap="RdYlBu", + levels=1000, + alpha=0.6, + ) + + axis.streamplot( + X.T, + Y.T, + U.T, + V.T, + color="k", + density=1.5, + linewidth=0.7, + arrowsize=0.7, + arrowstyle="<|-", + ) + + axis.set_xlim([x.min(), x.max()]) + axis.set_ylim([y.min(), y.max()]) + axis.set_xlabel(r"$h_0$") + axis.set_ylabel(r"$h_1$") + axis.set_title("Learned Vector Field") + + +def prepare_data(model, loader, collect_from_loader=False, n_samp=2 ** 14): + """Prepare model and data and return as dict.""" + device = next(model.parameters()).device + + # collect prior from model + prior = None + for x_b, _ in loader: + prior = model(x_b.to(device))["prior"] + break + + # extract ffjord model + model = model.model + + # set s-span but keep old one around! + s_span_backup = copy.deepcopy(model[1].s_span) + model[1].s_span = torch.linspace(1, 0, 2).to(device) + + # s_span for trajectory + s_span_traj = torch.linspace(1, 0, 100) + + # preparing some data and samples for plotting + sample = prior.sample(torch.Size([n_samp])).to(device) + x_sampled = model(sample) + trajectory = model[1].trajectory( + Augmenter(1, 1)(sample), + s_span=s_span_traj.to(device), + ) + # scrapping first dimension := jacobian trace + trajectory = trajectory[:, :, 1:] + + # restore s-span + model[1].s_span = s_span_backup + + data = { + "sample": sample, + "x_sampled": x_sampled, + "s_span": s_span_traj, + "trajectory": trajectory, + "model": model, + "device": device, + } + + # collect data from loader + if collect_from_loader: + x_data = None + for x_b, _ in loader: + if x_data is None: + x_data = x_b + else: + x_data = torch.cat((x_data, x_b)) + x_data = x_data.to(device) + data["x_data"] = x_data + + return data + + +def plot_all(model, loader, plot_folder=None, all_p=False): + plt.style.use("default") + + # retrieve plotting kwargs + kwargs_plot = prepare_data(model, loader, True) + + def _plot_and_save(plt_handle, plt_name): + plt_handle(**kwargs_plot) + if plot_folder is not None: + os.makedirs(plot_folder, exist_ok=True) + fig = plt.gcf() + fig.savefig( + os.path.join(plot_folder, f"{plt_name}.pdf"), + bbox_inches="tight", + ) + plt.close(fig) + + _plot_and_save(plot_samples_density, "samples_density") + _plot_and_save(plot_static_vector_field, "static_vector_field") + + if all_p: + _plot_and_save(plot_dataset, "dataset") + _plot_and_save(plot_flow, "flow") diff --git a/paper/node/script/sizes/nn_.py b/paper/node/script/sizes/nn_.py new file mode 100644 index 0000000..b90c45e --- /dev/null +++ b/paper/node/script/sizes/nn_.py @@ -0,0 +1,567 @@ +#!/usr/bin/env python2 +# -*- coding: utf-8 -*- +""" +Created on Mon Dec 11 13:58:12 2017 + +@author: CW +""" + + +import math +import numpy as np + +import torch +import torch.nn as nn +from torch.nn import Module +from torch.nn import functional as F +from torch.nn.parameter import Parameter +from torch.nn.modules.utils import _pair +from torch.autograd import Variable + +# aliasing +N_ = None + + +delta = 1e-6 +softplus_ = nn.Softplus() +softplus = lambda x: softplus_(x) + delta +sigmoid_ = nn.Sigmoid() +sigmoid = lambda x: sigmoid_(x) * (1 - delta) + 0.5 * delta +sigmoid2 = lambda x: sigmoid(x) * 2.0 +logsigmoid = lambda x: -softplus(-x) +logit = lambda x: torch.log +log = lambda x: torch.log(x * 1e2) - np.log(1e2) +logit = lambda x: log(x) - log(1 - x) + + +def softmax(x, dim=-1): + e_x = torch.exp(x - x.max(dim=dim, keepdim=True)[0]) + out = e_x / e_x.sum(dim=dim, keepdim=True) + return out + + +sum1 = lambda x: x.sum(1) +sum_from_one = ( + lambda x: sum_from_one(sum1(x)) if len(x.size()) > 2 else sum1(x) +) + + +class Sigmoid(Module): + def forward(self, x): + return sigmoid(x) + + +class WNlinear(Module): + def __init__( + self, in_features, out_features, bias=True, mask=N_, norm=True + ): + super(WNlinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer("mask", mask) + self.norm = norm + self.direction = Parameter(torch.Tensor(out_features, in_features)) + self.scale = Parameter(torch.Tensor(out_features)) + if bias: + self.bias = Parameter(torch.Tensor(out_features)) + else: + self.register_parameter("bias", N_) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.direction.size(1)) + self.direction.data.uniform_(-stdv, stdv) + self.scale.data.uniform_(1, 1) + if self.bias is not N_: + self.bias.data.uniform_(-stdv, stdv) + + def forward(self, input): + if self.norm: + dir_ = self.direction + direction = dir_.div(dir_.pow(2).sum(1).sqrt()[:, N_]) + weight = self.scale[:, N_].mul(direction) + else: + weight = self.scale[:, N_].mul(self.direction) + if self.mask is not N_: + # weight = weight * getattr(self.mask, + # ('cpu', 'cuda')[weight.is_cuda])() + weight = weight * Variable(self.mask) + return F.linear(input, weight, self.bias) + + def __repr__(self): + return ( + self.__class__.__name__ + + "(" + + "in_features=" + + str(self.in_features) + + ", out_features=" + + str(self.out_features) + + ")" + ) + + +class CWNlinear(Module): + def __init__( + self, in_features, out_features, context_features, mask=N_, norm=True + ): + super(CWNlinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.context_features = context_features + self.register_buffer("mask", mask) + self.norm = norm + self.direction = Parameter(torch.Tensor(out_features, in_features)) + self.cscale = nn.Linear(context_features, out_features) + self.cbias = nn.Linear(context_features, out_features) + self.reset_parameters() + self.cscale.weight.data.normal_(0, 0.001) + self.cbias.weight.data.normal_(0, 0.001) + + def reset_parameters(self): + self.direction.data.normal_(0, 0.001) + + def forward(self, inputs): + input, context = inputs + scale = self.cscale(context) + bias = self.cbias(context) + if self.norm: + dir_ = self.direction + direction = dir_.div(dir_.pow(2).sum(1).sqrt()[:, N_]) + weight = direction + else: + weight = self.direction + if self.mask is not N_: + # weight = weight * getattr(self.mask, + # ('cpu', 'cuda')[weight.is_cuda])() + weight = weight * Variable(self.mask) + return scale * F.linear(input, weight, None) + bias, context + + def __repr__(self): + return ( + self.__class__.__name__ + + "(" + + "in_features=" + + str(self.in_features) + + ", out_features=" + + str(self.out_features) + + ")" + ) + + +class WNBilinear(Module): + def __init__(self, in1_features, in2_features, out_features, bias=True): + super(WNBilinear, self).__init__() + self.in1_features = in1_features + self.in2_features = in2_features + self.out_features = out_features + self.direction = Parameter( + torch.Tensor(out_features, in1_features, in2_features) + ) + self.scale = Parameter(torch.Tensor(out_features)) + if bias: + self.bias = Parameter(torch.Tensor(out_features)) + else: + self.register_parameter("bias", N_) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.direction.size(1)) + self.direction.data.uniform_(-stdv, stdv) + self.scale.data.uniform_(1, 1) + if self.bias is not N_: + self.bias.data.uniform_(-stdv, stdv) + + def forward(self, input1, input2): + dir_ = self.direction + direction = dir_.div(dir_.pow(2).sum(1).sum(1).sqrt()[:, N_, N_]) + weight = self.scale[:, N_, N_].mul(direction) + return F.bilinear(input1, input2, weight, self.bias) + + def __repr__(self): + return ( + self.__class__.__name__ + + "(" + + "in1_features=" + + str(self.in1_features) + + ", in2_features=" + + str(self.in2_features) + + ", out_features=" + + str(self.out_features) + + ")" + ) + + +class _WNconvNd(Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + ): + super(_WNconvNd, self).__init__() + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + if out_channels % groups != 0: + raise ValueError("out_channels must be divisible by groups") + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + + # weight – filters tensor (out_channels x in_channels/groups x kH x kW) + if transposed: + self.direction = Parameter( + torch.Tensor(in_channels, out_channels // groups, *kernel_size) + ) + self.scale = Parameter(torch.Tensor(in_channels)) + else: + self.direction = Parameter( + torch.Tensor(out_channels, in_channels // groups, *kernel_size) + ) + self.scale = Parameter(torch.Tensor(out_channels)) + if bias: + self.bias = Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias", N_) + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1.0 / math.sqrt(n) + self.direction.data.uniform_(-stdv, stdv) + self.scale.data.uniform_(1, 1) + if self.bias is not N_: + self.bias.data.uniform_(-stdv, stdv) + + def __repr__(self): + s = ( + "{name} ({in_channels}, {out_channels}, kernel_size={kernel_size}" + ", stride={stride}" + ) + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.output_padding != (0,) * len(self.output_padding): + s += ", output_padding={output_padding}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias is N_: + s += ", bias=False" + s += ")" + return s.format(name=self.__class__.__name__, **self.__dict__) + + +class WNconv2d(_WNconvNd): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + mask=N_, + norm=True, + ): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super(WNconv2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _pair(0), + groups, + bias, + ) + + self.register_buffer("mask", mask) + self.norm = norm + + def forward(self, input): + if self.norm: + dir_ = self.direction + direction = dir_.div( + dir_.pow(2).sum(1).sum(1).sum(1).sqrt()[:, N_, N_, N_] + ) + weight = self.scale[:, N_, N_, N_].mul(direction) + else: + weight = self.scale[:, N_, N_, N_].mul(self.direction) + if self.mask is not None: + weight = weight * Variable(self.mask) + return F.conv2d( + input, + weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + +class CWNconv2d(_WNconvNd): + def __init__( + self, + context_features, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + mask=N_, + norm=True, + ): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super(CWNconv2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _pair(0), + groups, + False, + ) + + self.register_buffer("mask", mask) + self.norm = norm + self.cscale = nn.Linear(context_features, out_channels) + self.cbias = nn.Linear(context_features, out_channels) + + def forward(self, inputs): + input, context = inputs + scale = self.cscale(context)[:, :, N_, N_] + bias = self.cbias(context)[:, :, N_, N_] + if self.norm: + dir_ = self.direction + direction = dir_.div( + dir_.pow(2).sum(1).sum(1).sum(1).sqrt()[:, N_, N_, N_] + ) + weight = direction + else: + weight = self.direction + if self.mask is not None: + weight = weight * Variable(self.mask) + pre = F.conv2d( + input, + weight, + None, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return pre * scale + bias, context + + +class ResConv2d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + activation=nn.ReLU(), + oper=WNconv2d, + ): + super(ResConv2d, self).__init__() + + self.conv_0h = oper( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + self.conv_h1 = oper(out_channels, out_channels, 3, 1, 1, 1, 1, True) + self.conv_01 = oper( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + self.activation = activation + + def forward(self, input): + h = self.activation(self.conv_0h(input)) + out_nonlinear = self.conv_h1(h) + out_skip = self.conv_01(input) + return out_nonlinear + out_skip + + +class ResLinear(nn.Module): + def __init__( + self, + in_features, + out_features, + bias=True, + same_dim=False, + activation=nn.ReLU(), + oper=WNlinear, + ): + super(ResLinear, self).__init__() + + self.same_dim = same_dim + + self.dot_0h = oper(in_features, out_features, bias) + self.dot_h1 = oper(out_features, out_features, bias) + if not same_dim: + self.dot_01 = oper(in_features, out_features, bias) + + self.activation = activation + + def forward(self, input): + h = self.activation(self.dot_0h(input)) + out_nonlinear = self.dot_h1(h) + out_skip = input if self.same_dim else self.dot_01(input) + return out_nonlinear + out_skip + + +class GatingLinear(nn.Module): + def __init__(self, in_features, out_features, oper=WNlinear, **kwargs): + super(GatingLinear, self).__init__() + + self.dot = oper(in_features, out_features, **kwargs) + self.gate = oper(in_features, out_features, **kwargs) + + def forward(self, input): + h = self.dot(input) + s = sigmoid_(self.gate(input)) + return s * h + + +class Reshape(nn.Module): + def __init__(self, shape): + super(Reshape, self).__init__() + self.shape = shape + + def forward(self, input): + return input.view(self.shape) + + +class Slice(nn.Module): + def __init__(self, slc): + super(Slice, self).__init__() + self.slc = slc + + def forward(self, input): + return input.__getitem__(self.slc) + + +class SliceFactory(object): + def __init__(self): + pass + + def __getitem__(self, slc): + return Slice(slc) + + +slicer = SliceFactory() + + +class Lambda(nn.Module): + def __init__(self, function): + super(Lambda, self).__init__() + self.function = function + + def forward(self, input): + return self.function(input) + + +class SequentialFlow(nn.Sequential): + def sample(self, n=1, context=None, **kwargs): + dim = self[0].dim + if isinstance(dim, int): + dim = [ + dim, + ] + + spl = torch.autograd.Variable(torch.FloatTensor(n, *dim).normal_()) + lgd = torch.autograd.Variable( + torch.from_numpy(np.random.rand(n).astype("float32")) + ) + if context is None: + context = torch.autograd.Variable( + torch.from_numpy( + np.zeros((n, self[0].context_dim)).astype("float32") + ) + ) + + if hasattr(self, "gpu"): + if self.gpu: + spl = spl.cuda() + lgd = lgd.cuda() + context = context.cuda() + + return self.forward((spl, lgd, context)) + + def cuda(self): + self.gpu = True + return super(SequentialFlow, self).cuda() + + +class ContextWrapper(nn.Module): + def __init__(self, module): + super(ContextWrapper, self).__init__() + self.module = module + + def forward(self, inputs): + input, context = inputs + output = self.module.forward(input) + return output, context + + +if __name__ == "__main__": + + mdl = CWNlinear(2, 5, 3) + + inp = torch.autograd.Variable( + torch.from_numpy(np.random.rand(2, 2).astype("float32")) + ) + con = torch.autograd.Variable( + torch.from_numpy(np.random.rand(2, 3).astype("float32")) + ) + + print(mdl((inp, con))[0].size()) diff --git a/paper/node/script/sizes/sizes_maf.py b/paper/node/script/sizes/sizes_maf.py new file mode 100644 index 0000000..39de27a --- /dev/null +++ b/paper/node/script/sizes/sizes_maf.py @@ -0,0 +1,61 @@ +# %% just compute a couple of sizes + + +def made_size(l, h, d, k=None): + return int(1.5 * d * h + 0.5 * (l - 1) * h * h) + + +def nvp_size(l, h, d, k=10): + return int(2 * k * d * h + 2 * k * (l - 1) * h * h) + + +def maf_size(l, h, d, k=10): + return int(1.5 * k * d * h + 0.5 * k * (l - 1) * h * h) + + +def format_as_str(num): + if num / 1e9 > 1: + factor, suffix = 1e9, "B" + elif num / 1e6 > 1: + factor, suffix = 1e6, "M" + elif num / 1e3 > 1: + factor, suffix = 1e3, "K" + else: + factor, suffix = 1e0, "" + + num_factored = num / factor + + if num_factored / 1e2 > 1 or True: + num_rounded = str(int(round(num_factored))) + elif num_factored / 1e1 > 1: + num_rounded = f"{num_factored:.1f}" + else: + num_rounded = f"{num_factored:.2f}" + + return f"{num_rounded}{suffix} % {num}" + + +datasets = { + "power": {"l": 2, "h": 100, "d": 6}, + "gas": {"l": 2, "h": 100, "d": 8}, + "hepmass": {"l": 2, "h": 512, "d": 21}, + "miniboone": {"l": 2, "h": 512, "d": 43}, + "bsds300": {"l": 2, "h": 1024, "d": 63}, + "mnist": {"l": 1, "h": 1024, "d": 784, "k": 10}, + "cifar": {"l": 2, "h": 2048, "d": 3072, "k": 10}, +} + +networks = { + "MADE": made_size, + "RealNVP": nvp_size, + "MAF": maf_size, +} + +for net, handle in networks.items(): + print(net) + for dset, s_kwargs in datasets.items(): + print(f"{dset}: #params: {format_as_str(handle(**s_kwargs))}") + print("\n") + + +# miniboone ffjord: 820613 \ No newline at end of file diff --git a/paper/node/script/sizes/sizes_naf.py b/paper/node/script/sizes/sizes_naf.py new file mode 100644 index 0000000..27b03f2 --- /dev/null +++ b/paper/node/script/sizes/sizes_naf.py @@ -0,0 +1,1082 @@ +# %% + +import numpy as np +import torch +import torch.utils.data as data +import torch.nn as nn + +# import torch.optim as optim +from torch.autograd import Variable +from torchvision.transforms import transforms + + +import time +import json +import argparse, os + + +import torch +import torch.nn as nn +from torch.nn import Module +from torch.nn.parameter import Parameter +from torch.nn import functional as F +import nn_ as nn_ +from nn_ import log +from torch.autograd import Variable +import numpy as np + + +sum_from_one = nn_.sum_from_one + + +#!/usr/bin/env python2 +# -*- coding: utf-8 -*- +""" +Created on Mon Dec 11 13:58:12 2017 + +@author: CW +""" + + +import numpy as np + +import torch +import torch.nn as nn +from torch.nn import Module + + +from functools import reduce + +# aliasing +N_ = None + + +delta = 1e-6 +softplus_ = nn.Softplus() +softplus = lambda x: softplus_(x) + delta + + +tile = lambda x, r: np.tile(x, r).reshape(x.shape[0], x.shape[1] * r) + + +# %------------ MADE ------------% + + +def get_rank(max_rank, num_out): + rank_out = np.array([]) + while len(rank_out) < num_out: + rank_out = np.concatenate([rank_out, np.arange(max_rank)]) + excess = len(rank_out) - num_out + remove_ind = np.random.choice(max_rank, excess, False) + rank_out = np.delete(rank_out, remove_ind) + np.random.shuffle(rank_out) + return rank_out.astype("float32") + + +def get_mask_from_ranks(r1, r2): + return (r2[:, None] >= r1[None, :]).astype("float32") + + +def get_masks_all(ds, fixed_order=False, derank=1): + # ds: list of dimensions dx, d1, d2, ... dh, dx, + # (2 in/output + h hidden layers) + # derank only used for self connection, dim > 1 + dx = ds[0] + ms = list() + rx = get_rank(dx, dx) + if fixed_order: + rx = np.sort(rx) + r1 = rx + if dx != 1: + for d in ds[1:-1]: + r2 = get_rank(dx - derank, d) + ms.append(get_mask_from_ranks(r1, r2)) + r1 = r2 + r2 = rx - derank + ms.append(get_mask_from_ranks(r1, r2)) + else: + ms = [ + np.zeros([ds[i + 1], ds[i]]).astype("float32") + for i in range(len(ds) - 1) + ] + if derank == 1: + assert np.all(np.diag(reduce(np.dot, ms[::-1])) == 0), "wrong masks" + + return ms, rx + + +def get_masks(dim, dh, num_layers, num_outlayers, fixed_order=False, derank=1): + ms, rx = get_masks_all( + [ + dim, + ] + + [dh for i in range(num_layers - 1)] + + [ + dim, + ], + fixed_order, + derank, + ) + ml = ms[-1] + ml_ = ( + ( + ml.transpose(1, 0)[:, :, None] + * ( + [ + np.cast["float32"](1), + ] + * num_outlayers + ) + ) + .reshape(dh, dim * num_outlayers) + .transpose(1, 0) + ) + ms[-1] = ml_ + return ms, rx + + +class MADE(Module): + def __init__( + self, + dim, + hid_dim, + num_layers, + num_outlayers=1, + activation=nn.ELU(), + fixed_order=False, + derank=1, + ): + super(MADE, self).__init__() + + oper = nn_.WNlinear + + self.dim = dim + self.hid_dim = hid_dim + self.num_layers = num_layers + self.num_outlayers = num_outlayers + self.activation = activation + + ms, rx = get_masks( + dim, hid_dim, num_layers, num_outlayers, fixed_order, derank + ) + ms = [m for m in map(torch.from_numpy, ms)] + self.rx = rx + + sequels = list() + for i in range(num_layers - 1): + if i == 0: + sequels.append(oper(dim, hid_dim, True, ms[i], False)) + sequels.append(activation) + else: + sequels.append(oper(hid_dim, hid_dim, True, ms[i], False)) + sequels.append(activation) + + self.input_to_hidden = nn.Sequential(*sequels) + self.hidden_to_output = oper( + hid_dim, dim * num_outlayers, True, ms[-1] + ) + + def forward(self, input): + hid = self.input_to_hidden(input) + return self.hidden_to_output(hid).view( + -1, self.dim, self.num_outlayers + ) + + def randomize(self): + ms, rx = get_masks( + self.dim, self.hid_dim, self.num_layers, self.num_outlayers + ) + for i in range(self.num_layers - 1): + mask = torch.from_numpy(ms[i]) + if self.input_to_hidden[i * 2].mask.is_cuda: + mask = mask.cuda() + self.input_to_hidden[i * 2].mask.data.zero_().add_(mask) + self.rx = rx + + +class cMADE(Module): + def __init__( + self, + dim, + hid_dim, + context_dim, + num_layers, + num_outlayers=1, + activation=nn.ELU(), + fixed_order=False, + derank=1, + ): + super(cMADE, self).__init__() + + oper = nn_.CWNlinear + + self.dim = dim + self.hid_dim = hid_dim + self.num_layers = num_layers + self.context_dim = context_dim + self.num_outlayers = num_outlayers + self.activation = nn_.Lambda(lambda x: (activation(x[0]), x[1])) + + ms, rx = get_masks( + dim, hid_dim, num_layers, num_outlayers, fixed_order, derank + ) + ms = [m for m in map(torch.from_numpy, ms)] + self.rx = rx + + sequels = list() + for i in range(num_layers - 1): + if i == 0: + sequels.append(oper(dim, hid_dim, context_dim, ms[i], False)) + sequels.append(self.activation) + else: + sequels.append( + oper(hid_dim, hid_dim, context_dim, ms[i], False) + ) + sequels.append(self.activation) + + self.input_to_hidden = nn.Sequential(*sequels) + self.hidden_to_output = oper( + hid_dim, dim * num_outlayers, context_dim, ms[-1] + ) + + def forward(self, inputs): + input, context = inputs + hid, _ = self.input_to_hidden((input, context)) + out, _ = self.hidden_to_output((hid, context)) + return out.view(-1, self.dim, self.num_outlayers), context + + def randomize(self): + ms, rx = get_masks( + self.dim, self.hid_dim, self.num_layers, self.num_outlayers + ) + for i in range(self.num_layers - 1): + mask = torch.from_numpy(ms[i]) + if self.input_to_hidden[i * 2].mask.is_cuda: + mask = mask.cuda() + self.input_to_hidden[i * 2].mask.zero_().add_(mask) + self.rx = rx + + +class BaseFlow(Module): + def sample(self, n=1, context=None, **kwargs): + dim = self.dim + if isinstance(self.dim, int): + dim = [ + dim, + ] + + spl = Variable(torch.FloatTensor(n, *dim).normal_()) + lgd = Variable(torch.from_numpy(np.zeros(n).astype("float32"))) + if context is None: + context = Variable( + torch.from_numpy( + np.ones((n, self.context_dim)).astype("float32") + ) + ) + + if hasattr(self, "gpu"): + if self.gpu: + spl = spl.cuda() + lgd = lgd.cuda() + context = context.gpu() + + return self.forward((spl, lgd, context)) + + def cuda(self): + self.gpu = True + return super(BaseFlow, self).cuda() + + +class LinearFlow(BaseFlow): + def __init__( + self, dim, context_dim, oper=nn_.ResLinear, realify=nn_.softplus + ): + super(LinearFlow, self).__init__() + self.realify = realify + + self.dim = dim + self.context_dim = context_dim + + if type(dim) is int: + dim_ = dim + else: + dim_ = np.prod(dim) + + self.mean = oper(context_dim, dim_) + self.lstd = oper(context_dim, dim_) + + self.reset_parameters() + + def reset_parameters(self): + if isinstance(self.mean, nn_.ResLinear): + self.mean.dot_01.scale.data.uniform_(-0.001, 0.001) + self.mean.dot_h1.scale.data.uniform_(-0.001, 0.001) + self.mean.dot_01.bias.data.uniform_(-0.001, 0.001) + self.mean.dot_h1.bias.data.uniform_(-0.001, 0.001) + self.lstd.dot_01.scale.data.uniform_(-0.001, 0.001) + self.lstd.dot_h1.scale.data.uniform_(-0.001, 0.001) + if self.realify == nn_.softplus: + inv = np.log(np.exp(1 - nn_.delta) - 1) * 0.5 + self.lstd.dot_01.bias.data.uniform_(inv - 0.001, inv + 0.001) + self.lstd.dot_h1.bias.data.uniform_(inv - 0.001, inv + 0.001) + else: + self.lstd.dot_01.bias.data.uniform_(-0.001, 0.001) + self.lstd.dot_h1.bias.data.uniform_(-0.001, 0.001) + elif isinstance(self.mean, nn.Linear): + self.mean.weight.data.uniform_(-0.001, 0.001) + self.mean.bias.data.uniform_(-0.001, 0.001) + self.lstd.weight.data.uniform_(-0.001, 0.001) + if self.realify == nn_.softplus: + inv = np.log(np.exp(1 - nn_.delta) - 1) * 0.5 + self.lstd.bias.data.uniform_(inv - 0.001, inv + 0.001) + else: + self.lstd.bias.data.uniform_(-0.001, 0.001) + + def forward(self, inputs): + x, logdet, context = inputs + mean = self.mean(context) + lstd = self.lstd(context) + std = self.realify(lstd) + + if type(self.dim) is int: + x_ = mean + std * x + else: + size = x.size() + x_ = mean.view(size) + std.view(size) * x + logdet_ = sum_from_one(torch.log(std)) + logdet + return x_, logdet_, context + + +class BlockAffineFlow(Module): + # NICE, volume preserving + # x2' = x2 + nonLinfunc(x1) + + def __init__(self, dim1, dim2, context_dim, hid_dim, activation=nn.ELU()): + super(BlockAffineFlow, self).__init__() + self.dim1 = dim1 + self.dim2 = dim2 + self.actv = activation + + self.hid = nn_.WNBilinear(dim1, context_dim, hid_dim) + self.shift = nn_.WNBilinear(hid_dim, context_dim, dim2) + + def forward(self, inputs): + x, logdet, context = inputs + x1, x2 = x + + hid = self.actv(self.hid(x1, context)) + shift = self.shift(hid, context) + + x2_ = x2 + shift + + return (x1, x2_), 0, context + + +class IAF(BaseFlow): + def __init__( + self, + dim, + hid_dim, + context_dim, + num_layers, + activation=nn.ELU(), + realify=nn_.sigmoid, + fixed_order=False, + ): + super(IAF, self).__init__() + self.realify = realify + + self.dim = dim + self.context_dim = context_dim + + if type(dim) is int: + self.mdl = cMADE( + dim, + hid_dim, + context_dim, + num_layers, + 2, + activation, + fixed_order, + ) + self.reset_parameters() + + def reset_parameters(self): + self.mdl.hidden_to_output.cscale.weight.data.uniform_(-0.001, 0.001) + self.mdl.hidden_to_output.cscale.bias.data.uniform_(0.0, 0.0) + self.mdl.hidden_to_output.cbias.weight.data.uniform_(-0.001, 0.001) + self.mdl.hidden_to_output.cbias.bias.data.uniform_(0.0, 0.0) + if self.realify == nn_.softplus: + inv = np.log(np.exp(1 - nn_.delta) - 1) + self.mdl.hidden_to_output.cbias.bias.data[1::2].uniform_(inv, inv) + elif self.realify == nn_.sigmoid: + self.mdl.hidden_to_output.cbias.bias.data[1::2].uniform_(2.0, 2.0) + + def forward(self, inputs): + x, logdet, context = inputs + out, _ = self.mdl((x, context)) + if isinstance(self.mdl, cMADE): + mean = out[:, :, 0] + lstd = out[:, :, 1] + + std = self.realify(lstd) + + if self.realify == nn_.softplus: + x_ = mean + std * x + elif self.realify == nn_.sigmoid: + x_ = (-std + 1.0) * mean + std * x + elif self.realify == nn_.sigmoid2: + x_ = (-std + 2.0) * mean + std * x + logdet_ = sum_from_one(torch.log(std)) + logdet + return x_, logdet_, context + + +class IAF_VP(BaseFlow): + def __init__( + self, + dim, + hid_dim, + context_dim, + num_layers, + activation=nn.ELU(), + fixed_order=True, + ): + super(IAF_VP, self).__init__() + + self.dim = dim + self.context_dim = context_dim + + if type(dim) is int: + self.mdl = cMADE( + dim, + hid_dim, + context_dim, + num_layers, + 1, + activation, + fixed_order, + ) + self.reset_parameters() + + def reset_parameters(self): + self.mdl.hidden_to_output.cscale.weight.data.uniform_(-0.001, 0.001) + self.mdl.hidden_to_output.cscale.bias.data.uniform_(0.0, 0.0) + self.mdl.hidden_to_output.cbias.weight.data.uniform_(-0.001, 0.001) + self.mdl.hidden_to_output.cbias.bias.data.uniform_(0.0, 0.0) + + def forward(self, inputs): + x, logdet, context = inputs + out, _ = self.mdl((x, context)) + mean = out[:, :, 0] + x_ = mean + x + return x_, logdet, context + + +class IAF_DSF(BaseFlow): + + mollify = 0.0 + + def __init__( + self, + dim, + hid_dim, + context_dim, + num_layers, + activation=nn.ELU(), + fixed_order=False, + num_ds_dim=4, + num_ds_layers=1, + num_ds_multiplier=3, + ): + super(IAF_DSF, self).__init__() + + self.dim = dim + self.context_dim = context_dim + self.num_ds_dim = num_ds_dim + self.num_ds_layers = num_ds_layers + + if type(dim) is int: + self.mdl = cMADE( + dim, + hid_dim, + context_dim, + num_layers, + num_ds_multiplier * (hid_dim // dim) * num_ds_layers, + activation, + fixed_order, + ) + self.out_to_dsparams = nn.Conv1d( + num_ds_multiplier * (hid_dim // dim) * num_ds_layers, + 3 * num_ds_layers * num_ds_dim, + 1, + ) + self.reset_parameters() + + self.sf = SigmoidFlow(num_ds_dim) + + def reset_parameters(self): + self.out_to_dsparams.weight.data.uniform_(-0.001, 0.001) + self.out_to_dsparams.bias.data.uniform_(0.0, 0.0) + + inv = np.log(np.exp(1 - nn_.delta) - 1) + for l in range(self.num_ds_layers): + nc = self.num_ds_dim + nparams = nc * 3 + s = l * nparams + self.out_to_dsparams.bias.data[s : s + nc].uniform_(inv, inv) + + def forward(self, inputs): + x, logdet, context = inputs + out, _ = self.mdl((x, context)) + if isinstance(self.mdl, cMADE): + out = out.permute(0, 2, 1) + dsparams = self.out_to_dsparams(out).permute(0, 2, 1) + nparams = self.num_ds_dim * 3 + + mollify = self.mollify + h = x.view(x.size(0), -1) + for i in range(self.num_ds_layers): + params = dsparams[:, :, i * nparams : (i + 1) * nparams] + h, logdet = self.sf(h, logdet, params, mollify) + + return h, logdet, context + + +class SigmoidFlow(BaseFlow): + def __init__(self, num_ds_dim=4): + super(SigmoidFlow, self).__init__() + self.num_ds_dim = num_ds_dim + + self.act_a = lambda x: nn_.softplus(x) + self.act_b = lambda x: x + self.act_w = lambda x: nn_.softmax(x, dim=2) + + def forward(self, x, logdet, dsparams, mollify=0.0, delta=nn_.delta): + + ndim = self.num_ds_dim + a_ = self.act_a(dsparams[:, :, 0 * ndim : 1 * ndim]) + b_ = self.act_b(dsparams[:, :, 1 * ndim : 2 * ndim]) + w = self.act_w(dsparams[:, :, 2 * ndim : 3 * ndim]) + + a = a_ * (1 - mollify) + 1.0 * mollify + b = b_ * (1 - mollify) + 0.0 * mollify + + pre_sigm = a * x[:, :, None] + b + sigm = torch.sigmoid(pre_sigm) + x_pre = torch.sum(w * sigm, dim=2) + x_pre_clipped = x_pre * (1 - delta) + delta * 0.5 + x_ = log(x_pre_clipped) - log(1 - x_pre_clipped) + xnew = x_ + + logj = ( + F.log_softmax(dsparams[:, :, 2 * ndim : 3 * ndim], dim=2) + + nn_.logsigmoid(pre_sigm) + + nn_.logsigmoid(-pre_sigm) + + log(a) + ) + + logj = torch.exp(logj, 2).sum(2) + logdet_ = ( + logj + + np.log(1 - delta) + - (log(x_pre_clipped) + log(-x_pre_clipped + 1)) + ) + logdet = logdet_.sum(1) + logdet + + return xnew, logdet + + +class IAF_DDSF(BaseFlow): + def __init__( + self, + dim, + hid_dim, + context_dim, + num_layers, + activation=nn.ELU(), + fixed_order=False, + num_ds_dim=4, + num_ds_layers=1, + num_ds_multiplier=3, + ): + super(IAF_DDSF, self).__init__() + + self.dim = dim + self.context_dim = context_dim + self.num_ds_dim = num_ds_dim + self.num_ds_layers = num_ds_layers + + if type(dim) is int: + self.mdl = cMADE( + dim, + hid_dim, + context_dim, + num_layers, + int(num_ds_multiplier * (hid_dim / dim) * num_ds_layers), + activation, + fixed_order, + ) + + num_dsparams = 0 + for i in range(num_ds_layers): + if i == 0: + in_dim = 1 + else: + in_dim = num_ds_dim + if i == num_ds_layers - 1: + out_dim = 1 + else: + out_dim = num_ds_dim + + u_dim = in_dim + w_dim = num_ds_dim + a_dim = b_dim = num_ds_dim + num_dsparams += u_dim + w_dim + a_dim + b_dim + + self.add_module( + "sf{}".format(i), DenseSigmoidFlow(in_dim, num_ds_dim, out_dim) + ) + if type(dim) is int: + self.out_to_dsparams = nn.Conv1d( + int(num_ds_multiplier * (hid_dim / dim) * num_ds_layers), + int(num_dsparams), + 1, + ) + else: + self.out_to_dsparams = nn.Conv1d( + num_ds_multiplier * (hid_dim / dim[0]) * num_ds_layers, + num_dsparams, + 1, + ) + + self.reset_parameters() + + def reset_parameters(self): + self.out_to_dsparams.weight.data.uniform_(-0.001, 0.001) + self.out_to_dsparams.bias.data.uniform_(0.0, 0.0) + + def forward(self, inputs): + x, logdet, context = inputs + out, _ = self.mdl((x, context)) + out = out.permute(0, 2, 1) + dsparams = self.out_to_dsparams(out).permute(0, 2, 1) + + start = 0 + + h = x.view(x.size(0), -1)[:, :, None] + n = x.size(0) + dim = self.dim if type(self.dim) is int else self.dim[0] + lgd = Variable( + torch.from_numpy(np.zeros((n, dim, 1, 1)).astype("float32")) + ) + if self.out_to_dsparams.weight.is_cuda: + lgd = lgd.cuda() + for i in range(self.num_ds_layers): + if i == 0: + in_dim = 1 + else: + in_dim = self.num_ds_dim + if i == self.num_ds_layers - 1: + out_dim = 1 + else: + out_dim = self.num_ds_dim + + u_dim = in_dim + w_dim = self.num_ds_dim + a_dim = b_dim = self.num_ds_dim + end = start + u_dim + w_dim + a_dim + b_dim + + params = dsparams[:, :, start:end] + h, lgd = getattr(self, "sf{}".format(i))(h, lgd, params) + start = end + + assert out_dim == 1, "last dsf out dim should be 1" + return h[:, :, 0], lgd[:, :, 0, 0].sum(1) + logdet, context + + +class DenseSigmoidFlow(BaseFlow): + def __init__(self, in_dim, hidden_dim, out_dim): + super(DenseSigmoidFlow, self).__init__() + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.out_dim = out_dim + + self.act_a = lambda x: nn_.softplus(x) + self.act_b = lambda x: x + self.act_w = lambda x: nn_.softmax(x, dim=3) + self.act_u = lambda x: nn_.softmax(x, dim=3) + + self.u_ = Parameter(torch.Tensor(hidden_dim, in_dim)) + self.w_ = Parameter(torch.Tensor(out_dim, hidden_dim)) + + self.reset_parameters() + + def reset_parameters(self): + self.u_.data.uniform_(-0.001, 0.001) + self.w_.data.uniform_(-0.001, 0.001) + + def forward(self, x, logdet, dsparams): + inv = np.log(np.exp(1 - nn_.delta) - 1) + ndim = self.hidden_dim + pre_u = ( + self.u_[None, None, :, :] + + dsparams[:, :, -self.in_dim :][:, :, None, :] + ) + pre_w = ( + self.w_[None, None, :, :] + + dsparams[:, :, 2 * ndim : 3 * ndim][:, :, None, :] + ) + a = self.act_a(dsparams[:, :, 0 * ndim : 1 * ndim] + inv) + b = self.act_b(dsparams[:, :, 1 * ndim : 2 * ndim]) + w = self.act_w(pre_w) + u = self.act_u(pre_u) + + pre_sigm = torch.sum(u * a[:, :, :, None] * x[:, :, None, :], 3) + b + sigm = torch.sigmoid(pre_sigm) + x_pre = torch.sum(w * sigm[:, :, None, :], dim=3) + x_pre_clipped = x_pre * (1 - nn_.delta) + nn_.delta * 0.5 + x_ = log(x_pre_clipped) - log(1 - x_pre_clipped) + xnew = x_ + + logj = ( + F.log_softmax(pre_w, dim=3) + + nn_.logsigmoid(pre_sigm[:, :, None, :]) + + nn_.logsigmoid(-pre_sigm[:, :, None, :]) + + log(a[:, :, None, :]) + ) + # n, d, d2, dh + + logj = ( + logj[:, :, :, :, None] + + F.log_softmax(pre_u, dim=3)[:, :, None, :, :] + ) + # n, d, d2, dh, d1 + + logj = torch.exp(logj, 3).sum(3) + # n, d, d2, d1 + + logdet_ = ( + logj + + np.log(1 - nn_.delta) + - (log(x_pre_clipped) + log(-x_pre_clipped + 1))[:, :, :, None] + ) + + logdet = torch.exp( + logdet_[:, :, :, :, None] + logdet[:, :, None, :, :], 3 + ).sum(3) + # n, d, d2, d1, d0 -> n, d, d2, d0 + + return xnew, logdet + + +class FlipFlow(BaseFlow): + def __init__(self, dim): + self.dim = dim + super(FlipFlow, self).__init__() + + def forward(self, inputs): + input, logdet, context = inputs + + dim = self.dim + index = Variable( + getattr( + torch.arange(input.size(dim) - 1, -1, -1), + ("cpu", "cuda")[input.is_cuda], + )().long() + ) + + output = torch.index_select(input, dim, index) + + return output, logdet, context + + +class Sigmoid(BaseFlow): + def __init__(self): + super(Sigmoid, self).__init__() + + def forward(self, inputs): + if len(inputs) == 2: + input, logdet = inputs + elif len(inputs) == 3: + input, logdet, context = inputs + else: + raise (Exception("inputs length not correct")) + + output = F.sigmoid(input) + logdet += sum_from_one(-F.softplus(input) - F.softplus(-input)) + + if len(inputs) == 2: + return output, logdet + elif len(inputs) == 3: + return output, logdet, context + else: + raise (Exception("inputs length not correct")) + + +class Logit(BaseFlow): + def __init__(self): + super(Logit, self).__init__() + + def forward(self, inputs): + if len(inputs) == 2: + input, logdet = inputs + elif len(inputs) == 3: + input, logdet, context = inputs + else: + raise (Exception("inputs length not correct")) + + output = log(input) - log(1 - input) + logdet -= sum_from_one(log(input) + log(-input + 1)) + + if len(inputs) == 2: + return output, logdet + elif len(inputs) == 3: + return output, logdet, context + else: + raise (Exception("inputs length not correct")) + + +class Shift(BaseFlow): + def __init__(self, b): + self.b = b + super(Shift, self).__init__() + + def forward(self, inputs): + if len(inputs) == 2: + input, logdet = inputs + elif len(inputs) == 3: + input, logdet, context = inputs + else: + raise (Exception("inputs length not correct")) + + output = input + self.b + + if len(inputs) == 2: + return output, logdet + elif len(inputs) == 3: + return output, logdet, context + else: + raise (Exception("inputs length not correct")) + + +class Scale(BaseFlow): + def __init__(self, g): + self.g = g + super(Scale, self).__init__() + + def forward(self, inputs): + if len(inputs) == 2: + input, logdet = inputs + elif len(inputs) == 3: + input, logdet, context = inputs + else: + raise (Exception("inputs length not correct")) + + output = input * self.g + logdet += np.log(np.abs(self.g)) * np.prod(input.size()[1:]) + + if len(inputs) == 2: + return output, logdet + elif len(inputs) == 3: + return output, logdet, context + else: + raise (Exception("inputs length not correct")) + + +class MAF(object): + def __init__(self, args, p): + + self.args = args + self.__dict__.update(args.__dict__) + self.p = p + + dim = p + dimc = 1 + dimh = args.dimh + flowtype = args.flowtype + num_flow_layers = args.num_flow_layers + num_ds_dim = args.num_ds_dim + num_ds_layers = args.num_ds_layers + fixed_order = args.fixed_order + + act = nn.ELU() + if flowtype == "affine": + flow = IAF + elif flowtype == "dsf": + flow = lambda **kwargs: IAF_DSF( + num_ds_dim=num_ds_dim, num_ds_layers=num_ds_layers, **kwargs + ) + elif flowtype == "ddsf": + flow = lambda **kwargs: IAF_DDSF( + num_ds_dim=num_ds_dim, num_ds_layers=num_ds_layers, **kwargs + ) + + sequels = [ + nn_.SequentialFlow( + flow( + dim=dim, + hid_dim=dimh, + context_dim=dimc, + num_layers=args.num_hid_layers + 1, + activation=act, + fixed_order=fixed_order, + ), + FlipFlow(1), + ) + for i in range(num_flow_layers) + ] + [ + LinearFlow(dim, dimc), + ] + + self.flow = nn.Sequential(*sequels) + + +# ============================================================================= +# main +# ============================================================================= + + +"""parsing and configuration""" + + +def parse_args(): + desc = "MAF" + parser = argparse.ArgumentParser(description=desc) + + parser.add_argument( + "--dataset", + type=str, + default="miniboone", + choices=["power", "gas", "hepmass", "miniboone", "bsds300"], + ) + parser.add_argument( + "--epoch", type=int, default=400, help="The number of epochs to run" + ) + parser.add_argument( + "--batch_size", type=int, default=100, help="The size of batch" + ) + parser.add_argument( + "--save_dir", + type=str, + default="models", + help="Directory name to save the model", + ) + parser.add_argument( + "--result_dir", + type=str, + default="results", + help="Directory name to save the generated images", + ) + parser.add_argument( + "--log_dir", + type=str, + default="logs", + help="Directory name to save training logs", + ) + parser.add_argument("--seed", type=int, default=1993, help="Random seed") + parser.add_argument( + "--fn", type=str, default="0", help="Filename of model to be loaded" + ) + parser.add_argument( + "--to_train", type=int, default=1, help="1 if to train 0 if not" + ) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--clip", type=float, default=5.0) + parser.add_argument("--beta1", type=float, default=0.9) + parser.add_argument("--beta2", type=float, default=0.999) + parser.add_argument("--amsgrad", type=int, default=0) + parser.add_argument("--polyak", type=float, default=0.0) + parser.add_argument("--cuda", type=bool, default=False) + + parser.add_argument("--dimh", type=int, default=100) + parser.add_argument("--flowtype", type=str, default="ddsf") + parser.add_argument("--num_flow_layers", type=int, default=10) + parser.add_argument("--num_hid_layers", type=int, default=1) + parser.add_argument("--num_ds_dim", type=int, default=16) + parser.add_argument("--num_ds_layers", type=int, default=1) + parser.add_argument( + "--fixed_order", + type=bool, + default=True, + help="Fix the made ordering to be the given order", + ) + + return check_args(parser.parse_args()) + + +"""checking arguments""" + + +def check_args(args): + # --save_dir + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + + # --result_dir + if not os.path.exists(args.result_dir + "_" + args.dataset): + os.makedirs(args.result_dir + "_" + args.dataset) + + # --result_dir + if not os.path.exists(args.log_dir): + os.makedirs(args.log_dir) + + # --epoch + try: + assert args.epoch >= 1 + except: + print("number of epochs must be larger than or equal to one") + + # --batch_size + try: + assert args.batch_size >= 1 + except: + print("batch size must be larger than or equal to one") + + return args + + +datasets = { + "power": {"d": 6, "dimh": 100, "num_hid_layers": 2}, + "gas": {"d": 8, "dimh": 100, "num_hid_layers": 2}, + "hepmass": {"d": 21, "dimh": 512, "num_hid_layers": 2}, + "miniboone": {"d": 43, "dimh": 512, "num_hid_layers": 1}, + "bsds300": {"d": 63, "dimh": 1024, "num_hid_layers": 2}, +} + + +def format_as_str(num): + if num / 1e9 > 1: + factor, suffix = 1e9, "B" + elif num / 1e6 > 1: + factor, suffix = 1e6, "M" + elif num / 1e3 > 1: + factor, suffix = 1e3, "K" + else: + factor, suffix = 1e0, "" + + num_factored = num / factor + + if num_factored / 1e2 > 1 or True: + num_rounded = str(int(round(num_factored))) + elif num_factored / 1e1 > 1: + num_rounded = f"{num_factored:.1f}" + else: + num_rounded = f"{num_factored:.2f}" + + return f"{num_rounded}{suffix} % {num}" + + +def naf_size(d, **kwargs): + from torchprune.util.net import NetHandle + + args = parse_args() + for key, val in kwargs.items(): + setattr(args, key, val) + model = MAF(args, d) + model = NetHandle(model.flow) + return model.size() + + +for dset, s_kwargs in datasets.items(): + print(f"{dset}: #params: {format_as_str(naf_size(**s_kwargs))}") + print("\n") diff --git a/paper/node/script/sizes/sizes_sos.py b/paper/node/script/sizes/sizes_sos.py new file mode 100644 index 0000000..af5dc28 --- /dev/null +++ b/paper/node/script/sizes/sizes_sos.py @@ -0,0 +1,459 @@ +# %% +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + +from torchprune.util.net import NetHandle + + +def get_mask(in_features, out_features, in_flow_features, mask_type=None): + """ + mask_type: input | None | output + + See Figure 1 for a better illustration: + https://arxiv.org/pdf/1502.03509.pdf + """ + if mask_type == "input": + in_degrees = torch.arange(in_features) % in_flow_features + else: + in_degrees = torch.arange(in_features) % (in_flow_features - 1) + + if mask_type == "output": + out_degrees = torch.arange(out_features) % in_flow_features - 1 + else: + out_degrees = torch.arange(out_features) % (in_flow_features - 1) + + return (out_degrees.unsqueeze(-1) >= in_degrees.unsqueeze(0)).float() + + +class MaskedLinear(nn.Linear): + def __init__(self, in_features, out_features, mask, bias=True): + super(MaskedLinear, self).__init__(in_features, out_features, bias) + self.register_buffer("mask", mask) + + def forward(self, inputs): + return F.linear(inputs, self.weight * self.mask, self.bias) + + +class ConditionerNet(nn.Module): + def __init__(self, input_size, hidden_size, k, m, n_layers=1): + super().__init__() + self.k = k + self.m = m + self.input_size = input_size + self.output_size = k * self.m * input_size + input_size + self.network = self._make_net( + input_size, hidden_size, self.output_size, n_layers + ) + + def _make_net(self, input_size, hidden_size, output_size, n_layers): + if self.input_size > 1: + input_mask = get_mask( + input_size, hidden_size, input_size, mask_type="input" + ) + hidden_mask = get_mask(hidden_size, hidden_size, input_size) + output_mask = get_mask( + hidden_size, output_size, input_size, mask_type="output" + ) + + network = nn.Sequential( + MaskedLinear(input_size, hidden_size, input_mask), + nn.ReLU(), + MaskedLinear(hidden_size, hidden_size, hidden_mask), + nn.ReLU(), + MaskedLinear(hidden_size, output_size, output_mask), + ) + else: + network = nn.Sequential( + nn.Linear(input_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, output_size), + ) + + """ + for module in network.modules(): + if isinstance(module, nn.Linear): + nn.init.orthogonal_(module.weight) + module.bias.data.fill_(0) + """ + + return network + + def forward(self, inputs): + batch_size = inputs.size(0) + params = self.network(inputs) + i = self.k * self.m * self.input_size + c = ( + params[:, :i] + .view(batch_size, -1, self.input_size) + .transpose(1, 2) + .view(batch_size, self.input_size, self.k, self.m, 1) + ) + const = params[:, i:].view(batch_size, self.input_size) + C = torch.matmul(c, c.transpose(3, 4)) + return C, const + + +# +# SOS Block: +# + + +class SOSFlow(nn.Module): + @staticmethod + def power(z, k): + return z ** (torch.arange(k).float().to(z.device)) + + def __init__(self, input_size, hidden_size, k, r, n_layers=1): + super().__init__() + self.k = k + self.m = r + 1 + + self.conditioner = ConditionerNet( + input_size, hidden_size, k, self.m, n_layers + ) + self.register_buffer("filter", self._make_filter()) + + def _make_filter(self): + n = torch.arange(self.m).unsqueeze(1) + e = torch.ones(self.m).unsqueeze(1).long() + filter = (n.mm(e.transpose(0, 1))) + (e.mm(n.transpose(0, 1))) + 1 + return filter.float() + + def forward(self, inputs, mode="direct"): + batch_size, input_size = inputs.size(0), inputs.size(1) + C, const = self.conditioner(inputs) + X = SOSFlow.power(inputs.unsqueeze(-1), self.m).view( + batch_size, input_size, 1, self.m, 1 + ) # bs x d x 1 x m x 1 + Z = self._transform(X, C / self.filter) * inputs + const + logdet = torch.log(torch.abs(self._transform(X, C))).sum( + dim=1, keepdim=True + ) + return Z, logdet + + def _transform(self, X, C): + CX = torch.matmul(C, X) # bs x d x k x m x 1 + XCX = torch.matmul(X.transpose(3, 4), CX) # bs x d x k x 1 x 1 + summed = XCX.squeeze(-1).squeeze(-1).sum(-1) # bs x d + return summed + + def _jacob(self, inputs, mode="direct"): + X = inputs.clone() + X.requires_grad_() + X.retain_grad() + d = X.size(0) + + X_in = X.unsqueeze(0) + C, const = self.conditioner(X_in) + Xpow = SOSFlow.power(X_in.unsqueeze(-1), self.m).view( + 1, d, 1, self.m, 1 + ) # bs x d x 1 x m x 1 + Z = (self._transform(Xpow, C / self.filter) * X_in + const).view(-1) + + J = torch.zeros(d, d) + for i in range(d): + self.zero_grad() + Z[i].backward(retain_graph=True) + J[i, :] = X.grad + + del X, X_in, C, const, Xpow, Z + return J + + +class MADE(nn.Module): + """An implementation of MADE + (https://arxiv.org/abs/1502.03509s). + """ + + def __init__(self, num_inputs, num_hidden, act="relu", pre_exp_tanh=False): + super(MADE, self).__init__() + + activations = {"relu": nn.ReLU, "sigmoid": nn.Sigmoid, "tanh": nn.Tanh} + act_func = activations[act] + + input_mask = get_mask( + num_inputs, num_hidden, num_inputs, mask_type="input" + ) + hidden_mask = get_mask(num_hidden, num_hidden, num_inputs) + output_mask = get_mask( + num_hidden, num_inputs * 2, num_inputs, mask_type="output" + ) + + self.joiner = MaskedLinear(num_inputs, num_hidden, input_mask) + + self.trunk = nn.Sequential( + act_func(), + MaskedLinear(num_hidden, num_hidden, hidden_mask), + act_func(), + MaskedLinear(num_hidden, num_inputs * 2, output_mask), + ) + + def forward(self, inputs, mode="direct"): + if mode == "direct": + h = self.joiner(inputs) + m, a = self.trunk(h).chunk(2, 1) + u = (inputs - m) * torch.exp(-a) + return u, -a.sum(-1, keepdim=True) + + else: + x = torch.zeros_like(inputs) + for i_col in range(inputs.shape[1]): + h = self.joiner(x) + m, a = self.trunk(h).chunk(2, 1) + x[:, i_col] = ( + inputs[:, i_col] * torch.exp(a[:, i_col]) + m[:, i_col] + ) + return x, -a.sum(-1, keepdim=True) + + def _jacob(self, inputs): + X = inputs.clone() + X.requires_grad_() + X.retain_grad() + d = X.size(0) + + X_in = X.unsqueeze(0) + + h = self.joiner(X_in) + m, a = self.trunk(h).chunk(2, 1) + u = ((X_in - m) * torch.exp(-a)).view(-1) + + J = torch.zeros(d, d) + for i in range(d): + self.zero_grad() + u[i].backward(retain_graph=True) + J[i, :] = X.grad + + del X, X_in, h, m, a, u + return J + + +class BatchNormFlow(nn.Module): + """An implementation of a batch normalization layer from + Density estimation using Real NVP + (https://arxiv.org/abs/1605.08803). + """ + + def __init__(self, num_inputs, momentum=0.0, eps=1e-5): + super(BatchNormFlow, self).__init__() + + self.log_gamma = nn.Parameter(torch.zeros(num_inputs)) + self.beta = nn.Parameter(torch.zeros(num_inputs)) + self.momentum = momentum + self.eps = eps + + self.register_buffer("running_mean", torch.zeros(num_inputs)) + self.register_buffer("running_var", torch.ones(num_inputs)) + + def forward(self, inputs, mode="direct"): + if mode == "direct": + if True: # self.training: + self.batch_mean = inputs.mean(0) + self.batch_var = (inputs - self.batch_mean).pow(2).mean( + 0 + ) + self.eps + + self.running_mean.mul_(self.momentum) + self.running_var.mul_(self.momentum) + + self.running_mean.add_( + self.batch_mean.data * (1 - self.momentum) + ) + self.running_var.add_( + self.batch_var.data * (1 - self.momentum) + ) + + mean = self.batch_mean + var = self.batch_var + else: + mean = self.running_mean + var = self.running_var + + x_hat = (inputs - mean) / var.sqrt() + y = torch.exp(self.log_gamma) * x_hat + self.beta + return y, (self.log_gamma - 0.5 * torch.log(var)).sum( + -1, keepdim=True + ) + else: + if True: # self.training: + mean = self.batch_mean + var = self.batch_var + else: + mean = self.running_mean + var = self.running_var + + x_hat = (inputs - self.beta) / torch.exp(self.log_gamma) + + y = x_hat * var.sqrt() + mean + + return y, (-self.log_gamma + 0.5 * torch.log(var)).sum( + -1, keepdim=True + ) + + def _jacob(self, X): + return None + + +class Reverse(nn.Module): + """An implementation of a reversing layer from + Density estimation using Real NVP + (https://arxiv.org/abs/1605.08803). + """ + + def __init__(self, num_inputs): + super(Reverse, self).__init__() + self.perm = np.array(np.arange(0, num_inputs)[::-1]) + self.inv_perm = np.argsort(self.perm) + + def forward(self, inputs, mode="direct"): + if mode == "direct": + return inputs[:, self.perm], torch.zeros( + inputs.size(0), 1, device=inputs.device + ) + else: + return inputs[:, self.inv_perm], torch.zeros( + inputs.size(0), 1, device=inputs.device + ) + + def _jacob(self, X): + return None + + +class FlowSequential(nn.Sequential): + """A sequential container for flows. + In addition to a forward pass it implements a backward pass and + computes log jacobians. + """ + + def forward(self, inputs, mode="direct", logdets=None): + """Performs a forward or backward pass for flow modules. + Args: + inputs: a tuple of inputs and logdets + mode: to run direct computation or inverse + """ + self.num_inputs = inputs.size(-1) + + if logdets is None: + logdets = torch.zeros(inputs.size(0), 1, device=inputs.device) + + assert mode in ["direct", "inverse"] + if mode == "direct": + for module in self._modules.values(): + inputs, logdet = module(inputs, mode) + logdets += logdet + else: + for module in reversed(self._modules.values()): + inputs, logdet = module(inputs, mode) + logdets += logdet + + return inputs, logdets + + def evaluate(self, inputs): + N = len(self._modules) + outputs = torch.zeros( + N + 1, inputs.size(0), inputs.size(1), device=inputs.device + ) + outputs[0, :, :] = inputs + logdets = torch.zeros(N, inputs.size(0), 1, device=inputs.device) + for i in range(N): + outputs[i + 1, :, :], logdets[i, :, :] = self._modules[str(i)]( + outputs[i, :, :], mode="direct" + ) + return outputs, logdets + + def log_probs(self, inputs): + u, log_jacob = self(inputs) + log_probs = (-0.5 * u.pow(2) - 0.5 * math.log(2 * math.pi)).sum( + -1, keepdim=True + ) + return (log_probs + log_jacob).sum(-1, keepdim=True) + + def sample(self, num_samples=None, noise=None, cond_inputs=None): + if noise is None: + noise = torch.Tensor(num_samples, self.num_inputs).normal_() + device = next(self.parameters()).device + noise = noise.to(device) + if cond_inputs is not None: + cond_inputs = cond_inputs.to(device) + samples = self.forward(noise, cond_inputs, mode="inverse")[0] + return samples + + def jacobians(self, X): + assert len(X.size()) == 1 + N = len(self._modules) + num_inputs = X.size(-1) + jacobians = torch.zeros(N, num_inputs, num_inputs) + + n_jacob = 0 + for i in range(N): + J_i = self._modules[str(i)]._jacob(X) + if J_i is not None: + jacobians[n_jacob, :, :] = J_i + n_jacob += 1 + del J_i + + return jacobians[:n_jacob, :, :] + + +def build_model(input_size, hidden_size, k, r, n_blocks): + modules = [] + for i in range(n_blocks): + modules += [ + SOSFlow(input_size, hidden_size, k, r), + BatchNormFlow(input_size), + Reverse(input_size), + ] + model = FlowSequential(*modules) + + for module in model.modules(): + if isinstance(module, nn.Linear): + nn.init.orthogonal_(module.weight) + + return model + + +datasets = { + "power": {"hidden_size": 100, "k": 5, "r": 4, "d": 6, "n_blocks": 8}, + "gas": {"hidden_size": 100, "k": 5, "r": 4, "d": 8, "n_blocks": 8}, + "hepmass": {"hidden_size": 512, "k": 5, "r": 4, "d": 21, "n_blocks": 8}, + "miniboone": {"hidden_size": 512, "k": 5, "r": 4, "d": 43, "n_blocks": 8}, + "bsds300": {"hidden_size": 512, "k": 5, "r": 4, "d": 63, "n_blocks": 8}, + "mnist": {"hidden_size": 100, "k": 5, "r": 4, "d": 784, "n_blocks": 8}, + "cifar": {"hidden_size": 100, "k": 5, "r": 4, "d": 3072, "n_blocks": 8}, +} + + +def format_as_str(num): + if num / 1e9 > 1: + factor, suffix = 1e9, "B" + elif num / 1e6 > 1: + factor, suffix = 1e6, "M" + elif num / 1e3 > 1: + factor, suffix = 1e3, "K" + else: + factor, suffix = 1e0, "" + + num_factored = num / factor + + if num_factored / 1e2 > 1 or True: + num_rounded = str(int(round(num_factored))) + elif num_factored / 1e1 > 1: + num_rounded = f"{num_factored:.1f}" + else: + num_rounded = f"{num_factored:.2f}" + + return f"{num_rounded}{suffix} % {num}" + + +def sos_size(d, hidden_size, k, r, n_blocks): + model = build_model(d, hidden_size, k, r, n_blocks) + model = NetHandle(model) + return model.size() + + +for dset, s_kwargs in datasets.items(): + print(f"{dset}: #params: {format_as_str(sos_size(**s_kwargs))}") diff --git a/paper/node/script/sparsehessian.py b/paper/node/script/sparsehessian.py new file mode 100644 index 0000000..408fe49 --- /dev/null +++ b/paper/node/script/sparsehessian.py @@ -0,0 +1,321 @@ +# * +# @file Different utility functions +# Copyright (c) Zhewei Yao, Amir Gholami +# All rights reserved. +# This file is part of PyHessian library. +# +# PyHessian is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PyHessian is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PyHessian. If not, see+ +
diff --git a/src/torchprune/torchprune/util/models/cnn/__init__.py b/src/torchprune/torchprune/util/external/ffjord/__init__.py similarity index 100% rename from src/torchprune/torchprune/util/models/cnn/__init__.py rename to src/torchprune/torchprune/util/external/ffjord/__init__.py diff --git a/src/torchprune/torchprune/util/external/ffjord/assets/github_flow.gif b/src/torchprune/torchprune/util/external/ffjord/assets/github_flow.gif new file mode 100644 index 0000000..26e8d54 Binary files /dev/null and b/src/torchprune/torchprune/util/external/ffjord/assets/github_flow.gif differ diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/LICENSE.txt b/src/torchprune/torchprune/util/external/ffjord/datasets/LICENSE.txt new file mode 100644 index 0000000..6faf4ff --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/LICENSE.txt @@ -0,0 +1,26 @@ +Copyright (c) 2017, George Papamakarios +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +The views and conclusions contained in the software and documentation are those +of the authors and should not be interpreted as representing official policies, +either expressed or implied, of anybody else. diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/__init__.py b/src/torchprune/torchprune/util/external/ffjord/datasets/__init__.py new file mode 100644 index 0000000..7220778 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/__init__.py @@ -0,0 +1,5 @@ +from .power import POWER +from .gas import GAS +from .hepmass import HEPMASS +from .miniboone import MINIBOONE +from .bsds300 import BSDS300 diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/bsds300.py b/src/torchprune/torchprune/util/external/ffjord/datasets/bsds300.py new file mode 100644 index 0000000..0d1e026 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/bsds300.py @@ -0,0 +1,33 @@ +import os +import numpy as np +import h5py + + +class BSDS300: + """ + A dataset of patches from BSDS300. + """ + + class Data: + """ + Constructs the dataset. + """ + + def __init__(self, data): + + self.x = data[:] + self.N = self.x.shape[0] + + def __init__(self, root): + + # load dataset + f = h5py.File(os.path.join(root, "BSDS300", "BSDS300.hdf5"), "r") + + self.trn = self.Data(f["train"]) + self.val = self.Data(f["validation"]) + self.tst = self.Data(f["test"]) + + self.n_dims = self.trn.x.shape[1] + self.image_size = [int(np.sqrt(self.n_dims + 1))] * 2 + + f.close() diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/gas.py b/src/torchprune/torchprune/util/external/ffjord/datasets/gas.py new file mode 100644 index 0000000..5bd68c9 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/gas.py @@ -0,0 +1,69 @@ +import os +import pandas as pd +import numpy as np + + +class GAS: + class Data: + def __init__(self, data): + + self.x = data.astype(np.float32) + self.N = self.x.shape[0] + + def __init__(self, root): + + file = os.path.join(root, "gas", "ethylene_CO.pickle") + trn, val, tst = load_data_and_clean_and_split(file) + + self.trn = self.Data(trn) + self.val = self.Data(val) + self.tst = self.Data(tst) + + self.n_dims = self.trn.x.shape[1] + + +def load_data(file): + + data = pd.read_pickle(file) + # data = pd.read_pickle(file).sample(frac=0.25) + # data.to_pickle(file) + data.drop("Meth", axis=1, inplace=True) + data.drop("Eth", axis=1, inplace=True) + data.drop("Time", axis=1, inplace=True) + return data + + +def get_correlation_numbers(data): + C = data.corr() + A = C > 0.98 + B = A.to_numpy().sum(axis=1) + return B + + +def load_data_and_clean(file): + + data = load_data(file) + B = get_correlation_numbers(data) + + while np.any(B > 1): + col_to_remove = np.where(B > 1)[0][0] + col_name = data.columns[col_to_remove] + data.drop(col_name, axis=1, inplace=True) + B = get_correlation_numbers(data) + # print(data.corr()) + data = (data - data.mean()) / data.std() + + return data + + +def load_data_and_clean_and_split(file): + + data = load_data_and_clean(file).to_numpy() + N_test = int(0.1 * data.shape[0]) + data_test = data[-N_test:] + data_train = data[0:-N_test] + N_validate = int(0.1 * data_train.shape[0]) + data_validate = data_train[-N_validate:] + data_train = data_train[0:-N_validate] + + return data_train, data_validate, data_test diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/hepmass.py b/src/torchprune/torchprune/util/external/ffjord/datasets/hepmass.py new file mode 100644 index 0000000..2fe35d9 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/hepmass.py @@ -0,0 +1,112 @@ +import os +import pandas as pd +import numpy as np +from collections import Counter + + +class HEPMASS: + """ + The HEPMASS data set. + http://archive.ics.uci.edu/ml/datasets/HEPMASS + """ + + class Data: + def __init__(self, data): + + self.x = data.astype(np.float32) + self.N = self.x.shape[0] + + def __init__(self, root): + + path = os.path.join(root, "hepmass") + trn, val, tst = load_data_no_discrete_normalised_as_array(path) + + self.trn = self.Data(trn) + self.val = self.Data(val) + self.tst = self.Data(tst) + + self.n_dims = self.trn.x.shape[1] + + +def load_data(path): + + data_train = pd.read_csv( + filepath_or_buffer=os.path.join(path, "1000_train.csv"), + index_col=False, + ) + data_test = pd.read_csv( + filepath_or_buffer=os.path.join(path, "1000_test.csv"), index_col=False + ) + + return data_train, data_test + + +def load_data_no_discrete(path): + """ + Loads the positive class examples from the first 10 percent of the dataset. + """ + data_train, data_test = load_data(path) + + # Gets rid of any background noise examples i.e. class label 0. + data_train = data_train[data_train[data_train.columns[0]] == 1] + data_train = data_train.drop(data_train.columns[0], axis=1) + data_test = data_test[data_test[data_test.columns[0]] == 1] + data_test = data_test.drop(data_test.columns[0], axis=1) + # Because the data set is messed up! + data_test = data_test.drop(data_test.columns[-1], axis=1) + + return data_train, data_test + + +def load_data_no_discrete_normalised(path): + + data_train, data_test = load_data_no_discrete(path) + mu = data_train.mean() + s = data_train.std() + data_train = (data_train - mu) / s + data_test = (data_test - mu) / s + + return data_train, data_test + + +def load_data_no_discrete_normalised_as_array(path): + + data_train, data_test = load_data_no_discrete_normalised(path) + data_train, data_test = data_train.to_numpy(), data_test.to_numpy() + + i = 0 + # Remove any features that have too many re-occurring real values. + features_to_remove = [] + for feature in data_train.T: + c = Counter(feature) + max_count = np.array([v for k, v in sorted(c.items())])[0] + if max_count > 5: + features_to_remove.append(i) + i += 1 + data_train = data_train[ + :, + np.array( + [ + i + for i in range(data_train.shape[1]) + if i not in features_to_remove + ] + ), + ] + data_test = data_test[ + :, + np.array( + [ + i + for i in range(data_test.shape[1]) + if i not in features_to_remove + ] + ), + ] + + N = data_train.shape[0] + N_validate = int(N * 0.1) + data_validate = data_train[-N_validate:] + data_train = data_train[0:-N_validate] + + return data_train, data_validate, data_test diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/miniboone.py b/src/torchprune/torchprune/util/external/ffjord/datasets/miniboone.py new file mode 100644 index 0000000..10f8881 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/miniboone.py @@ -0,0 +1,66 @@ +import os +import numpy as np + + +class MINIBOONE: + class Data: + def __init__(self, data): + + self.x = data.astype(np.float32) + self.N = self.x.shape[0] + + def __init__(self, root): + + file = os.path.join(root, "miniboone", "data.npy") + trn, val, tst = load_data_normalised(file) + + self.trn = self.Data(trn) + self.val = self.Data(val) + self.tst = self.Data(tst) + + self.n_dims = self.trn.x.shape[1] + + +def load_data(root_path): + # NOTE: To remember how the pre-processing was done. + # data = pd.read_csv(root_path, names=[str(x) for x in range(50)], delim_whitespace=True) + # print data.head() + # data = data.as_matrix() + # # Remove some random outliers + # indices = (data[:, 0] < -100) + # data = data[~indices] + # + # i = 0 + # # Remove any features that have too many re-occuring real values. + # features_to_remove = [] + # for feature in data.T: + # c = Counter(feature) + # max_count = np.array([v for k, v in sorted(c.iteritems())])[0] + # if max_count > 5: + # features_to_remove.append(i) + # i += 1 + # data = data[:, np.array([i for i in range(data.shape[1]) if i not in features_to_remove])] + # np.save("~/data/miniboone/data.npy", data) + + data = np.load(root_path) + N_test = int(0.1 * data.shape[0]) + data_test = data[-N_test:] + data = data[0:-N_test] + N_validate = int(0.1 * data.shape[0]) + data_validate = data[-N_validate:] + data_train = data[0:-N_validate] + + return data_train, data_validate, data_test + + +def load_data_normalised(root_path): + + data_train, data_validate, data_test = load_data(root_path) + data = np.vstack((data_train, data_validate)) + mu = data.mean(axis=0) + s = data.std(axis=0) + data_train = (data_train - mu) / s + data_validate = (data_validate - mu) / s + data_test = (data_test - mu) / s + + return data_train, data_validate, data_test diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/power.py b/src/torchprune/torchprune/util/external/ffjord/datasets/power.py new file mode 100644 index 0000000..5539732 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/power.py @@ -0,0 +1,71 @@ +import os +import numpy as np + + +class POWER: + class Data: + def __init__(self, data): + + self.x = data.astype(np.float32) + self.N = self.x.shape[0] + + def __init__(self, root): + + trn, val, tst = load_data_normalised(root) + + self.trn = self.Data(trn) + self.val = self.Data(val) + self.tst = self.Data(tst) + + self.n_dims = self.trn.x.shape[1] + + +def load_data(root): + return np.load(os.path.join(root, "power", "data.npy")) + + +def load_data_split_with_noise(root): + + rng = np.random.RandomState(42) + + data = load_data(root) + rng.shuffle(data) + N = data.shape[0] + + data = np.delete(data, 3, axis=1) + data = np.delete(data, 1, axis=1) + ############################ + # Add noise + ############################ + # global_intensity_noise = 0.1*rng.rand(N, 1) + voltage_noise = 0.01 * rng.rand(N, 1) + # grp_noise = 0.001*rng.rand(N, 1) + gap_noise = 0.001 * rng.rand(N, 1) + sm_noise = rng.rand(N, 3) + time_noise = np.zeros((N, 1)) + # noise = np.hstack((gap_noise, grp_noise, voltage_noise, global_intensity_noise, sm_noise, time_noise)) + # noise = np.hstack((gap_noise, grp_noise, voltage_noise, sm_noise, time_noise)) + noise = np.hstack((gap_noise, voltage_noise, sm_noise, time_noise)) + data = data + noise + + N_test = int(0.1 * data.shape[0]) + data_test = data[-N_test:] + data = data[0:-N_test] + N_validate = int(0.1 * data.shape[0]) + data_validate = data[-N_validate:] + data_train = data[0:-N_validate] + + return data_train, data_validate, data_test + + +def load_data_normalised(root): + + data_train, data_validate, data_test = load_data_split_with_noise(root) + data = np.vstack((data_train, data_validate)) + mu = data.mean(axis=0) + s = data.std(axis=0) + data_train = (data_train - mu) / s + data_validate = (data_validate - mu) / s + data_test = (data_test - mu) / s + + return data_train, data_validate, data_test diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/approx_error_1d.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/approx_error_1d.py new file mode 100644 index 0000000..c4a52d7 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/approx_error_1d.py @@ -0,0 +1,265 @@ +from inspect import getsourcefile +import sys +import os + +current_path = os.path.abspath(getsourcefile(lambda: 0)) +current_dir = os.path.dirname(current_path) +parent_dir = current_dir[:current_dir.rfind(os.path.sep)] +sys.path.insert(0, parent_dir) + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import numpy as np +import argparse +import os +import time + +import torch +import torch.optim as optim + +import lib.utils as utils +import lib.layers.odefunc as odefunc + +from train_misc import standard_normal_logprob +from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time +from train_misc import build_model_tabular + +import seaborn as sns +sns.set_style("whitegrid") +colors = ["windows blue", "amber", "greyish", "faded green", "dusty purple"] +sns.palplot(sns.xkcd_palette(colors)) + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser('Continuous Normalizing Flow') +parser.add_argument( + "--layer_type", type=str, default="concatsquash", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument('--dims', type=str, default='64-64-64') +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') +parser.add_argument('--time_length', type=float, default=0.5) +parser.add_argument('--train_T', type=eval, default=True) +parser.add_argument("--divergence_fn", type=str, default="brute_force", choices=["brute_force", "approximate"]) +parser.add_argument("--nonlinearity", type=str, default="tanh", choices=odefunc.NONLINEARITIES) + +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) + +parser.add_argument('--niters', type=int, default=10000) +parser.add_argument('--batch_size', type=int, default=100) +parser.add_argument('--test_batch_size', type=int, default=1000) +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--weight_decay', type=float, default=1e-5) + +# Track quantities +parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1") +parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2") +parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2") +parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F") +parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F") +parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") + +parser.add_argument('--save', type=str, default='experiments/approx_error_1d') +parser.add_argument('--viz_freq', type=int, default=100) +parser.add_argument('--val_freq', type=int, default=100) +parser.add_argument('--log_freq', type=int, default=10) +parser.add_argument('--gpu', type=int, default=0) +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) + +if args.layer_type == "blend": + logger.info("!! Setting time_length from None to 1.0 due to use of Blend layers.") + args.time_length = 1.0 + +logger.info(args) + +device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + + +def normal_log_density(x, mean=0, stdev=1): + term = (x - mean) / stdev + return -0.5 * (np.log(2 * np.pi) + 2 * np.log(stdev) + term * term) + + +def data_sample(batch_size): + x1 = np.random.randn(batch_size) * np.sqrt(0.4) - 2.8 + x2 = np.random.randn(batch_size) * np.sqrt(0.4) - 0.9 + x3 = np.random.randn(batch_size) * np.sqrt(0.4) + 2. + xs = np.concatenate([x1[:, None], x2[:, None], x3[:, None]], 1) + k = np.random.randint(0, 3, batch_size) + x = xs[np.arange(batch_size), k] + return torch.tensor(x[:, None]).float().to(device) + + +def data_density(x): + p1 = normal_log_density(x, mean=-2.8, stdev=np.sqrt(0.4)) + p2 = normal_log_density(x, mean=-0.9, stdev=np.sqrt(0.4)) + p3 = normal_log_density(x, mean=2.0, stdev=np.sqrt(0.4)) + return torch.log(p1.exp() / 3 + p2.exp() / 3 + p3.exp() / 3) + + +def model_density(x, model): + x = x.to(device) + z, delta_logp = model(x, torch.zeros_like(x)) + logpx = standard_normal_logprob(z) - delta_logp + return logpx + + +def model_sample(model, batch_size): + z = torch.randn(batch_size, 1) + logqz = standard_normal_logprob(z) + x, logqx = model(z, logqz, reverse=True) + return x, logqx + + +def compute_loss(args, model, batch_size=None): + if batch_size is None: batch_size = args.batch_size + + x = data_sample(batch_size) + logpx = model_density(x, model) + return -torch.mean(logpx) + + +def train(): + + model = build_model_tabular(args, 1).to(device) + set_cnf_options(args, model) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + time_meter = utils.RunningAverageMeter(0.93) + loss_meter = utils.RunningAverageMeter(0.93) + nfef_meter = utils.RunningAverageMeter(0.93) + nfeb_meter = utils.RunningAverageMeter(0.93) + tt_meter = utils.RunningAverageMeter(0.93) + + end = time.time() + best_loss = float('inf') + model.train() + for itr in range(1, args.niters + 1): + optimizer.zero_grad() + + loss = compute_loss(args, model) + loss_meter.update(loss.item()) + + total_time = count_total_time(model) + nfe_forward = count_nfe(model) + + loss.backward() + optimizer.step() + + nfe_total = count_nfe(model) + nfe_backward = nfe_total - nfe_forward + nfef_meter.update(nfe_forward) + nfeb_meter.update(nfe_backward) + + time_meter.update(time.time() - end) + tt_meter.update(total_time) + + log_message = ( + 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' + ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( + itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, + nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg + ) + ) + logger.info(log_message) + + if itr % args.val_freq == 0 or itr == args.niters: + with torch.no_grad(): + model.eval() + test_loss = compute_loss(args, model, batch_size=args.test_batch_size) + test_nfe = count_nfe(model) + log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe) + logger.info(log_message) + + if test_loss.item() < best_loss: + best_loss = test_loss.item() + utils.makedirs(args.save) + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + }, os.path.join(args.save, 'checkpt.pth')) + model.train() + + if itr % args.viz_freq == 0: + with torch.no_grad(): + model.eval() + + xx = torch.linspace(-10, 10, 10000).view(-1, 1) + true_p = data_density(xx) + plt.plot(xx.view(-1).cpu().numpy(), true_p.view(-1).exp().cpu().numpy(), label='True') + + true_p = model_density(xx, model) + plt.plot(xx.view(-1).cpu().numpy(), true_p.view(-1).exp().cpu().numpy(), label='Model') + + utils.makedirs(os.path.join(args.save, 'figs')) + plt.savefig(os.path.join(args.save, 'figs', '{:06d}.jpg'.format(itr))) + plt.close() + + model.train() + + end = time.time() + + logger.info('Training has finished.') + + +def evaluate(): + model = build_model_tabular(args, 1).to(device) + set_cnf_options(args, model) + + checkpt = torch.load(os.path.join(args.save, 'checkpt.pth')) + model.load_state_dict(checkpt['state_dict']) + model.to(device) + + tols = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] + errors = [] + with torch.no_grad(): + for tol in tols: + args.rtol = tol + args.atol = tol + set_cnf_options(args, model) + + xx = torch.linspace(-15, 15, 500000).view(-1, 1).to(device) + prob_xx = model_density(xx, model).double().view(-1).cpu() + xx = xx.double().cpu().view(-1) + dxx = torch.log(xx[1:] - xx[:-1]) + num_integral = torch.logsumexp(prob_xx[:-1] + dxx, 0).exp() + errors.append(float(torch.abs(num_integral - 1.))) + + print(errors[-1]) + + plt.figure(figsize=(5, 3)) + plt.plot(tols, errors, linewidth=3, marker='o', markersize=7) + # plt.plot([-1, 0.2], [-1, 0.2], '--', color='grey', linewidth=1) + plt.xscale("log", nonposx='clip') + # plt.yscale("log", nonposy='clip') + plt.xlabel('Solver Tolerance', fontsize=17) + plt.ylabel('$| 1 - \int p(x) |$', fontsize=17) + plt.tight_layout() + plt.savefig('ode_solver_error_vs_tol.pdf') + + +if __name__ == '__main__': + # train() + evaluate() diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_bottleneck_losses.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_bottleneck_losses.py new file mode 100644 index 0000000..82761c0 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_bottleneck_losses.py @@ -0,0 +1,70 @@ +import re +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import scipy.signal +import scipy.ndimage + +# BASE = "experiments/cnf_mnist_64-64-128-128-64-64/logs" +# RESIDUAL = "experiments/cnf_mnist_64-64-128-128-64-64_residual/logs" +# RADEMACHER = "experiments/cnf_mnist_64-64-128-128-64-64_rademacher/logs" + +BOTTLENECK = "experiments/cnf_mnist_bottleneck_64-64-128-5-128-64-64/logs" +BOTTLENECK_EST = "experiments/cnf_mnist_bottleneck_64-64-128-5-128-64-64_ae-est/logs" +RAD_BOTTLENECK = "experiments/cnf_mnist_bottleneck_64-64-128-5-128-64-64_rademacher/logs" +RAD_BOTTLENECK_EST = "experiments/cnf_mnist_bottleneck_64-64-128-5-128-64-64_ae-est_rademacher/logs" + +# ET_ALL = "experiments/cnf_mnist_bottleneck_64-64-128-5-128-64-64_ae-est_residual_rademacher/logs" + + +def get_losses(filename): + with open(filename, "r") as f: + lines = f.readlines() + + losses = [] + + for line in lines: + w = re.findall(r"Bit/dim [^|(]*\([0-9\.]*\)", line) + if w: w = re.findall(r"\([0-9\.]*\)", w[0]) + if w: w = re.findall(r"[0-9\.]+", w[0]) + if w: + losses.append(float(w[0])) + + return losses + + +bottleneck_loss = get_losses(BOTTLENECK) +bottleneck_est_loss = get_losses(BOTTLENECK_EST) +rademacher_bottleneck_loss = get_losses(RAD_BOTTLENECK) +rademacher_bottleneck_est_loss = get_losses(RAD_BOTTLENECK_EST) + +bottleneck_loss = scipy.signal.medfilt(bottleneck_loss, 21) +bottleneck_est_loss = scipy.signal.medfilt(bottleneck_est_loss, 21) +rademacher_bottleneck_loss = scipy.signal.medfilt(rademacher_bottleneck_loss, 21) +rademacher_bottleneck_est_loss = scipy.signal.medfilt(rademacher_bottleneck_est_loss, 21) + +import seaborn as sns +sns.set_style("whitegrid") +colors = ["windows blue", "amber", "greyish", "faded green", "dusty purple"] +sns.palplot(sns.xkcd_palette(colors)) + +import brewer2mpl +line_colors = brewer2mpl.get_map('Set2', 'qualitative', 4).mpl_colors +dark_colors = brewer2mpl.get_map('Dark2', 'qualitative', 4).mpl_colors +# plt.style.use('ggplot') + +plt.figure(figsize=(4, 3)) +plt.plot(np.arange(len(bottleneck_loss)) / 30, bottleneck_loss, ':', color=line_colors[1], label="Gaussian w/o Trick") +plt.plot(np.arange(len(bottleneck_est_loss)) / 30, bottleneck_est_loss, color=dark_colors[1], label="Gaussian w/ Trick") +plt.plot(np.arange(len(rademacher_bottleneck_loss)) / 30, rademacher_bottleneck_loss, ':', color=line_colors[2], label="Rademacher w/o Trick") +plt.plot(np.arange(len(rademacher_bottleneck_est_loss)) / 30, rademacher_bottleneck_est_loss, color=dark_colors[2], label="Rademacher w/ Trick") + +plt.legend(frameon=True, fontsize=10.5, loc='upper right') +plt.ylim([1.1, 1.7]) +# plt.yscale("log", nonposy='clip') +plt.xlabel("Epoch", fontsize=18) +plt.ylabel("Bits/dim", fontsize=18) +plt.xlim([0, 170]) +plt.tight_layout() +plt.savefig('bottleneck_losses.pdf') diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_compare_multiscale.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_compare_multiscale.py new file mode 100644 index 0000000..36ff110 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_compare_multiscale.py @@ -0,0 +1,60 @@ +import re +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +MNIST_SINGLESCALE = "diagnostics/mnist.log" +MNIST_MULTISCALE = "diagnostics/mnist_multiscale.log" + + +def get_values(filename): + with open(filename, "r") as f: + lines = f.readlines() + + losses = [] + nfes = [] + + for line in lines: + + w = re.findall(r"Steps [^|(]*\([0-9\.]*\)", line) + if w: w = re.findall(r"\([0-9\.]*\)", w[0]) + if w: w = re.findall(r"[0-9\.]+", w[0]) + if w: + nfes.append(float(w[0])) + + w = re.findall(r"Bit/dim [^|(]*\([0-9\.]*\)", line) + if w: w = re.findall(r"\([0-9\.]*\)", w[0]) + if w: w = re.findall(r"[0-9\.]+", w[0]) + if w: + losses.append(float(w[0])) + + return losses, nfes + + +mnist_singlescale_loss, mnist_singlescale_nfes = get_values(MNIST_SINGLESCALE) +mnist_multiscale_loss, mnist_multiscale_nfes = get_values(MNIST_MULTISCALE) + +import brewer2mpl +line_colors = brewer2mpl.get_map('Set2', 'qualitative', 4).mpl_colors +dark_colors = brewer2mpl.get_map('Dark2', 'qualitative', 4).mpl_colors + +import seaborn as sns +sns.set_style("whitegrid") +colors = ["windows blue", "amber", "greyish", "faded green", "dusty purple"] +sns.palplot(sns.xkcd_palette(colors)) + +plt.figure(figsize=(4, 2.6)) +plt.scatter(mnist_singlescale_nfes[::10], mnist_singlescale_loss[::10], color=line_colors[1], label="Single FFJORD") +plt.scatter(mnist_multiscale_nfes[::10], mnist_multiscale_loss[::10], color=line_colors[2], label="Multiscale FFJORD") + +plt.ylim([0.9, 1.25]) +plt.legend(frameon=True, fontsize=10.5) +plt.xlabel("NFE", fontsize=18) +plt.ylabel("Bits/dim", fontsize=18) + +ax = plt.gca() +ax.tick_params(axis='both', which='major', labelsize=14) +ax.tick_params(axis='both', which='minor', labelsize=10) + +plt.tight_layout() +plt.savefig('multiscale_loss_vs_nfe.pdf') diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_flows.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_flows.py new file mode 100644 index 0000000..d284a33 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_flows.py @@ -0,0 +1,157 @@ +from inspect import getsourcefile +import sys +import os + +current_path = os.path.abspath(getsourcefile(lambda: 0)) +current_dir = os.path.dirname(current_path) +parent_dir = current_dir[:current_dir.rfind(os.path.sep)] +sys.path.insert(0, parent_dir) + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import argparse +import os + +import torch + +import lib.toy_data as toy_data +import lib.utils as utils +import lib.visualize_flow as viz_flow +import lib.layers.odefunc as odefunc +import lib.layers as layers + +from train_misc import standard_normal_logprob +from train_misc import build_model_tabular, count_parameters + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser('Continuous Normalizing Flow') +parser.add_argument( + '--data', choices=['swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings'], + type=str, default='pinwheel' +) + +parser.add_argument('--discrete', action='store_true') + +parser.add_argument('--depth', help='number of coupling layers', type=int, default=10) +parser.add_argument('--glow', type=eval, choices=[True, False], default=False) + +parser.add_argument( + "--layer_type", type=str, default="concatsquash", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument('--dims', type=str, default='64-64-64') +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') +parser.add_argument('--time_length', type=float, default=0.5) +parser.add_argument('--train_T', type=eval, default=True) +parser.add_argument("--divergence_fn", type=str, default="brute_force", choices=["brute_force", "approximate"]) +parser.add_argument("--nonlinearity", type=str, default="tanh", choices=odefunc.NONLINEARITIES) + +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) + +parser.add_argument('--niters', type=int, default=2500) +parser.add_argument('--batch_size', type=int, default=100) +parser.add_argument('--test_batch_size', type=int, default=1000) +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--weight_decay', type=float, default=1e-5) + +parser.add_argument('--checkpt', type=str, required=True) +parser.add_argument('--save', type=str, default='experiments/cnf') +parser.add_argument('--viz_freq', type=int, default=100) +parser.add_argument('--val_freq', type=int, default=100) +parser.add_argument('--log_freq', type=int, default=10) +parser.add_argument('--gpu', type=int, default=0) +args = parser.parse_args() + +device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + + +def construct_discrete_model(): + + chain = [] + for i in range(args.depth): + if args.glow: chain.append(layers.BruteForceLayer(2)) + chain.append(layers.CouplingLayer(2, swap=i % 2 == 0)) + return layers.SequentialFlow(chain) + + +def get_transforms(model): + + def sample_fn(z, logpz=None): + if logpz is not None: + return model(z, logpz, reverse=True) + else: + return model(z, reverse=True) + + def density_fn(x, logpx=None): + if logpx is not None: + return model(x, logpx, reverse=False) + else: + return model(x, reverse=False) + + return sample_fn, density_fn + + +if __name__ == '__main__': + + if args.discrete: + model = construct_discrete_model().to(device) + model.load_state_dict(torch.load(args.checkpt)['state_dict']) + else: + model = build_model_tabular(args, 2).to(device) + + sd = torch.load(args.checkpt)['state_dict'] + fixed_sd = {} + for k, v in sd.items(): + fixed_sd[k.replace('odefunc.odefunc', 'odefunc')] = v + model.load_state_dict(fixed_sd) + + print(model) + print("Number of trainable parameters: {}".format(count_parameters(model))) + + model.eval() + p_samples = toy_data.inf_train_gen(args.data, batch_size=800**2) + + with torch.no_grad(): + sample_fn, density_fn = get_transforms(model) + + plt.figure(figsize=(10, 10)) + ax = ax = plt.gca() + viz_flow.plt_samples(p_samples, ax, npts=800) + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + fig_filename = os.path.join(args.save, 'figs', 'true_samples.jpg') + utils.makedirs(os.path.dirname(fig_filename)) + plt.savefig(fig_filename) + plt.close() + + plt.figure(figsize=(10, 10)) + ax = ax = plt.gca() + viz_flow.plt_flow_density(standard_normal_logprob, density_fn, ax, npts=800, memory=200, device=device) + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + fig_filename = os.path.join(args.save, 'figs', 'model_density.jpg') + utils.makedirs(os.path.dirname(fig_filename)) + plt.savefig(fig_filename) + plt.close() + + plt.figure(figsize=(10, 10)) + ax = ax = plt.gca() + viz_flow.plt_flow_samples(torch.randn, sample_fn, ax, npts=800, memory=200, device=device) + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + fig_filename = os.path.join(args.save, 'figs', 'model_samples.jpg') + utils.makedirs(os.path.dirname(fig_filename)) + plt.savefig(fig_filename) + plt.close() diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_nfe_vs_dim_vae.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_nfe_vs_dim_vae.py new file mode 100644 index 0000000..26126c5 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_nfe_vs_dim_vae.py @@ -0,0 +1,50 @@ +import os.path +import re +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import scipy.ndimage + +import seaborn as sns +sns.set_style("whitegrid") +colors = ["windows blue", "amber", "greyish", "faded green", "dusty purple"] +sns.palplot(sns.xkcd_palette(colors)) + +dims = [16, 32, 48, 64] +dirs = [ + 'vae_mnist_cnf_num_flows_4_256-256_num_blocks_1__2018-09-16_17_27_03', + 'vae_mnist_cnf_num_flows_4_256-256_num_blocks_1__2018-09-16_17_26_41', + 'vae_mnist_cnf_num_flows_4_256-256_num_blocks_1__2018-09-16_17_23_35', + 'vae_mnist_cnf_num_flows_4_256-256_num_blocks_1__2018-09-16_17_25_03', +] + +nfe_all = [] + +for dim, dirname in zip(dims, dirs): + with open(os.path.join('snapshots', dirname, 'logs'), 'r') as f: + lines = f.readlines() + + nfes_ = [] + + for line in lines: + w = re.findall(r"NFE Forward [0-9]*", line) + if w: w = re.findall(r"[0-9]+", w[0]) + if w: + nfes_.append(float(w[0])) + + nfe_all.append(nfes_) + +plt.figure(figsize=(4, 2.4)) +for i, (dim, nfes) in enumerate(zip(dims, nfe_all)): + nfes = np.array(nfes) + xx = (np.arange(len(nfes)) + 1) / 50 + nfes = scipy.ndimage.gaussian_filter(nfes, 101) + plt.plot(xx, nfes, '--', label='Dim {}'.format(dim)) + +plt.legend(frameon=True, fontsize=10.5) +plt.xlabel('Epoch', fontsize=18) +plt.ylabel('NFE', fontsize=18) +plt.xlim([0, 200]) +plt.tight_layout() +plt.savefig("nfes_vs_dim_vae.pdf") diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_sn_losses.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_sn_losses.py new file mode 100644 index 0000000..fc425f8 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_sn_losses.py @@ -0,0 +1,81 @@ +import re +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +CIFAR10 = "diagnostics/cifar10_multiscale.log" +CIFAR10_SN = "diagnostics/cifar10_multiscale_sn.log" + +MNIST = "diagnostics/mnist_multiscale.log" +MNIST_SN = "diagnostics/mnist_multiscale_sn.log" + + +def get_values(filename): + with open(filename, "r") as f: + lines = f.readlines() + + losses = [] + nfes = [] + + for line in lines: + + w = re.findall(r"Steps [^|(]*\([0-9\.]*\)", line) + if w: w = re.findall(r"\([0-9\.]*\)", w[0]) + if w: w = re.findall(r"[0-9\.]+", w[0]) + if w: + nfes.append(float(w[0])) + + w = re.findall(r"Bit/dim [^|(]*\([0-9\.]*\)", line) + if w: w = re.findall(r"\([0-9\.]*\)", w[0]) + if w: w = re.findall(r"[0-9\.]+", w[0]) + if w: + losses.append(float(w[0])) + + return losses, nfes + + +cifar10_loss, cifar10_nfes = get_values(CIFAR10) +cifar10_sn_loss, cifar10_sn_nfes = get_values(CIFAR10_SN) +mnist_loss, mnist_nfes = get_values(MNIST) +mnist_sn_loss, mnist_sn_nfes = get_values(MNIST_SN) + +import brewer2mpl +line_colors = brewer2mpl.get_map('Set2', 'qualitative', 4).mpl_colors +dark_colors = brewer2mpl.get_map('Dark2', 'qualitative', 4).mpl_colors +plt.style.use('ggplot') + +# CIFAR10 plot +plt.figure(figsize=(6, 7)) +plt.scatter(cifar10_nfes, cifar10_loss, color=line_colors[1], label="w/o Spectral Norm") +plt.scatter(cifar10_sn_nfes, cifar10_sn_loss, color=line_colors[2], label="w/ Spectral Norm") + +plt.ylim([3, 5]) +plt.legend(fontsize=18) +plt.xlabel("NFE", fontsize=30) +plt.ylabel("Bits/dim", fontsize=30) + +ax = plt.gca() +ax.tick_params(axis='both', which='major', labelsize=24) +ax.tick_params(axis='both', which='minor', labelsize=16) +ax.yaxis.set_ticks([3, 3.5, 4, 4.5, 5]) + +plt.tight_layout() +plt.savefig('cifar10_sn_loss_vs_nfe.pdf') + +# MNIST plot +plt.figure(figsize=(6, 7)) +plt.scatter(mnist_nfes, mnist_loss, color=line_colors[1], label="w/o Spectral Norm") +plt.scatter(mnist_sn_nfes, mnist_sn_loss, color=line_colors[2], label="w/ Spectral Norm") + +plt.ylim([0.9, 2]) +plt.legend(fontsize=18) +plt.xlabel("NFE", fontsize=30) +plt.ylabel("Bits/dim", fontsize=30) + +ax = plt.gca() +ax.tick_params(axis='both', which='major', labelsize=24) +ax.tick_params(axis='both', which='minor', labelsize=16) +# ax.yaxis.set_ticks([3, 3.5, 4, 4.5, 5]) + +plt.tight_layout() +plt.savefig('mnist_sn_loss_vs_nfe.pdf') diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/scrap_log.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/scrap_log.py new file mode 100644 index 0000000..c08e99b --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/scrap_log.py @@ -0,0 +1,64 @@ +import os +import re +import csv + + +def log_to_csv(log_filename, csv_filename): + with open(log_filename, 'r') as f: + lines = f.readlines() + + with open(csv_filename, 'w', newline='') as csvfile: + fieldnames = None + writer = None + + for line in lines: + if line.startswith('Iter'): + # A dictionary of quantity : value. + quants = _line_to_dict(line) + # Create writer and write header. + if fieldnames is None: + fieldnames = quants.keys() + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + # Write a line. + writer.writerow(quants) + + +def _line_to_dict(line): + line = re.sub(':', '', line) # strip colons. + line = re.sub('\([^)]*\)', '', line) # strip running averages. + + quants = {} + for quant_str in line.split('|'): + quant_str = quant_str.strip() # strip beginning and ending whitespaces. + key, val = quant_str.split(' ') + quants[key] = val + + return quants + + +def plot_pairplot(csv_filename, fig_filename, top=None): + import seaborn as sns + import pandas as pd + + sns.set(style="ticks", color_codes=True) + quants = pd.read_csv(csv_filename) + if top is not None: + quants = quants[:top] + + g = sns.pairplot(quants, kind='reg', diag_kind='kde', markers='.') + g.savefig(fig_filename) + + +if __name__ == '__main__': + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--log', type=str, required=True) + parser.add_argument('--top_iters', type=int, default=None) + args = parser.parse_args() + + print('Parsing log into csv.') + log_to_csv(args.log, args.log + '.csv') + print('Creating correlation plot.') + plot_pairplot(args.log + '.csv', os.path.join(os.path.dirname(args.log), 'quants.png'), args.top_iters) diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_cnf.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_cnf.py new file mode 100644 index 0000000..3cb73c6 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_cnf.py @@ -0,0 +1,257 @@ +from inspect import getsourcefile +import sys +import os +import subprocess + +current_path = os.path.abspath(getsourcefile(lambda: 0)) +current_dir = os.path.dirname(current_path) +parent_dir = current_dir[:current_dir.rfind(os.path.sep)] +sys.path.insert(0, parent_dir) + +import argparse +import torch +import torchvision.datasets as dset +import torchvision.transforms as tforms +from torchvision.utils import save_image + +import lib.layers as layers +import lib.spectral_norm as spectral_norm +import lib.utils as utils + + +def add_noise(x): + """ + [0, 1] -> [0, 255] -> add noise -> [0, 1] + """ + noise = x.new().resize_as_(x).uniform_() + x = x * 255 + noise + x = x / 256 + return x + + +def get_dataset(args): + trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise]) + + if args.data == "mnist": + im_dim = 1 + im_size = 28 if args.imagesize is None else args.imagesize + train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True) + test_set = dset.MNIST(root="./data", train=False, transform=trans(im_size), download=True) + elif args.data == "svhn": + im_dim = 3 + im_size = 32 if args.imagesize is None else args.imagesize + train_set = dset.SVHN(root="./data", split="train", transform=trans(im_size), download=True) + test_set = dset.SVHN(root="./data", split="test", transform=trans(im_size), download=True) + elif args.data == "cifar10": + im_dim = 3 + im_size = 32 if args.imagesize is None else args.imagesize + train_set = dset.CIFAR10(root="./data", train=True, transform=trans(im_size), download=True) + test_set = dset.CIFAR10(root="./data", train=False, transform=trans(im_size), download=True) + elif args.dataset == 'celeba': + im_dim = 3 + im_size = 64 if args.imagesize is None else args.imagesize + train_set = dset.CelebA( + train=True, transform=tforms.Compose([ + tforms.ToPILImage(), + tforms.Resize(im_size), + tforms.RandomHorizontalFlip(), + tforms.ToTensor(), + add_noise, + ]) + ) + test_set = dset.CelebA( + train=False, transform=tforms.Compose([ + tforms.ToPILImage(), + tforms.Resize(args.imagesize), + tforms.ToTensor(), + add_noise, + ]) + ) + data_shape = (im_dim, im_size, im_size) + + train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=args.batch_size, shuffle=False) + return train_loader, test_loader, data_shape + + +def add_spectral_norm(model): + def recursive_apply_sn(parent_module): + for child_name in list(parent_module._modules.keys()): + child_module = parent_module._modules[child_name] + classname = child_module.__class__.__name__ + if classname.find('Conv') != -1 and 'weight' in child_module._parameters: + del parent_module._modules[child_name] + parent_module.add_module(child_name, spectral_norm.spectral_norm(child_module, 'weight')) + else: + recursive_apply_sn(child_module) + + recursive_apply_sn(model) + + +def build_model(args, state_dict): + # load dataset + train_loader, test_loader, data_shape = get_dataset(args) + + hidden_dims = tuple(map(int, args.dims.split(","))) + strides = tuple(map(int, args.strides.split(","))) + + # neural net that parameterizes the velocity field + if args.autoencode: + + def build_cnf(): + autoencoder_diffeq = layers.AutoencoderDiffEqNet( + hidden_dims=hidden_dims, + input_shape=data_shape, + strides=strides, + conv=args.conv, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = layers.AutoencoderODEfunc( + autoencoder_diffeq=autoencoder_diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + cnf = layers.CNF( + odefunc=odefunc, + T=args.time_length, + solver=args.solver, + ) + return cnf + else: + + def build_cnf(): + diffeq = layers.ODEnet( + hidden_dims=hidden_dims, + input_shape=data_shape, + strides=strides, + conv=args.conv, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = layers.ODEfunc( + diffeq=diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + cnf = layers.CNF( + odefunc=odefunc, + T=args.time_length, + solver=args.solver, + ) + return cnf + + chain = [layers.LogitTransform(alpha=args.alpha), build_cnf()] + if args.batch_norm: + chain.append(layers.MovingBatchNorm2d(data_shape[0])) + model = layers.SequentialFlow(chain) + + if args.spectral_norm: + add_spectral_norm(model) + + model.load_state_dict(state_dict) + + return model, test_loader.dataset + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Visualizes experiments trained using train_cnf.py.") + parser.add_argument("--checkpt", type=str, required=True) + parser.add_argument("--nsamples", type=int, default=50) + parser.add_argument("--ntimes", type=int, default=100) + parser.add_argument("--save", type=str, default="imgs") + args = parser.parse_args() + + checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) + ck_args = checkpt["args"] + state_dict = checkpt["state_dict"] + + model, test_set = build_model(ck_args, state_dict) + real_samples = torch.stack([test_set[i][0] for i in range(args.nsamples)], dim=0) + data_shape = real_samples.shape[1:] + fake_latents = torch.randn(args.nsamples, *data_shape) + + # Transfer to GPU if available. + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print("Running on {}".format(device)) + model.to(device) + real_samples = real_samples.to(device) + fake_latents = fake_latents.to(device) + + # Construct fake samples + fake_samples = model(fake_latents, reverse=True).view(-1, *data_shape) + samples = torch.cat([real_samples, fake_samples], dim=0) + + still_diffeq = torch.zeros_like(samples) + im_indx = 0 + + # Image-saving helper function + def save_im(im, diffeq): + global im_indx + filename = os.path.join(current_dir, args.save, "flow_%05d.png" % im_indx) + utils.makedirs(os.path.dirname(filename)) + + diffeq = diffeq.clone() + de_min, de_max = float(diffeq.min()), float(diffeq.max()) + diffeq.clamp_(min=de_min, max=de_max) + diffeq.add_(-de_min).div_(de_max - de_min + 1e-5) + + assert im.shape == diffeq.shape + shape = im.shape + interleaved = torch.stack([im, diffeq]).transpose(0, 1).contiguous().view(2 * shape[0], *shape[1:]) + save_image(interleaved, filename, nrow=20, padding=0, range=(0, 1)) + im_indx += 1 + + # Still frames with image samples. + for _ in range(30): + save_im(samples, still_diffeq) + + # Forward image to latent. + logits = model.chain[0](samples) + for i in range(1, len(model.chain)): + assert isinstance(model.chain[i], layers.CNF) + cnf = model.chain[i] + tt = torch.linspace(cnf.integration_times[0], cnf.integration_times[-1], args.ntimes) + z_t = cnf(logits, integration_times=tt) + logits = z_t[-1] + + # transform back to image space + im_t = model.chain[0](z_t.view(args.ntimes * args.nsamples * 2, *data_shape), + reverse=True).view(args.ntimes, 2 * args.nsamples, *data_shape) + + # save each step as an image + for t, im in zip(tt, im_t): + diffeq = cnf.odefunc(t, (im, None))[0] + diffeq = model.chain[0](diffeq, reverse=True) + save_im(im, diffeq) + + # Still frames with latent samples. + latents = model.chain[0](logits, reverse=True) + for _ in range(30): + save_im(latents, still_diffeq) + + # Forward image to latent. + for i in range(len(model.chain) - 1, 0, -1): + assert isinstance(model.chain[i], layers.CNF) + cnf = model.chain[i] + tt = torch.linspace(cnf.integration_times[-1], cnf.integration_times[0], args.ntimes) + z_t = cnf(logits, integration_times=tt) + logits = z_t[-1] + + # transform back to image space + im_t = model.chain[0](z_t.view(args.ntimes * args.nsamples * 2, *data_shape), + reverse=True).view(args.ntimes, 2 * args.nsamples, *data_shape) + # save each step as an image + for t, im in zip(tt, im_t): + diffeq = cnf.odefunc(t, (im, None))[0] + diffeq = model.chain[0](diffeq, reverse=True) + save_im(im, -diffeq) + + # Combine the images into a movie + bashCommand = r"ffmpeg -y -i {}/flow_%05d.png {}".format( + os.path.join(current_dir, args.save), os.path.join(current_dir, args.save, "flow.mp4") + ) + process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) + output, error = process.communicate() diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_high_fidelity_toy.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_high_fidelity_toy.py new file mode 100644 index 0000000..b4181c2 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_high_fidelity_toy.py @@ -0,0 +1,132 @@ +import os +import math +from tqdm import tqdm +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import torch + + +def standard_normal_logprob(z): + logZ = -0.5 * math.log(2 * math.pi) + return logZ - z.pow(2) / 2 + + +def makedirs(dirname): + if not os.path.exists(dirname): + os.makedirs(dirname) + + +def save_density_traj(model, data_samples, savedir, ntimes=101, memory=0.01, device='cpu'): + model.eval() + + # sample from a grid + npts = 800 + side = np.linspace(-4, 4, npts) + xx, yy = np.meshgrid(side, side) + xx = torch.from_numpy(xx).type(torch.float32).to(device) + yy = torch.from_numpy(yy).type(torch.float32).to(device) + z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1) + + with torch.no_grad(): + # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container. + logpz_grid = torch.sum(standard_normal_logprob(z_grid), 1, keepdim=True) + for cnf in model.chain: + end_time = cnf.sqrt_end_time * cnf.sqrt_end_time + viz_times = torch.linspace(0., end_time, ntimes) + + logpz_grid = [standard_normal_logprob(z_grid).sum(1, keepdim=True)] + for t in tqdm(viz_times[1:]): + inds = torch.arange(0, z_grid.shape[0]).to(torch.int64) + logpz_t = [] + for ii in torch.split(inds, int(z_grid.shape[0] * memory)): + z0, delta_logp = cnf( + z_grid[ii], + torch.zeros(z_grid[ii].shape[0], 1).to(z_grid), integration_times=torch.tensor([0., + t.item()]) + ) + logpz_t.append(standard_normal_logprob(z0).sum(1, keepdim=True) - delta_logp) + logpz_grid.append(torch.cat(logpz_t, 0)) + logpz_grid = torch.stack(logpz_grid, 0).cpu().detach().numpy() + z_grid = z_grid.cpu().detach().numpy() + + plt.figure(figsize=(8, 8)) + for t in range(logpz_grid.shape[0]): + + plt.clf() + ax = plt.gca() + + # plot the density + z, logqz = z_grid, logpz_grid[t] + + xx = z[:, 0].reshape(npts, npts) + yy = z[:, 1].reshape(npts, npts) + qz = np.exp(logqz).reshape(npts, npts) + + plt.pcolormesh(xx, yy, qz, cmap='binary') + ax.set_xlim(-4, 4) + ax.set_ylim(-4, 4) + cmap = matplotlib.cm.get_cmap('binary') + ax.set_axis_bgcolor(cmap(0.)) + ax.invert_yaxis() + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + plt.tight_layout() + + makedirs(savedir) + plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg")) + + +def trajectory_to_video(savedir): + import subprocess + bashCommand = 'ffmpeg -y -i {} {}'.format(os.path.join(savedir, 'viz-%05d.jpg'), os.path.join(savedir, 'traj.mp4')) + process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) + output, error = process.communicate() + + +if __name__ == '__main__': + import argparse + import sys + + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))) + + import lib.toy_data as toy_data + from train_misc import count_parameters + from train_misc import set_cnf_options, add_spectral_norm, create_regularization_fns + from train_misc import build_model_tabular + + def get_ckpt_model_and_data(args): + # Load checkpoint. + checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) + ckpt_args = checkpt['args'] + state_dict = checkpt['state_dict'] + + # Construct model and restore checkpoint. + regularization_fns, regularization_coeffs = create_regularization_fns(ckpt_args) + model = build_model_tabular(ckpt_args, 2, regularization_fns).to(device) + if ckpt_args.spectral_norm: add_spectral_norm(model) + set_cnf_options(ckpt_args, model) + + model.load_state_dict(state_dict) + model.to(device) + + print(model) + print("Number of trainable parameters: {}".format(count_parameters(model))) + + # Load samples from dataset + data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000) + + return model, data_samples + + parser = argparse.ArgumentParser() + parser.add_argument('--checkpt', type=str, required=True) + parser.add_argument('--ntimes', type=int, default=101) + parser.add_argument('--memory', type=float, default=0.01, help='Higher this number, the more memory is consumed.') + parser.add_argument('--save', type=str, default='trajectory') + args = parser.parse_args() + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + model, data_samples = get_ckpt_model_and_data(args) + save_density_traj(model, data_samples, args.save, ntimes=args.ntimes, memory=args.memory, device=device) + trajectory_to_video(args.save) diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_multiscale.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_multiscale.py new file mode 100644 index 0000000..b3fcd9d --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_multiscale.py @@ -0,0 +1,222 @@ +from inspect import getsourcefile +import sys +import os +import math + +current_path = os.path.abspath(getsourcefile(lambda: 0)) +current_dir = os.path.dirname(current_path) +parent_dir = current_dir[:current_dir.rfind(os.path.sep)] +sys.path.insert(0, parent_dir) + +import argparse + +import lib.layers as layers +import lib.odenvp as odenvp +import torch +import torchvision.transforms as tforms +import torchvision.datasets as dset +from torchvision.utils import save_image +import lib.utils as utils + +from train_misc import add_spectral_norm, set_cnf_options, count_parameters + +parser = argparse.ArgumentParser("Continuous Normalizing Flow") +parser.add_argument("--checkpt", type=str, required=True) +parser.add_argument("--data", choices=["mnist", "svhn", "cifar10", 'lsun_church'], type=str, default="cifar10") +parser.add_argument("--dims", type=str, default="64,64,64") +parser.add_argument("--num_blocks", type=int, default=2, help='Number of stacked CNFs.') +parser.add_argument("--divergence_fn", type=str, default="approximate", choices=["brute_force", "approximate"]) +parser.add_argument( + "--nonlinearity", type=str, default="softplus", choices=["tanh", "relu", "softplus", "elu", "swish"] +) +parser.add_argument("--conv", type=eval, default=True, choices=[True, False]) + +parser.add_argument('--solver', type=str, default='dopri5') +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument("--imagesize", type=int, default=None) +parser.add_argument("--alpha", type=float, default=-1.0) +parser.add_argument('--time_length', type=float, default=1.0) +parser.add_argument('--train_T', type=eval, default=True) + +parser.add_argument("--add_noise", type=eval, default=True, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=True, choices=[True, False]) +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) + +parser.add_argument('--ntimes', type=int, default=50) +parser.add_argument('--save', type=str, default='img_trajectory') + +args = parser.parse_args() + +BATCH_SIZE = 8 * 8 + + +def add_noise(x): + """ + [0, 1] -> [0, 255] -> add noise -> [0, 1] + """ + if args.add_noise: + noise = x.new().resize_as_(x).uniform_() + x = x * 255 + noise + x = x / 256 + return x + + +def get_dataset(args): + trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise]) + + if args.data == "mnist": + im_dim = 1 + im_size = 28 if args.imagesize is None else args.imagesize + train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True) + elif args.data == "cifar10": + im_dim = 3 + im_size = 32 if args.imagesize is None else args.imagesize + train_set = dset.CIFAR10( + root="./data", train=True, transform=tforms.Compose([ + tforms.Resize(im_size), + tforms.RandomHorizontalFlip(), + tforms.ToTensor(), + add_noise, + ]), download=True + ) + elif args.data == 'lsun_church': + im_dim = 3 + im_size = 64 if args.imagesize is None else args.imagesize + train_set = dset.LSUN( + 'data', ['church_outdoor_train'], transform=tforms.Compose([ + tforms.Resize(96), + tforms.RandomCrop(64), + tforms.Resize(im_size), + tforms.ToTensor(), + add_noise, + ]) + ) + data_shape = (im_dim, im_size, im_size) + if not args.conv: + data_shape = (im_dim * im_size * im_size,) + + return train_set, data_shape + + +def create_model(args, data_shape): + hidden_dims = tuple(map(int, args.dims.split(","))) + + model = odenvp.ODENVP( + (BATCH_SIZE, *data_shape), + n_blocks=args.num_blocks, + intermediate_dims=hidden_dims, + nonlinearity=args.nonlinearity, + alpha=args.alpha, + cnf_kwargs={"T": args.time_length, "train_T": args.train_T}, + ) + if args.spectral_norm: add_spectral_norm(model) + set_cnf_options(args, model) + return model + + +if __name__ == '__main__': + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) + + # load dataset + train_set, data_shape = get_dataset(args) + + # build model + model = create_model(args, data_shape) + + print(model) + print("Number of trainable parameters: {}".format(count_parameters(model))) + + # restore parameters + checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) + pruned_sd = {} + for k, v in checkpt['state_dict'].items(): + pruned_sd[k.replace('odefunc.odefunc', 'odefunc')] = v + model.load_state_dict(pruned_sd) + + train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True) + + data_samples, _ = train_loader.__iter__().__next__() + + # cosine interpolate between 4 real images. + z = data_samples[:4] + print('Inferring base values for 4 example images.') + z = model(z) + + phi0 = torch.linspace(0, 0.5, int(math.sqrt(BATCH_SIZE))) * math.pi + phi1 = torch.linspace(0, 0.5, int(math.sqrt(BATCH_SIZE))) * math.pi + phi0, phi1 = torch.meshgrid([phi0, phi1]) + phi0, phi1 = phi0.contiguous().view(-1, 1), phi1.contiguous().view(-1, 1) + z = torch.cos(phi0) * (torch.cos(phi1) * z[0] + torch.sin(phi1) * z[1]) + \ + torch.sin(phi0) * (torch.cos(phi1) * z[2] + torch.sin(phi1) * z[3]) + print('Reconstructing images from latent interpolation.') + z = model(z, reverse=True) + + non_cnf_layers = [] + + utils.makedirs(args.save) + img_idx = 0 + + def save_imgs_figure(xs): + global img_idx + save_image( + list(xs), + os.path.join(args.save, "img_{:05d}.jpg".format(img_idx)), nrow=int(math.sqrt(BATCH_SIZE)), normalize=True, + range=(0, 1) + ) + img_idx += 1 + + class FactorOut(torch.nn.Module): + + def __init__(self, factor_out): + super(FactorOut, self).__init__() + self.factor_out = factor_out + + def forward(self, x, reverse=True): + assert reverse + T = x.shape[0] // self.factor_out.shape[0] + return torch.cat([x, self.factor_out.repeat(T, *([1] * (self.factor_out.ndimension() - 1)))], 1) + + time_ratio = 1.0 + print('Visualizing transformations.') + with torch.no_grad(): + for idx, stacked_layers in enumerate(model.transforms): + for layer in stacked_layers.chain: + print(z.shape) + print(non_cnf_layers) + if isinstance(layer, layers.CNF): + # linspace over time, and visualize by reversing through previous non_cnf_layers. + cnf = layer + end_time = (cnf.sqrt_end_time * cnf.sqrt_end_time) + ntimes = int(args.ntimes * time_ratio) + integration_times = torch.linspace(0, end_time.item(), ntimes) + z_traj = cnf(z, integration_times=integration_times) + + # reverse z(t) for all times to the input space + z_flatten = z_traj.view(ntimes * BATCH_SIZE, *z_traj.shape[2:]) + for prev_layer in non_cnf_layers[::-1]: + z_flatten = prev_layer(z_flatten, reverse=True) + z_inv = z_flatten.view(ntimes, BATCH_SIZE, *data_shape) + for t in range(1, z_inv.shape[0]): + z_t = z_inv[t] + save_imgs_figure(z_t) + z = z_traj[-1] + else: + # update z and place in non_cnf_layers. + z = layer(z) + non_cnf_layers.append(layer) + if idx < len(model.transforms) - 1: + d = z.shape[1] // 2 + z, factor_out = z[:, :d], z[:, d:] + non_cnf_layers.append(FactorOut(factor_out)) + + # After every factor out, we half the time for visualization. + time_ratio = time_ratio / 2 diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_toy.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_toy.py new file mode 100644 index 0000000..7ebf029 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_toy.py @@ -0,0 +1,179 @@ +import os +import math +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import torch + + +def standard_normal_logprob(z): + logZ = -0.5 * math.log(2 * math.pi) + return logZ - z.pow(2) / 2 + + +def makedirs(dirname): + if not os.path.exists(dirname): + os.makedirs(dirname) + + +def save_trajectory(model, data_samples, savedir, ntimes=101, memory=0.01, device='cpu'): + model.eval() + + # Sample from prior + z_samples = torch.randn(2000, 2).to(device) + + # sample from a grid + npts = 800 + side = np.linspace(-4, 4, npts) + xx, yy = np.meshgrid(side, side) + xx = torch.from_numpy(xx).type(torch.float32).to(device) + yy = torch.from_numpy(yy).type(torch.float32).to(device) + z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1) + + with torch.no_grad(): + # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container. + logp_samples = torch.sum(standard_normal_logprob(z_samples), 1, keepdim=True) + logp_grid = torch.sum(standard_normal_logprob(z_grid), 1, keepdim=True) + t = 0 + for cnf in model.chain: + end_time = (cnf.sqrt_end_time * cnf.sqrt_end_time) + integration_times = torch.linspace(0, end_time, ntimes) + + z_traj, _ = cnf(z_samples, logp_samples, integration_times=integration_times, reverse=True) + z_traj = z_traj.cpu().numpy() + + grid_z_traj, grid_logpz_traj = [], [] + inds = torch.arange(0, z_grid.shape[0]).to(torch.int64) + for ii in torch.split(inds, int(z_grid.shape[0] * memory)): + _grid_z_traj, _grid_logpz_traj = cnf( + z_grid[ii], logp_grid[ii], integration_times=integration_times, reverse=True + ) + _grid_z_traj, _grid_logpz_traj = _grid_z_traj.cpu().numpy(), _grid_logpz_traj.cpu().numpy() + grid_z_traj.append(_grid_z_traj) + grid_logpz_traj.append(_grid_logpz_traj) + grid_z_traj = np.concatenate(grid_z_traj, axis=1) + grid_logpz_traj = np.concatenate(grid_logpz_traj, axis=1) + + plt.figure(figsize=(8, 8)) + for _ in range(z_traj.shape[0]): + + plt.clf() + + # plot target potential function + ax = plt.subplot(2, 2, 1, aspect="equal") + + ax.hist2d(data_samples[:, 0], data_samples[:, 1], range=[[-4, 4], [-4, 4]], bins=200) + ax.invert_yaxis() + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + ax.set_title("Target", fontsize=32) + + # plot the density + ax = plt.subplot(2, 2, 2, aspect="equal") + + z, logqz = grid_z_traj[t], grid_logpz_traj[t] + + xx = z[:, 0].reshape(npts, npts) + yy = z[:, 1].reshape(npts, npts) + qz = np.exp(logqz).reshape(npts, npts) + + plt.pcolormesh(xx, yy, qz) + ax.set_xlim(-4, 4) + ax.set_ylim(-4, 4) + cmap = matplotlib.cm.get_cmap(None) + ax.set_axis_bgcolor(cmap(0.)) + ax.invert_yaxis() + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + ax.set_title("Density", fontsize=32) + + # plot the samples + ax = plt.subplot(2, 2, 3, aspect="equal") + + zk = z_traj[t] + ax.hist2d(zk[:, 0], zk[:, 1], range=[[-4, 4], [-4, 4]], bins=200) + ax.invert_yaxis() + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + ax.set_title("Samples", fontsize=32) + + # plot vector field + ax = plt.subplot(2, 2, 4, aspect="equal") + + K = 13j + y, x = np.mgrid[-4:4:K, -4:4:K] + K = int(K.imag) + zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32) + logps = torch.zeros(zs.shape[0], 1).to(device, torch.float32) + dydt = cnf.odefunc(integration_times[t], (zs, logps))[0] + dydt = -dydt.cpu().detach().numpy() + dydt = dydt.reshape(K, K, 2) + + logmag = 2 * np.log(np.hypot(dydt[:, :, 0], dydt[:, :, 1])) + ax.quiver( + x, y, dydt[:, :, 0], dydt[:, :, 1], + np.exp(logmag), cmap="coolwarm", scale=20., width=0.015, pivot="mid" + ) + ax.set_xlim(-4, 4) + ax.set_ylim(-4, 4) + ax.axis("off") + ax.set_title("Vector Field", fontsize=32) + + makedirs(savedir) + plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg")) + t += 1 + + +def trajectory_to_video(savedir): + import subprocess + bashCommand = 'ffmpeg -y -i {} {}'.format(os.path.join(savedir, 'viz-%05d.jpg'), os.path.join(savedir, 'traj.mp4')) + process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) + output, error = process.communicate() + + +if __name__ == '__main__': + import argparse + import sys + + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))) + + import lib.toy_data as toy_data + from train_misc import count_parameters + from train_misc import set_cnf_options, add_spectral_norm, create_regularization_fns + from train_misc import build_model_tabular + + def get_ckpt_model_and_data(args): + # Load checkpoint. + checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) + ckpt_args = checkpt['args'] + state_dict = checkpt['state_dict'] + + # Construct model and restore checkpoint. + regularization_fns, regularization_coeffs = create_regularization_fns(ckpt_args) + model = build_model_tabular(ckpt_args, 2, regularization_fns).to(device) + if ckpt_args.spectral_norm: add_spectral_norm(model) + set_cnf_options(ckpt_args, model) + + model.load_state_dict(state_dict) + model.to(device) + + print(model) + print("Number of trainable parameters: {}".format(count_parameters(model))) + + # Load samples from dataset + data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000) + + return model, data_samples + + parser = argparse.ArgumentParser() + parser.add_argument('--checkpt', type=str, required=True) + parser.add_argument('--ntimes', type=int, default=101) + parser.add_argument('--memory', type=float, default=0.01, help='Higher this number, the more memory is consumed.') + parser.add_argument('--save', type=str, default='trajectory') + args = parser.parse_args() + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + model, data_samples = get_ckpt_model_and_data(args) + save_trajectory(model, data_samples, args.save, ntimes=args.ntimes, memory=args.memory, device=device) + trajectory_to_video(args.save) diff --git a/src/torchprune/torchprune/util/external/ffjord/imgs/github.png b/src/torchprune/torchprune/util/external/ffjord/imgs/github.png new file mode 100644 index 0000000..487ed2f Binary files /dev/null and b/src/torchprune/torchprune/util/external/ffjord/imgs/github.png differ diff --git a/src/torchprune/torchprune/util/external/ffjord/imgs/maple_leaf.jpg b/src/torchprune/torchprune/util/external/ffjord/imgs/maple_leaf.jpg new file mode 100644 index 0000000..05bcc19 Binary files /dev/null and b/src/torchprune/torchprune/util/external/ffjord/imgs/maple_leaf.jpg differ diff --git a/src/torchprune/torchprune/util/external/ffjord/train_cnf.py b/src/torchprune/torchprune/util/external/ffjord/train_cnf.py new file mode 100644 index 0000000..afbc364 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_cnf.py @@ -0,0 +1,444 @@ +import argparse +import os +import time +import numpy as np + +import torch +import torch.optim as optim +import torchvision.datasets as dset +import torchvision.transforms as tforms +from torchvision.utils import save_image + +import lib.layers as layers +import lib.utils as utils +import lib.odenvp as odenvp +import lib.multiscale_parallel as multiscale_parallel + +from train_misc import standard_normal_logprob +from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time +from train_misc import add_spectral_norm, spectral_norm_power_iteration +from train_misc import create_regularization_fns, get_regularization, append_regularization_to_log + +# go fast boi!! +torch.backends.cudnn.benchmark = True +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams'] +parser = argparse.ArgumentParser("Continuous Normalizing Flow") +parser.add_argument("--data", choices=["mnist", "svhn", "cifar10", 'lsun_church'], type=str, default="mnist") +parser.add_argument("--dims", type=str, default="8,32,32,8") +parser.add_argument("--strides", type=str, default="2,2,1,-2,-2") +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') + +parser.add_argument("--conv", type=eval, default=True, choices=[True, False]) +parser.add_argument( + "--layer_type", type=str, default="ignore", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument("--divergence_fn", type=str, default="approximate", choices=["brute_force", "approximate"]) +parser.add_argument( + "--nonlinearity", type=str, default="softplus", choices=["tanh", "relu", "softplus", "elu", "swish"] +) +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument("--imagesize", type=int, default=None) +parser.add_argument("--alpha", type=float, default=1e-6) +parser.add_argument('--time_length', type=float, default=1.0) +parser.add_argument('--train_T', type=eval, default=True) + +parser.add_argument("--num_epochs", type=int, default=1000) +parser.add_argument("--batch_size", type=int, default=200) +parser.add_argument( + "--batch_size_schedule", type=str, default="", help="Increases the batchsize at every given epoch, dash separated." +) +parser.add_argument("--test_batch_size", type=int, default=200) +parser.add_argument("--lr", type=float, default=1e-3) +parser.add_argument("--warmup_iters", type=float, default=1000) +parser.add_argument("--weight_decay", type=float, default=0.0) +parser.add_argument("--spectral_norm_niter", type=int, default=10) + +parser.add_argument("--add_noise", type=eval, default=True, choices=[True, False]) +parser.add_argument("--batch_norm", type=eval, default=False, choices=[True, False]) +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--autoencode', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=True, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--multiscale', type=eval, default=False, choices=[True, False]) +parser.add_argument('--parallel', type=eval, default=False, choices=[True, False]) + +# Regularizations +parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1") +parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2") +parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2") +parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F") +parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F") +parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") + +parser.add_argument("--time_penalty", type=float, default=0, help="Regularization on the end_time.") +parser.add_argument( + "--max_grad_norm", type=float, default=1e10, + help="Max norm of graidents (default is just stupidly high to avoid any clipping)" +) + +parser.add_argument("--begin_epoch", type=int, default=1) +parser.add_argument("--resume", type=str, default=None) +parser.add_argument("--save", type=str, default="experiments/cnf") +parser.add_argument("--val_freq", type=int, default=1) +parser.add_argument("--log_freq", type=int, default=10) + +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) + +if args.layer_type == "blend": + logger.info("!! Setting time_length from None to 1.0 due to use of Blend layers.") + args.time_length = 1.0 + +logger.info(args) + + +def add_noise(x): + """ + [0, 1] -> [0, 255] -> add noise -> [0, 1] + """ + if args.add_noise: + noise = x.new().resize_as_(x).uniform_() + x = x * 255 + noise + x = x / 256 + return x + + +def update_lr(optimizer, itr): + iter_frac = min(float(itr + 1) / max(args.warmup_iters, 1), 1.0) + lr = args.lr * iter_frac + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def get_train_loader(train_set, epoch): + if args.batch_size_schedule != "": + epochs = [0] + list(map(int, args.batch_size_schedule.split("-"))) + n_passed = sum(np.array(epochs) <= epoch) + current_batch_size = int(args.batch_size * n_passed) + else: + current_batch_size = args.batch_size + train_loader = torch.utils.data.DataLoader( + dataset=train_set, batch_size=current_batch_size, shuffle=True, drop_last=True, pin_memory=True + ) + logger.info("===> Using batch size {}. Total {} iterations/epoch.".format(current_batch_size, len(train_loader))) + return train_loader + + +def get_dataset(args): + trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise]) + + if args.data == "mnist": + im_dim = 1 + im_size = 28 if args.imagesize is None else args.imagesize + train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True) + test_set = dset.MNIST(root="./data", train=False, transform=trans(im_size), download=True) + elif args.data == "svhn": + im_dim = 3 + im_size = 32 if args.imagesize is None else args.imagesize + train_set = dset.SVHN(root="./data", split="train", transform=trans(im_size), download=True) + test_set = dset.SVHN(root="./data", split="test", transform=trans(im_size), download=True) + elif args.data == "cifar10": + im_dim = 3 + im_size = 32 if args.imagesize is None else args.imagesize + train_set = dset.CIFAR10( + root="./data", train=True, transform=tforms.Compose([ + tforms.Resize(im_size), + tforms.RandomHorizontalFlip(), + tforms.ToTensor(), + add_noise, + ]), download=True + ) + test_set = dset.CIFAR10(root="./data", train=False, transform=trans(im_size), download=True) + elif args.data == 'celeba': + im_dim = 3 + im_size = 64 if args.imagesize is None else args.imagesize + train_set = dset.CelebA( + train=True, transform=tforms.Compose([ + tforms.ToPILImage(), + tforms.Resize(im_size), + tforms.RandomHorizontalFlip(), + tforms.ToTensor(), + add_noise, + ]) + ) + test_set = dset.CelebA( + train=False, transform=tforms.Compose([ + tforms.ToPILImage(), + tforms.Resize(im_size), + tforms.ToTensor(), + add_noise, + ]) + ) + elif args.data == 'lsun_church': + im_dim = 3 + im_size = 64 if args.imagesize is None else args.imagesize + train_set = dset.LSUN( + 'data', ['church_outdoor_train'], transform=tforms.Compose([ + tforms.Resize(96), + tforms.RandomCrop(64), + tforms.Resize(im_size), + tforms.ToTensor(), + add_noise, + ]) + ) + test_set = dset.LSUN( + 'data', ['church_outdoor_val'], transform=tforms.Compose([ + tforms.Resize(96), + tforms.RandomCrop(64), + tforms.Resize(im_size), + tforms.ToTensor(), + add_noise, + ]) + ) + data_shape = (im_dim, im_size, im_size) + if not args.conv: + data_shape = (im_dim * im_size * im_size,) + + test_loader = torch.utils.data.DataLoader( + dataset=test_set, batch_size=args.test_batch_size, shuffle=False, drop_last=True + ) + return train_set, test_loader, data_shape + + +def compute_bits_per_dim(x, model): + zero = torch.zeros(x.shape[0], 1).to(x) + + # Don't use data parallelize if batch size is small. + # if x.shape[0] < 200: + # model = model.module + + z, delta_logp = model(x, zero) # run model forward + + logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) + logpx = logpz - delta_logp + + logpx_per_dim = torch.sum(logpx) / x.nelement() # averaged over batches + bits_per_dim = -(logpx_per_dim - np.log(256)) / np.log(2) + + return bits_per_dim + + +def create_model(args, data_shape, regularization_fns): + hidden_dims = tuple(map(int, args.dims.split(","))) + strides = tuple(map(int, args.strides.split(","))) + + if args.multiscale: + model = odenvp.ODENVP( + (args.batch_size, *data_shape), + n_blocks=args.num_blocks, + intermediate_dims=hidden_dims, + nonlinearity=args.nonlinearity, + alpha=args.alpha, + cnf_kwargs={"T": args.time_length, "train_T": args.train_T, "regularization_fns": regularization_fns}, + ) + elif args.parallel: + model = multiscale_parallel.MultiscaleParallelCNF( + (args.batch_size, *data_shape), + n_blocks=args.num_blocks, + intermediate_dims=hidden_dims, + alpha=args.alpha, + time_length=args.time_length, + ) + else: + if args.autoencode: + + def build_cnf(): + autoencoder_diffeq = layers.AutoencoderDiffEqNet( + hidden_dims=hidden_dims, + input_shape=data_shape, + strides=strides, + conv=args.conv, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = layers.AutoencoderODEfunc( + autoencoder_diffeq=autoencoder_diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + cnf = layers.CNF( + odefunc=odefunc, + T=args.time_length, + regularization_fns=regularization_fns, + solver=args.solver, + ) + return cnf + else: + + def build_cnf(): + diffeq = layers.ODEnet( + hidden_dims=hidden_dims, + input_shape=data_shape, + strides=strides, + conv=args.conv, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = layers.ODEfunc( + diffeq=diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + cnf = layers.CNF( + odefunc=odefunc, + T=args.time_length, + train_T=args.train_T, + regularization_fns=regularization_fns, + solver=args.solver, + ) + return cnf + + chain = [layers.LogitTransform(alpha=args.alpha)] if args.alpha > 0 else [layers.ZeroMeanTransform()] + chain = chain + [build_cnf() for _ in range(args.num_blocks)] + if args.batch_norm: + chain.append(layers.MovingBatchNorm2d(data_shape[0])) + model = layers.SequentialFlow(chain) + return model + + +if __name__ == "__main__": + + # get deivce + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) + + # load dataset + train_set, test_loader, data_shape = get_dataset(args) + + # build model + regularization_fns, regularization_coeffs = create_regularization_fns(args) + model = create_model(args, data_shape, regularization_fns) + + if args.spectral_norm: add_spectral_norm(model, logger) + set_cnf_options(args, model) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + # optimizer + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + # restore parameters + if args.resume is not None: + checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage) + model.load_state_dict(checkpt["state_dict"]) + if "optim_state_dict" in checkpt.keys(): + optimizer.load_state_dict(checkpt["optim_state_dict"]) + # Manually move optimizer state to device. + for state in optimizer.state.values(): + for k, v in state.items(): + if torch.is_tensor(v): + state[k] = cvt(v) + + if torch.cuda.is_available(): + model = torch.nn.DataParallel(model).cuda() + + # For visualization. + fixed_z = cvt(torch.randn(100, *data_shape)) + + time_meter = utils.RunningAverageMeter(0.97) + loss_meter = utils.RunningAverageMeter(0.97) + steps_meter = utils.RunningAverageMeter(0.97) + grad_meter = utils.RunningAverageMeter(0.97) + tt_meter = utils.RunningAverageMeter(0.97) + + if args.spectral_norm and not args.resume: spectral_norm_power_iteration(model, 500) + + best_loss = float("inf") + itr = 0 + for epoch in range(args.begin_epoch, args.num_epochs + 1): + model.train() + train_loader = get_train_loader(train_set, epoch) + for _, (x, y) in enumerate(train_loader): + start = time.time() + update_lr(optimizer, itr) + optimizer.zero_grad() + + if not args.conv: + x = x.view(x.shape[0], -1) + + # cast data and move to device + x = cvt(x) + # compute loss + loss = compute_bits_per_dim(x, model) + if regularization_coeffs: + reg_states = get_regularization(model, regularization_coeffs) + reg_loss = sum( + reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 + ) + loss = loss + reg_loss + total_time = count_total_time(model) + loss = loss + total_time * args.time_penalty + + loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + optimizer.step() + + if args.spectral_norm: spectral_norm_power_iteration(model, args.spectral_norm_niter) + + time_meter.update(time.time() - start) + loss_meter.update(loss.item()) + steps_meter.update(count_nfe(model)) + grad_meter.update(grad_norm) + tt_meter.update(total_time) + + if itr % args.log_freq == 0: + log_message = ( + "Iter {:04d} | Time {:.4f}({:.4f}) | Bit/dim {:.4f}({:.4f}) | " + "Steps {:.0f}({:.2f}) | Grad Norm {:.4f}({:.4f}) | Total Time {:.2f}({:.2f})".format( + itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, steps_meter.val, + steps_meter.avg, grad_meter.val, grad_meter.avg, tt_meter.val, tt_meter.avg + ) + ) + if regularization_coeffs: + log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) + logger.info(log_message) + + itr += 1 + + # compute test loss + model.eval() + if epoch % args.val_freq == 0: + with torch.no_grad(): + start = time.time() + logger.info("validating...") + losses = [] + for (x, y) in test_loader: + if not args.conv: + x = x.view(x.shape[0], -1) + x = cvt(x) + loss = compute_bits_per_dim(x, model) + losses.append(loss) + + loss = np.mean(losses) + logger.info("Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}".format(epoch, time.time() - start, loss)) + if loss < best_loss: + best_loss = loss + utils.makedirs(args.save) + torch.save({ + "args": args, + "state_dict": model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), + "optim_state_dict": optimizer.state_dict(), + }, os.path.join(args.save, "checkpt.pth")) + + # visualize samples and density + with torch.no_grad(): + fig_filename = os.path.join(args.save, "figs", "{:04d}.jpg".format(epoch)) + utils.makedirs(os.path.dirname(fig_filename)) + generated_samples = model(fixed_z, reverse=True).view(-1, *data_shape) + save_image(generated_samples, fig_filename, nrow=10) diff --git a/src/torchprune/torchprune/util/external/ffjord/train_discrete_tabular.py b/src/torchprune/torchprune/util/external/ffjord/train_discrete_tabular.py new file mode 100644 index 0000000..4492842 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_discrete_tabular.py @@ -0,0 +1,235 @@ +import argparse +import os +import time + +import torch + +import lib.utils as utils +from lib.custom_optimizers import Adam +import lib.layers as layers + +import datasets + +from train_misc import standard_normal_logprob, count_parameters + +parser = argparse.ArgumentParser() +parser.add_argument( + '--data', choices=['power', 'gas', 'hepmass', 'miniboone', 'bsds300'], type=str, default='miniboone' +) + +parser.add_argument('--depth', type=int, default=10) +parser.add_argument('--dims', type=str, default="100-100") +parser.add_argument('--nonlinearity', type=str, default="tanh") +parser.add_argument('--glow', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) + +parser.add_argument('--early_stopping', type=int, default=30) +parser.add_argument('--batch_size', type=int, default=1000) +parser.add_argument('--test_batch_size', type=int, default=None) +parser.add_argument('--lr', type=float, default=1e-4) +parser.add_argument('--weight_decay', type=float, default=1e-6) + +parser.add_argument('--resume', type=str, default=None) +parser.add_argument('--save', type=str, default='experiments/cnf') +parser.add_argument('--evaluate', action='store_true') +parser.add_argument('--val_freq', type=int, default=200) +parser.add_argument('--log_freq', type=int, default=10) +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) +logger.info(args) + +test_batch_size = args.test_batch_size if args.test_batch_size else args.batch_size + + +def batch_iter(X, batch_size=args.batch_size, shuffle=False): + """ + X: feature tensor (shape: num_instances x num_features) + """ + if shuffle: + idxs = torch.randperm(X.shape[0]) + else: + idxs = torch.arange(X.shape[0]) + if X.is_cuda: + idxs = idxs.cuda() + for batch_idxs in idxs.split(batch_size): + yield X[batch_idxs] + + +ndecs = 0 + + +def update_lr(optimizer, n_vals_without_improvement): + global ndecs + if ndecs == 0 and n_vals_without_improvement > args.early_stopping // 3: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr / 10 + ndecs = 1 + elif ndecs == 1 and n_vals_without_improvement > args.early_stopping // 3 * 2: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr / 100 + ndecs = 2 + else: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr / 10**ndecs + + +def load_data(name): + + if name == 'bsds300': + return datasets.BSDS300() + + elif name == 'power': + return datasets.POWER() + + elif name == 'gas': + return datasets.GAS() + + elif name == 'hepmass': + return datasets.HEPMASS() + + elif name == 'miniboone': + return datasets.MINIBOONE() + + else: + raise ValueError('Unknown dataset') + + +def build_model(input_dim): + hidden_dims = tuple(map(int, args.dims.split("-"))) + chain = [] + for i in range(args.depth): + if args.glow: chain.append(layers.BruteForceLayer(input_dim)) + chain.append(layers.MaskedCouplingLayer(input_dim, hidden_dims, 'alternate', swap=i % 2 == 0)) + if args.batch_norm: chain.append(layers.MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag)) + return layers.SequentialFlow(chain) + + +def compute_loss(x, model): + zero = torch.zeros(x.shape[0], 1).to(x) + + z, delta_logp = model(x, zero) # run model forward + + logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) + logpx = logpz - delta_logp + loss = -torch.mean(logpx) + return loss + + +def restore_model(model, filename): + checkpt = torch.load(filename, map_location=lambda storage, loc: storage) + model.load_state_dict(checkpt["state_dict"]) + return model + + +if __name__ == '__main__': + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) + + logger.info('Using {} GPUs.'.format(torch.cuda.device_count())) + + data = load_data(args.data) + data.trn.x = torch.from_numpy(data.trn.x) + data.val.x = torch.from_numpy(data.val.x) + data.tst.x = torch.from_numpy(data.tst.x) + + model = build_model(data.n_dims).to(device) + + if args.resume is not None: + checkpt = torch.load(args.resume) + model.load_state_dict(checkpt['state_dict']) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + if not args.evaluate: + optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + time_meter = utils.RunningAverageMeter(0.98) + loss_meter = utils.RunningAverageMeter(0.98) + + best_loss = float('inf') + itr = 0 + n_vals_without_improvement = 0 + end = time.time() + model.train() + while True: + if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping: + break + + for x in batch_iter(data.trn.x, shuffle=True): + if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping: + break + + optimizer.zero_grad() + + x = cvt(x) + loss = compute_loss(x, model) + loss_meter.update(loss.item()) + + loss.backward() + optimizer.step() + + time_meter.update(time.time() - end) + + if itr % args.log_freq == 0: + log_message = ( + 'Iter {:06d} | Epoch {:.2f} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | '.format( + itr, + float(itr) / (data.trn.x.shape[0] / float(args.batch_size)), time_meter.val, time_meter.avg, + loss_meter.val, loss_meter.avg + ) + ) + logger.info(log_message) + itr += 1 + end = time.time() + + # Validation loop. + if itr % args.val_freq == 0: + model.eval() + start_time = time.time() + with torch.no_grad(): + val_loss = utils.AverageMeter() + for x in batch_iter(data.val.x, batch_size=test_batch_size): + x = cvt(x) + val_loss.update(compute_loss(x, model).item(), x.shape[0]) + + if val_loss.avg < best_loss: + best_loss = val_loss.avg + utils.makedirs(args.save) + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + }, os.path.join(args.save, 'checkpt.pth')) + n_vals_without_improvement = 0 + else: + n_vals_without_improvement += 1 + update_lr(optimizer, n_vals_without_improvement) + + log_message = ( + '[VAL] Iter {:06d} | Val Loss {:.6f} | ' + 'NoImproveEpochs {:02d}/{:02d}'.format( + itr, val_loss.avg, n_vals_without_improvement, args.early_stopping + ) + ) + logger.info(log_message) + model.train() + + logger.info('Training has finished.') + model = restore_model(model, os.path.join(args.save, 'checkpt.pth')).to(device) + + logger.info('Evaluating model on test set.') + model.eval() + + with torch.no_grad(): + test_loss = utils.AverageMeter() + for itr, x in enumerate(batch_iter(data.tst.x, batch_size=test_batch_size)): + x = cvt(x) + test_loss.update(compute_loss(x, model).item(), x.shape[0]) + logger.info('Progress: {:.2f}%'.format(itr / (data.tst.x.shape[0] / test_batch_size))) + log_message = '[TEST] Iter {:06d} | Test Loss {:.6f} '.format(itr, test_loss.avg) + logger.info(log_message) diff --git a/src/torchprune/torchprune/util/external/ffjord/train_discrete_toy.py b/src/torchprune/torchprune/util/external/ffjord/train_discrete_toy.py new file mode 100644 index 0000000..afa536e --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_discrete_toy.py @@ -0,0 +1,186 @@ +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import argparse +import os +import time + +import torch +import torch.optim as optim + +import lib.layers as layers +import lib.toy_data as toy_data +import lib.utils as utils +from lib.visualize_flow import visualize_transform + +from train_misc import standard_normal_logprob +from train_misc import count_parameters + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser('Continuous Normalizing Flow') +parser.add_argument( + '--data', choices=['swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings'], + type=str, default='pinwheel' +) + +parser.add_argument('--depth', help='number of coupling layers', type=int, default=10) +parser.add_argument('--glow', type=eval, choices=[True, False], default=False) +parser.add_argument('--nf', type=eval, choices=[True, False], default=False) + +parser.add_argument('--niters', type=int, default=100001) +parser.add_argument('--batch_size', type=int, default=100) +parser.add_argument('--test_batch_size', type=int, default=1000) +parser.add_argument('--lr', type=float, default=1e-4) +parser.add_argument('--weight_decay', type=float, default=0) + +# Track quantities +parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1") +parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2") +parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2") +parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F") +parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F") +parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") + +parser.add_argument('--save', type=str, default='experiments/cnf') +parser.add_argument('--viz_freq', type=int, default=1000) +parser.add_argument('--val_freq', type=int, default=1000) +parser.add_argument('--log_freq', type=int, default=100) +parser.add_argument('--gpu', type=int, default=0) +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) + +logger.info(args) + +device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + + +def construct_model(): + + if args.nf: + chain = [] + for i in range(args.depth): + chain.append(layers.PlanarFlow(2)) + return layers.SequentialFlow(chain) + else: + chain = [] + for i in range(args.depth): + if args.glow: chain.append(layers.BruteForceLayer(2)) + chain.append(layers.CouplingLayer(2, swap=i % 2 == 0)) + return layers.SequentialFlow(chain) + + +def get_transforms(model): + + if args.nf: + sample_fn = None + else: + + def sample_fn(z, logpz=None): + if logpz is not None: + return model(z, logpz, reverse=True) + else: + return model(z, reverse=True) + + def density_fn(x, logpx=None): + if logpx is not None: + return model(x, logpx, reverse=False) + else: + return model(x, reverse=False) + + return sample_fn, density_fn + + +def compute_loss(args, model, batch_size=None): + if batch_size is None: batch_size = args.batch_size + + # load data + x = toy_data.inf_train_gen(args.data, batch_size=batch_size) + x = torch.from_numpy(x).type(torch.float32).to(device) + zero = torch.zeros(x.shape[0], 1).to(x) + + # transform to z + z, delta_logp = model(x, zero) + + # compute log q(z) + logpz = standard_normal_logprob(z).sum(1, keepdim=True) + + logpx = logpz - delta_logp + loss = -torch.mean(logpx) + return loss + + +if __name__ == '__main__': + + model = construct_model().to(device) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + optimizer = optim.Adamax(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + time_meter = utils.RunningAverageMeter(0.98) + loss_meter = utils.RunningAverageMeter(0.98) + + end = time.time() + best_loss = float('inf') + model.train() + for itr in range(1, args.niters + 1): + optimizer.zero_grad() + + loss = compute_loss(args, model) + loss_meter.update(loss.item()) + + loss.backward() + optimizer.step() + + time_meter.update(time.time() - end) + + if itr % args.log_freq == 0: + log_message = ( + 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f})'.format( + itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg + ) + ) + logger.info(log_message) + + if itr % args.val_freq == 0 or itr == args.niters: + with torch.no_grad(): + model.eval() + test_loss = compute_loss(args, model, batch_size=args.test_batch_size) + log_message = '[TEST] Iter {:04d} | Test Loss {:.6f}'.format(itr, test_loss) + logger.info(log_message) + + if test_loss.item() < best_loss: + best_loss = test_loss.item() + utils.makedirs(args.save) + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + }, os.path.join(args.save, 'checkpt.pth')) + model.train() + + if itr % args.viz_freq == 0: + with torch.no_grad(): + model.eval() + p_samples = toy_data.inf_train_gen(args.data, batch_size=2000) + + sample_fn, density_fn = get_transforms(model) + + plt.figure(figsize=(9, 3)) + visualize_transform( + p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, + samples=True, npts=800, device=device + ) + fig_filename = os.path.join(args.save, 'figs', '{:04d}.jpg'.format(itr)) + utils.makedirs(os.path.dirname(fig_filename)) + plt.savefig(fig_filename) + plt.close() + model.train() + + end = time.time() + + logger.info('Training has finished.') diff --git a/src/torchprune/torchprune/util/external/ffjord/train_img2d.py b/src/torchprune/torchprune/util/external/ffjord/train_img2d.py new file mode 100644 index 0000000..5041e25 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_img2d.py @@ -0,0 +1,253 @@ +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import argparse +import os +import time + +import torch +import torch.optim as optim + +import lib.utils as utils +from lib.visualize_flow import visualize_transform +import lib.layers.odefunc as odefunc + +from train_misc import standard_normal_logprob +from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time +from train_misc import add_spectral_norm, spectral_norm_power_iteration +from train_misc import create_regularization_fns, get_regularization, append_regularization_to_log +from train_misc import build_model_tabular + +from diagnostics.viz_toy import save_trajectory, trajectory_to_video + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser('Continuous Normalizing Flow') +parser.add_argument('--img', type=str, required=True) +parser.add_argument('--data', type=str, default='dummy') +parser.add_argument( + "--layer_type", type=str, default="concatsquash", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument('--dims', type=str, default='64-64-64') +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') +parser.add_argument('--time_length', type=float, default=0.5) +parser.add_argument('--train_T', type=eval, default=True) +parser.add_argument("--divergence_fn", type=str, default="brute_force", choices=["brute_force", "approximate"]) +parser.add_argument("--nonlinearity", type=str, default="tanh", choices=odefunc.NONLINEARITIES) + +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) + +parser.add_argument('--niters', type=int, default=10000) +parser.add_argument('--batch_size', type=int, default=1000) +parser.add_argument('--test_batch_size', type=int, default=1000) +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--weight_decay', type=float, default=1e-5) + +# Track quantities +parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1") +parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2") +parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2") +parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F") +parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F") +parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") + +parser.add_argument('--save', type=str, default='experiments/cnf') +parser.add_argument('--viz_freq', type=int, default=100) +parser.add_argument('--val_freq', type=int, default=100) +parser.add_argument('--log_freq', type=int, default=10) +parser.add_argument('--gpu', type=int, default=0) +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) + +if args.layer_type == "blend": + logger.info("!! Setting time_length from None to 1.0 due to use of Blend layers.") + args.time_length = 1.0 + +logger.info(args) + +device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + +from PIL import Image +import numpy as np + +img = np.array(Image.open(args.img).convert('L')) +h, w = img.shape +xx = np.linspace(-4, 4, w) +yy = np.linspace(-4, 4, h) +xx, yy = np.meshgrid(xx, yy) +xx = xx.reshape(-1, 1) +yy = yy.reshape(-1, 1) + +means = np.concatenate([xx, yy], 1) +img = img.max() - img +probs = img.reshape(-1) / img.sum() + +std = np.array([8 / w / 2, 8 / h / 2]) + + +def sample_data(data=None, rng=None, batch_size=200): + """data and rng are ignored.""" + inds = np.random.choice(int(probs.shape[0]), int(batch_size), p=probs) + m = means[inds] + samples = np.random.randn(*m.shape) * std + m + return samples + + +def get_transforms(model): + + def sample_fn(z, logpz=None): + if logpz is not None: + return model(z, logpz, reverse=True) + else: + return model(z, reverse=True) + + def density_fn(x, logpx=None): + if logpx is not None: + return model(x, logpx, reverse=False) + else: + return model(x, reverse=False) + + return sample_fn, density_fn + + +def compute_loss(args, model, batch_size=None): + if batch_size is None: batch_size = args.batch_size + + # load data + x = sample_data(args.data, batch_size=batch_size) + x = torch.from_numpy(x).type(torch.float32).to(device) + zero = torch.zeros(x.shape[0], 1).to(x) + + # transform to z + z, delta_logp = model(x, zero) + + # compute log q(z) + logpz = standard_normal_logprob(z).sum(1, keepdim=True) + + logpx = logpz - delta_logp + loss = -torch.mean(logpx) + return loss + + +if __name__ == '__main__': + + regularization_fns, regularization_coeffs = create_regularization_fns(args) + model = build_model_tabular(args, 2, regularization_fns).to(device) + if args.spectral_norm: add_spectral_norm(model) + set_cnf_options(args, model) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + time_meter = utils.RunningAverageMeter(0.93) + loss_meter = utils.RunningAverageMeter(0.93) + nfef_meter = utils.RunningAverageMeter(0.93) + nfeb_meter = utils.RunningAverageMeter(0.93) + tt_meter = utils.RunningAverageMeter(0.93) + + end = time.time() + best_loss = float('inf') + model.train() + for itr in range(1, args.niters + 1): + optimizer.zero_grad() + if args.spectral_norm: spectral_norm_power_iteration(model, 1) + + loss = compute_loss(args, model) + loss_meter.update(loss.item()) + + if len(regularization_coeffs) > 0: + reg_states = get_regularization(model, regularization_coeffs) + reg_loss = sum( + reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 + ) + loss = loss + reg_loss + + total_time = count_total_time(model) + nfe_forward = count_nfe(model) + + loss.backward() + optimizer.step() + + nfe_total = count_nfe(model) + nfe_backward = nfe_total - nfe_forward + nfef_meter.update(nfe_forward) + nfeb_meter.update(nfe_backward) + + time_meter.update(time.time() - end) + tt_meter.update(total_time) + + log_message = ( + 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' + ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( + itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, + nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg + ) + ) + if len(regularization_coeffs) > 0: + log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) + + logger.info(log_message) + + if itr % args.val_freq == 0 or itr == args.niters: + with torch.no_grad(): + model.eval() + test_loss = compute_loss(args, model, batch_size=args.test_batch_size) + test_nfe = count_nfe(model) + log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe) + logger.info(log_message) + + if test_loss.item() < best_loss: + best_loss = test_loss.item() + utils.makedirs(args.save) + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + }, os.path.join(args.save, 'checkpt.pth')) + model.train() + + if itr % args.viz_freq == 0: + with torch.no_grad(): + model.eval() + p_samples = sample_data(args.data, batch_size=2000) + + sample_fn, density_fn = get_transforms(model) + + plt.figure(figsize=(9, 3)) + visualize_transform( + p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, + samples=True, npts=800, device=device + ) + fig_filename = os.path.join(args.save, 'figs', '{:04d}.jpg'.format(itr)) + utils.makedirs(os.path.dirname(fig_filename)) + plt.savefig(fig_filename) + plt.close() + model.train() + + end = time.time() + + logger.info('Training has finished.') + + save_traj_dir = os.path.join(args.save, 'trajectory') + logger.info('Plotting trajectory to {}'.format(save_traj_dir)) + data_samples = sample_data(args.data, batch_size=2000) + save_trajectory(model, data_samples, save_traj_dir, device=device) + trajectory_to_video(save_traj_dir) diff --git a/src/torchprune/torchprune/util/external/ffjord/train_misc.py b/src/torchprune/torchprune/util/external/ffjord/train_misc.py new file mode 100644 index 0000000..d2ee9e1 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_misc.py @@ -0,0 +1,200 @@ +import six +import math + +from .lib.layers.wrappers import cnf_regularization as reg_lib +from .lib import spectral_norm as spectral_norm +from .lib import layers as layers +from .lib.layers.odefunc import divergence_bf, divergence_approx + + +def standard_normal_logprob(z): + logZ = -0.5 * math.log(2 * math.pi) + return logZ - z.pow(2) / 2 + + +def set_cnf_options(args, model): + + def _set(module): + if isinstance(module, layers.CNF): + # Set training settings + module.solver = args.solver + module.atol = args.atol + module.rtol = args.rtol + if args.step_size is not None: + module.solver_options['step_size'] = args.step_size + + # If using fixed-grid adams, restrict order to not be too high. + if args.solver in ['fixed_adams', 'explicit_adams']: + module.solver_options['max_order'] = 4 + + # Set the test settings + module.test_solver = args.test_solver if args.test_solver else args.solver + module.test_atol = args.test_atol if args.test_atol else args.atol + module.test_rtol = args.test_rtol if args.test_rtol else args.rtol + + if isinstance(module, layers.ODEfunc): + module.rademacher = args.rademacher + module.residual = args.residual + + model.apply(_set) + + +def override_divergence_fn(model, divergence_fn): + + def _set(module): + if isinstance(module, layers.ODEfunc): + if divergence_fn == "brute_force": + module.divergence_fn = divergence_bf + elif divergence_fn == "approximate": + module.divergence_fn = divergence_approx + + model.apply(_set) + + +def count_nfe(model): + + class AccNumEvals(object): + + def __init__(self): + self.num_evals = 0 + + def __call__(self, module): + if isinstance(module, layers.ODEfunc): + self.num_evals += module.num_evals() + + accumulator = AccNumEvals() + model.apply(accumulator) + return accumulator.num_evals + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def count_total_time(model): + + class Accumulator(object): + + def __init__(self): + self.total_time = 0 + + def __call__(self, module): + if isinstance(module, layers.CNF): + self.total_time = self.total_time + module.sqrt_end_time * module.sqrt_end_time + + accumulator = Accumulator() + model.apply(accumulator) + return accumulator.total_time + + +def add_spectral_norm(model, logger=None): + """Applies spectral norm to all modules within the scope of a CNF.""" + + def apply_spectral_norm(module): + if 'weight' in module._parameters: + if logger: logger.info("Adding spectral norm to {}".format(module)) + spectral_norm.inplace_spectral_norm(module, 'weight') + + def find_cnf(module): + if isinstance(module, layers.CNF): + module.apply(apply_spectral_norm) + else: + for child in module.children(): + find_cnf(child) + + find_cnf(model) + + +def spectral_norm_power_iteration(model, n_power_iterations=1): + + def recursive_power_iteration(module): + if hasattr(module, spectral_norm.POWER_ITERATION_FN): + getattr(module, spectral_norm.POWER_ITERATION_FN)(n_power_iterations) + + model.apply(recursive_power_iteration) + + +REGULARIZATION_FNS = { + "l1int": reg_lib.l1_regularzation_fn, + "l2int": reg_lib.l2_regularzation_fn, + "dl2int": reg_lib.directional_l2_regularization_fn, + "JFrobint": reg_lib.jacobian_frobenius_regularization_fn, + "JdiagFrobint": reg_lib.jacobian_diag_frobenius_regularization_fn, + "JoffdiagFrobint": reg_lib.jacobian_offdiag_frobenius_regularization_fn, +} + +INV_REGULARIZATION_FNS = {v: k for k, v in six.iteritems(REGULARIZATION_FNS)} + + +def append_regularization_to_log(log_message, regularization_fns, reg_states): + for i, reg_fn in enumerate(regularization_fns): + log_message = log_message + " | " + INV_REGULARIZATION_FNS[reg_fn] + ": {:.8f}".format(reg_states[i].item()) + return log_message + + +def create_regularization_fns(args): + regularization_fns = [] + regularization_coeffs = [] + + for arg_key, reg_fn in six.iteritems(REGULARIZATION_FNS): + if getattr(args, arg_key) is not None: + regularization_fns.append(reg_fn) + regularization_coeffs.append(eval("args." + arg_key)) + + regularization_fns = tuple(regularization_fns) + regularization_coeffs = tuple(regularization_coeffs) + return regularization_fns, regularization_coeffs + + +def get_regularization(model, regularization_coeffs): + if len(regularization_coeffs) == 0: + return None + + acc_reg_states = tuple([0.] * len(regularization_coeffs)) + for module in model.modules(): + if isinstance(module, layers.CNF): + acc_reg_states = tuple(acc + reg for acc, reg in zip(acc_reg_states, module.get_regularization_states())) + return acc_reg_states + + +def build_model_tabular(args, dims, regularization_fns=None): + + hidden_dims = tuple(map(int, args.dims.split("-"))) + + def build_cnf(): + diffeq = layers.ODEnet( + hidden_dims=hidden_dims, + input_shape=(dims,), + strides=None, + conv=False, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = layers.ODEfunc( + diffeq=diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + cnf = layers.CNF( + odefunc=odefunc, + T=args.time_length, + train_T=args.train_T, + regularization_fns=regularization_fns, + solver=args.solver, + ) + return cnf + + chain = [build_cnf() for _ in range(args.num_blocks)] + if args.batch_norm: + bn_layers = [layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag) for _ in range(args.num_blocks)] + bn_chain = [layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag)] + for a, b in zip(chain, bn_layers): + bn_chain.append(a) + bn_chain.append(b) + chain = bn_chain + model = layers.SequentialFlow(chain) + + set_cnf_options(args, model) + + return model diff --git a/src/torchprune/torchprune/util/external/ffjord/train_tabular.py b/src/torchprune/torchprune/util/external/ffjord/train_tabular.py new file mode 100644 index 0000000..b66d1d7 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_tabular.py @@ -0,0 +1,306 @@ +import argparse +import os +import time + +import torch + +import lib.utils as utils +import lib.layers.odefunc as odefunc +from lib.custom_optimizers import Adam + +import datasets + +from train_misc import standard_normal_logprob +from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time +from train_misc import create_regularization_fns, get_regularization, append_regularization_to_log +from train_misc import build_model_tabular, override_divergence_fn + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser('Continuous Normalizing Flow') +parser.add_argument( + '--data', choices=['power', 'gas', 'hepmass', 'miniboone', 'bsds300'], type=str, default='miniboone' +) +parser.add_argument( + "--layer_type", type=str, default="concatsquash", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument('--hdim_factor', type=int, default=10) +parser.add_argument('--nhidden', type=int, default=1) +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') +parser.add_argument('--time_length', type=float, default=1.0) +parser.add_argument('--train_T', type=eval, default=True) +parser.add_argument("--divergence_fn", type=str, default="approximate", choices=["brute_force", "approximate"]) +parser.add_argument("--nonlinearity", type=str, default="softplus", choices=odefunc.NONLINEARITIES) + +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-8) +parser.add_argument('--rtol', type=float, default=1e-6) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) + +parser.add_argument('--early_stopping', type=int, default=30) +parser.add_argument('--batch_size', type=int, default=1000) +parser.add_argument('--test_batch_size', type=int, default=None) +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--weight_decay', type=float, default=1e-6) + +# Track quantities +parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1") +parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2") +parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2") +parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F") +parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F") +parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") + +parser.add_argument('--resume', type=str, default=None) +parser.add_argument('--save', type=str, default='experiments/cnf') +parser.add_argument('--evaluate', action='store_true') +parser.add_argument('--val_freq', type=int, default=200) +parser.add_argument('--log_freq', type=int, default=10) +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) + +if args.layer_type == "blend": + logger.info("!! Setting time_length from None to 1.0 due to use of Blend layers.") + args.time_length = 1.0 + args.train_T = False + +logger.info(args) + +test_batch_size = args.test_batch_size if args.test_batch_size else args.batch_size + + +def batch_iter(X, batch_size=args.batch_size, shuffle=False): + """ + X: feature tensor (shape: num_instances x num_features) + """ + if shuffle: + idxs = torch.randperm(X.shape[0]) + else: + idxs = torch.arange(X.shape[0]) + if X.is_cuda: + idxs = idxs.cuda() + for batch_idxs in idxs.split(batch_size): + yield X[batch_idxs] + + +ndecs = 0 + + +def update_lr(optimizer, n_vals_without_improvement): + global ndecs + if ndecs == 0 and n_vals_without_improvement > args.early_stopping // 3: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr / 10 + ndecs = 1 + elif ndecs == 1 and n_vals_without_improvement > args.early_stopping // 3 * 2: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr / 100 + ndecs = 2 + else: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr / 10**ndecs + + +def load_data(name): + + if name == 'bsds300': + return datasets.BSDS300() + + elif name == 'power': + return datasets.POWER() + + elif name == 'gas': + return datasets.GAS() + + elif name == 'hepmass': + return datasets.HEPMASS() + + elif name == 'miniboone': + return datasets.MINIBOONE() + + else: + raise ValueError('Unknown dataset') + + +def compute_loss(x, model): + zero = torch.zeros(x.shape[0], 1).to(x) + + z, delta_logp = model(x, zero) # run model forward + + logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) + logpx = logpz - delta_logp + loss = -torch.mean(logpx) + return loss + + +def restore_model(model, filename): + checkpt = torch.load(filename, map_location=lambda storage, loc: storage) + model.load_state_dict(checkpt["state_dict"]) + return model + + +if __name__ == '__main__': + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) + + logger.info('Using {} GPUs.'.format(torch.cuda.device_count())) + + data = load_data(args.data) + data.trn.x = torch.from_numpy(data.trn.x) + data.val.x = torch.from_numpy(data.val.x) + data.tst.x = torch.from_numpy(data.tst.x) + + args.dims = '-'.join([str(args.hdim_factor * data.n_dims)] * args.nhidden) + + regularization_fns, regularization_coeffs = create_regularization_fns(args) + model = build_model_tabular(args, data.n_dims, regularization_fns).to(device) + set_cnf_options(args, model) + + for k in model.state_dict().keys(): + logger.info(k) + + if args.resume is not None: + checkpt = torch.load(args.resume) + + # Backwards compatibility with an older version of the code. + # TODO: remove upon release. + filtered_state_dict = {} + for k, v in checkpt['state_dict'].items(): + if 'diffeq.diffeq' not in k: + filtered_state_dict[k.replace('module.', '')] = v + model.load_state_dict(filtered_state_dict) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + if not args.evaluate: + optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + time_meter = utils.RunningAverageMeter(0.98) + loss_meter = utils.RunningAverageMeter(0.98) + nfef_meter = utils.RunningAverageMeter(0.98) + nfeb_meter = utils.RunningAverageMeter(0.98) + tt_meter = utils.RunningAverageMeter(0.98) + + best_loss = float('inf') + itr = 0 + n_vals_without_improvement = 0 + end = time.time() + model.train() + while True: + if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping: + break + + for x in batch_iter(data.trn.x, shuffle=True): + if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping: + break + + optimizer.zero_grad() + + x = cvt(x) + loss = compute_loss(x, model) + loss_meter.update(loss.item()) + + if len(regularization_coeffs) > 0: + reg_states = get_regularization(model, regularization_coeffs) + reg_loss = sum( + reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 + ) + loss = loss + reg_loss + + total_time = count_total_time(model) + nfe_forward = count_nfe(model) + + loss.backward() + optimizer.step() + + nfe_total = count_nfe(model) + nfe_backward = nfe_total - nfe_forward + nfef_meter.update(nfe_forward) + nfeb_meter.update(nfe_backward) + + time_meter.update(time.time() - end) + tt_meter.update(total_time) + + if itr % args.log_freq == 0: + log_message = ( + 'Iter {:06d} | Epoch {:.2f} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | ' + 'NFE Forward {:.0f}({:.1f}) | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( + itr, + float(itr) / (data.trn.x.shape[0] / float(args.batch_size)), time_meter.val, time_meter.avg, + loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, + nfeb_meter.avg, tt_meter.val, tt_meter.avg + ) + ) + if len(regularization_coeffs) > 0: + log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) + + logger.info(log_message) + itr += 1 + end = time.time() + + # Validation loop. + if itr % args.val_freq == 0: + model.eval() + start_time = time.time() + with torch.no_grad(): + val_loss = utils.AverageMeter() + val_nfe = utils.AverageMeter() + for x in batch_iter(data.val.x, batch_size=test_batch_size): + x = cvt(x) + val_loss.update(compute_loss(x, model).item(), x.shape[0]) + val_nfe.update(count_nfe(model)) + + if val_loss.avg < best_loss: + best_loss = val_loss.avg + utils.makedirs(args.save) + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + }, os.path.join(args.save, 'checkpt.pth')) + n_vals_without_improvement = 0 + else: + n_vals_without_improvement += 1 + update_lr(optimizer, n_vals_without_improvement) + + log_message = ( + '[VAL] Iter {:06d} | Val Loss {:.6f} | NFE {:.0f} | ' + 'NoImproveEpochs {:02d}/{:02d}'.format( + itr, val_loss.avg, val_nfe.avg, n_vals_without_improvement, args.early_stopping + ) + ) + logger.info(log_message) + model.train() + + logger.info('Training has finished.') + model = restore_model(model, os.path.join(args.save, 'checkpt.pth')).to(device) + set_cnf_options(args, model) + + logger.info('Evaluating model on test set.') + model.eval() + + override_divergence_fn(model, "brute_force") + + with torch.no_grad(): + test_loss = utils.AverageMeter() + test_nfe = utils.AverageMeter() + for itr, x in enumerate(batch_iter(data.tst.x, batch_size=test_batch_size)): + x = cvt(x) + test_loss.update(compute_loss(x, model).item(), x.shape[0]) + test_nfe.update(count_nfe(model)) + logger.info('Progress: {:.2f}%'.format(100. * itr / (data.tst.x.shape[0] / test_batch_size))) + log_message = '[TEST] Iter {:06d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss.avg, test_nfe.avg) + logger.info(log_message) diff --git a/src/torchprune/torchprune/util/external/ffjord/train_toy.py b/src/torchprune/torchprune/util/external/ffjord/train_toy.py new file mode 100644 index 0000000..72ab811 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_toy.py @@ -0,0 +1,231 @@ +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import argparse +import os +import time + +import torch +import torch.optim as optim + +import lib.toy_data as toy_data +import lib.utils as utils +from lib.visualize_flow import visualize_transform +import lib.layers.odefunc as odefunc + +from train_misc import standard_normal_logprob +from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time +from train_misc import add_spectral_norm, spectral_norm_power_iteration +from train_misc import create_regularization_fns, get_regularization, append_regularization_to_log +from train_misc import build_model_tabular + +from diagnostics.viz_toy import save_trajectory, trajectory_to_video + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser('Continuous Normalizing Flow') +parser.add_argument( + '--data', choices=['swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings'], + type=str, default='pinwheel' +) +parser.add_argument( + "--layer_type", type=str, default="concatsquash", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument('--dims', type=str, default='64-64-64') +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') +parser.add_argument('--time_length', type=float, default=0.5) +parser.add_argument('--train_T', type=eval, default=True) +parser.add_argument("--divergence_fn", type=str, default="brute_force", choices=["brute_force", "approximate"]) +parser.add_argument("--nonlinearity", type=str, default="tanh", choices=odefunc.NONLINEARITIES) + +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) + +parser.add_argument('--niters', type=int, default=10000) +parser.add_argument('--batch_size', type=int, default=100) +parser.add_argument('--test_batch_size', type=int, default=1000) +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--weight_decay', type=float, default=1e-5) + +# Track quantities +parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1") +parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2") +parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2") +parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F") +parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F") +parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") + +parser.add_argument('--save', type=str, default='experiments/cnf') +parser.add_argument('--viz_freq', type=int, default=100) +parser.add_argument('--val_freq', type=int, default=100) +parser.add_argument('--log_freq', type=int, default=10) +parser.add_argument('--gpu', type=int, default=0) +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) + +if args.layer_type == "blend": + logger.info("!! Setting time_length from None to 1.0 due to use of Blend layers.") + args.time_length = 1.0 + +logger.info(args) + +device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + + +def get_transforms(model): + + def sample_fn(z, logpz=None): + if logpz is not None: + return model(z, logpz, reverse=True) + else: + return model(z, reverse=True) + + def density_fn(x, logpx=None): + if logpx is not None: + return model(x, logpx, reverse=False) + else: + return model(x, reverse=False) + + return sample_fn, density_fn + + +def compute_loss(args, model, batch_size=None): + if batch_size is None: batch_size = args.batch_size + + # load data + x = toy_data.inf_train_gen(args.data, batch_size=batch_size) + x = torch.from_numpy(x).type(torch.float32).to(device) + zero = torch.zeros(x.shape[0], 1).to(x) + + # transform to z + z, delta_logp = model(x, zero) + + # compute log q(z) + logpz = standard_normal_logprob(z).sum(1, keepdim=True) + + logpx = logpz - delta_logp + loss = -torch.mean(logpx) + return loss + + +if __name__ == '__main__': + + regularization_fns, regularization_coeffs = create_regularization_fns(args) + model = build_model_tabular(args, 2, regularization_fns).to(device) + if args.spectral_norm: add_spectral_norm(model) + set_cnf_options(args, model) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + time_meter = utils.RunningAverageMeter(0.93) + loss_meter = utils.RunningAverageMeter(0.93) + nfef_meter = utils.RunningAverageMeter(0.93) + nfeb_meter = utils.RunningAverageMeter(0.93) + tt_meter = utils.RunningAverageMeter(0.93) + + end = time.time() + best_loss = float('inf') + model.train() + for itr in range(1, args.niters + 1): + optimizer.zero_grad() + if args.spectral_norm: spectral_norm_power_iteration(model, 1) + + loss = compute_loss(args, model) + loss_meter.update(loss.item()) + + if len(regularization_coeffs) > 0: + reg_states = get_regularization(model, regularization_coeffs) + reg_loss = sum( + reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 + ) + loss = loss + reg_loss + + total_time = count_total_time(model) + nfe_forward = count_nfe(model) + + loss.backward() + optimizer.step() + + nfe_total = count_nfe(model) + nfe_backward = nfe_total - nfe_forward + nfef_meter.update(nfe_forward) + nfeb_meter.update(nfe_backward) + + time_meter.update(time.time() - end) + tt_meter.update(total_time) + + log_message = ( + 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' + ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( + itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, + nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg + ) + ) + if len(regularization_coeffs) > 0: + log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) + + logger.info(log_message) + + if itr % args.val_freq == 0 or itr == args.niters: + with torch.no_grad(): + model.eval() + test_loss = compute_loss(args, model, batch_size=args.test_batch_size) + test_nfe = count_nfe(model) + log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe) + logger.info(log_message) + + if test_loss.item() < best_loss: + best_loss = test_loss.item() + utils.makedirs(args.save) + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + }, os.path.join(args.save, 'checkpt.pth')) + model.train() + + if itr % args.viz_freq == 0: + with torch.no_grad(): + model.eval() + p_samples = toy_data.inf_train_gen(args.data, batch_size=2000) + + sample_fn, density_fn = get_transforms(model) + + plt.figure(figsize=(9, 3)) + visualize_transform( + p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, + samples=True, npts=800, device=device + ) + fig_filename = os.path.join(args.save, 'figs', '{:04d}.jpg'.format(itr)) + utils.makedirs(os.path.dirname(fig_filename)) + plt.savefig(fig_filename) + plt.close() + model.train() + + end = time.time() + + logger.info('Training has finished.') + + save_traj_dir = os.path.join(args.save, 'trajectory') + logger.info('Plotting trajectory to {}'.format(save_traj_dir)) + data_samples = toy_data.inf_train_gen(args.data, batch_size=2000) + save_trajectory(model, data_samples, save_traj_dir, device=device) + trajectory_to_video(save_traj_dir) diff --git a/src/torchprune/torchprune/util/external/ffjord/train_vae_flow.py b/src/torchprune/torchprune/util/external/ffjord/train_vae_flow.py new file mode 100644 index 0000000..7bec82b --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_vae_flow.py @@ -0,0 +1,358 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import print_function +import argparse +import time +import torch +import torch.utils.data +import torch.optim as optim +import numpy as np +import math +import random + +import os + +import datetime + +import lib.utils as utils +import lib.layers.odefunc as odefunc + +import vae_lib.models.VAE as VAE +import vae_lib.models.CNFVAE as CNFVAE +from vae_lib.optimization.training import train, evaluate +from vae_lib.utils.load_data import load_dataset +from vae_lib.utils.plotting import plot_training_curve + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser(description='PyTorch Sylvester Normalizing flows') + +parser.add_argument( + '-d', '--dataset', type=str, default='mnist', choices=['mnist', 'freyfaces', 'omniglot', 'caltech'], + metavar='DATASET', help='Dataset choice.' +) + +parser.add_argument( + '-freys', '--freyseed', type=int, default=123, metavar='FREYSEED', + help="""Seed for shuffling frey face dataset for test split. Ignored for other datasets. + Results in paper are produced with seeds 123, 321, 231""" +) + +parser.add_argument('-nc', '--no_cuda', action='store_true', default=False, help='disables CUDA training') + +parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.') + +parser.add_argument( + '-li', '--log_interval', type=int, default=10, metavar='LOG_INTERVAL', + help='how many batches to wait before logging training status' +) + +parser.add_argument( + '-od', '--out_dir', type=str, default='snapshots', metavar='OUT_DIR', + help='output directory for model snapshots etc.' +) + +# optimization settings +parser.add_argument( + '-e', '--epochs', type=int, default=2000, metavar='EPOCHS', help='number of epochs to train (default: 2000)' +) +parser.add_argument( + '-es', '--early_stopping_epochs', type=int, default=35, metavar='EARLY_STOPPING', + help='number of early stopping epochs' +) + +parser.add_argument( + '-bs', '--batch_size', type=int, default=100, metavar='BATCH_SIZE', help='input batch size for training' +) +parser.add_argument('-lr', '--learning_rate', type=float, default=0.0005, metavar='LEARNING_RATE', help='learning rate') + +parser.add_argument( + '-w', '--warmup', type=int, default=100, metavar='N', + help='number of epochs for warm-up. Set to 0 to turn warmup off.' +) +parser.add_argument('--max_beta', type=float, default=1., metavar='MB', help='max beta for warm-up') +parser.add_argument('--min_beta', type=float, default=0.0, metavar='MB', help='min beta for warm-up') +parser.add_argument( + '-f', '--flow', type=str, default='no_flow', choices=[ + 'planar', 'iaf', 'householder', 'orthogonal', 'triangular', 'cnf', 'cnf_bias', 'cnf_hyper', 'cnf_rank', + 'cnf_lyper', 'no_flow' + ], help="""Type of flows to use, no flows can also be selected""" +) +parser.add_argument('-r', '--rank', type=int, default=1) +parser.add_argument( + '-nf', '--num_flows', type=int, default=4, metavar='NUM_FLOWS', + help='Number of flow layers, ignored in absence of flows' +) +parser.add_argument( + '-nv', '--num_ortho_vecs', type=int, default=8, metavar='NUM_ORTHO_VECS', + help=""" For orthogonal flow: How orthogonal vectors per flow do you need. + Ignored for other flow types.""" +) +parser.add_argument( + '-nh', '--num_householder', type=int, default=8, metavar='NUM_HOUSEHOLDERS', + help=""" For Householder Sylvester flow: Number of Householder matrices per flow. + Ignored for other flow types.""" +) +parser.add_argument( + '-mhs', '--made_h_size', type=int, default=320, metavar='MADEHSIZE', + help='Width of mades for iaf. Ignored for all other flows.' +) +parser.add_argument('--z_size', type=int, default=64, metavar='ZSIZE', help='how many stochastic hidden units') +# gpu/cpu +parser.add_argument('--gpu_num', type=int, default=0, metavar='GPU', help='choose GPU to run on.') + +# CNF settings +parser.add_argument( + "--layer_type", type=str, default="concat", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument('--dims', type=str, default='512-512') +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') +parser.add_argument('--time_length', type=float, default=0.5) +parser.add_argument('--train_T', type=eval, default=False) +parser.add_argument("--divergence_fn", type=str, default="approximate", choices=["brute_force", "approximate"]) +parser.add_argument("--nonlinearity", type=str, default="softplus", choices=odefunc.NONLINEARITIES) + +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) +# evaluation +parser.add_argument('--evaluate', type=eval, default=False, choices=[True, False]) +parser.add_argument('--model_path', type=str, default='') +parser.add_argument('--retrain_encoder', type=eval, default=False, choices=[True, False]) + +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +if args.manual_seed is None: + args.manual_seed = random.randint(1, 100000) +random.seed(args.manual_seed) +torch.manual_seed(args.manual_seed) +np.random.seed(args.manual_seed) + +if args.cuda: + # gpu device number + torch.cuda.set_device(args.gpu_num) + +kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {} + + +def run(args, kwargs): + # ================================================================================================================== + # SNAPSHOTS + # ================================================================================================================== + args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_') + args.model_signature = args.model_signature.replace(':', '_') + + snapshots_path = os.path.join(args.out_dir, 'vae_' + args.dataset + '_') + snap_dir = snapshots_path + args.flow + + if args.flow != 'no_flow': + snap_dir += '_' + 'num_flows_' + str(args.num_flows) + + if args.flow == 'orthogonal': + snap_dir = snap_dir + '_num_vectors_' + str(args.num_ortho_vecs) + elif args.flow == 'orthogonalH': + snap_dir = snap_dir + '_num_householder_' + str(args.num_householder) + elif args.flow == 'iaf': + snap_dir = snap_dir + '_madehsize_' + str(args.made_h_size) + + elif args.flow == 'permutation': + snap_dir = snap_dir + '_' + 'kernelsize_' + str(args.kernel_size) + elif args.flow == 'mixed': + snap_dir = snap_dir + '_' + 'num_householder_' + str(args.num_householder) + elif args.flow == 'cnf_rank': + snap_dir = snap_dir + '_rank_' + str(args.rank) + '_' + args.dims + '_num_blocks_' + str(args.num_blocks) + elif 'cnf' in args.flow: + snap_dir = snap_dir + '_' + args.dims + '_num_blocks_' + str(args.num_blocks) + + if args.retrain_encoder: + snap_dir = snap_dir + '_retrain-encoder_' + elif args.evaluate: + snap_dir = snap_dir + '_evaluate_' + + snap_dir = snap_dir + '__' + args.model_signature + '/' + + args.snap_dir = snap_dir + + if not os.path.exists(snap_dir): + os.makedirs(snap_dir) + + # logger + utils.makedirs(args.snap_dir) + logger = utils.get_logger(logpath=os.path.join(args.snap_dir, 'logs'), filepath=os.path.abspath(__file__)) + + logger.info(args) + + # SAVING + torch.save(args, snap_dir + args.flow + '.config') + + # ================================================================================================================== + # LOAD DATA + # ================================================================================================================== + train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs) + + if not args.evaluate: + + # ============================================================================================================== + # SELECT MODEL + # ============================================================================================================== + # flow parameters and architecture choice are passed on to model through args + + if args.flow == 'no_flow': + model = VAE.VAE(args) + elif args.flow == 'planar': + model = VAE.PlanarVAE(args) + elif args.flow == 'iaf': + model = VAE.IAFVAE(args) + elif args.flow == 'orthogonal': + model = VAE.OrthogonalSylvesterVAE(args) + elif args.flow == 'householder': + model = VAE.HouseholderSylvesterVAE(args) + elif args.flow == 'triangular': + model = VAE.TriangularSylvesterVAE(args) + elif args.flow == 'cnf': + model = CNFVAE.CNFVAE(args) + elif args.flow == 'cnf_bias': + model = CNFVAE.AmortizedBiasCNFVAE(args) + elif args.flow == 'cnf_hyper': + model = CNFVAE.HypernetCNFVAE(args) + elif args.flow == 'cnf_lyper': + model = CNFVAE.LypernetCNFVAE(args) + elif args.flow == 'cnf_rank': + model = CNFVAE.AmortizedLowRankCNFVAE(args) + else: + raise ValueError('Invalid flow choice') + + if args.retrain_encoder: + logger.info(f"Initializing decoder from {args.model_path}") + dec_model = torch.load(args.model_path) + dec_sd = {} + for k, v in dec_model.state_dict().items(): + if 'p_x' in k: + dec_sd[k] = v + model.load_state_dict(dec_sd, strict=False) + + if args.cuda: + logger.info("Model on GPU") + model.cuda() + + logger.info(model) + + if args.retrain_encoder: + parameters = [] + logger.info('Optimizing over:') + for name, param in model.named_parameters(): + if 'p_x' not in name: + logger.info(name) + parameters.append(param) + else: + parameters = model.parameters() + + optimizer = optim.Adamax(parameters, lr=args.learning_rate, eps=1.e-7) + + # ================================================================================================================== + # TRAINING + # ================================================================================================================== + train_loss = [] + val_loss = [] + + # for early stopping + best_loss = np.inf + best_bpd = np.inf + e = 0 + epoch = 0 + + train_times = [] + + for epoch in range(1, args.epochs + 1): + + t_start = time.time() + tr_loss = train(epoch, train_loader, model, optimizer, args, logger) + train_loss.append(tr_loss) + train_times.append(time.time() - t_start) + logger.info('One training epoch took %.2f seconds' % (time.time() - t_start)) + + v_loss, v_bpd = evaluate(val_loader, model, args, logger, epoch=epoch) + + val_loss.append(v_loss) + + # early-stopping + if v_loss < best_loss: + e = 0 + best_loss = v_loss + if args.input_type != 'binary': + best_bpd = v_bpd + logger.info('->model saved<-') + torch.save(model, snap_dir + args.flow + '.model') + # torch.save(model, snap_dir + args.flow + '_' + args.architecture + '.model') + + elif (args.early_stopping_epochs > 0) and (epoch >= args.warmup): + e += 1 + if e > args.early_stopping_epochs: + break + + if args.input_type == 'binary': + logger.info( + '--> Early stopping: {}/{} (BEST: loss {:.4f})\n'.format(e, args.early_stopping_epochs, best_loss) + ) + + else: + logger.info( + '--> Early stopping: {}/{} (BEST: loss {:.4f}, bpd {:.4f})\n'. + format(e, args.early_stopping_epochs, best_loss, best_bpd) + ) + + if math.isnan(v_loss): + raise ValueError('NaN encountered!') + + train_loss = np.hstack(train_loss) + val_loss = np.array(val_loss) + + plot_training_curve(train_loss, val_loss, fname=snap_dir + '/training_curve_%s.pdf' % args.flow) + + # training time per epoch + train_times = np.array(train_times) + mean_train_time = np.mean(train_times) + std_train_time = np.std(train_times, ddof=1) + logger.info('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) + + # ================================================================================================================== + # EVALUATION + # ================================================================================================================== + + logger.info(args) + logger.info('Stopped after %d epochs' % epoch) + logger.info('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) + + final_model = torch.load(snap_dir + args.flow + '.model') + validation_loss, validation_bpd = evaluate(val_loader, final_model, args, logger) + + else: + validation_loss = "N/A" + validation_bpd = "N/A" + logger.info(f"Loading model from {args.model_path}") + final_model = torch.load(args.model_path) + + test_loss, test_bpd = evaluate(test_loader, final_model, args, logger, testing=True) + + logger.info('FINAL EVALUATION ON VALIDATION SET. ELBO (VAL): {:.4f}'.format(validation_loss)) + logger.info('FINAL EVALUATION ON TEST SET. NLL (TEST): {:.4f}'.format(test_loss)) + if args.input_type != 'binary': + logger.info('FINAL EVALUATION ON VALIDATION SET. ELBO (VAL) BPD : {:.4f}'.format(validation_bpd)) + logger.info('FINAL EVALUATION ON TEST SET. NLL (TEST) BPD: {:.4f}'.format(test_bpd)) + + +if __name__ == "__main__": + + run(args, kwargs) diff --git a/src/torchprune/torchprune/util/models/cnn/LICENSE b/src/torchprune/torchprune/util/external/ffjord/vae_lib/LICENSE similarity index 96% rename from src/torchprune/torchprune/util/models/cnn/LICENSE rename to src/torchprune/torchprune/util/external/ffjord/vae_lib/LICENSE index 0482bd0..a37f7eb 100644 --- a/src/torchprune/torchprune/util/models/cnn/LICENSE +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2017 Wei Yang +Copyright (c) 2018 Rianne van den Berg Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/CNFVAE.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/CNFVAE.py new file mode 100644 index 0000000..3e24dc4 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/CNFVAE.py @@ -0,0 +1,412 @@ +import torch +import torch.nn as nn +from train_misc import build_model_tabular +import lib.layers as layers +from .VAE import VAE +import lib.layers.diffeq_layers as diffeq_layers +from lib.layers.odefunc import NONLINEARITIES + +from torchdiffeq import odeint_adjoint as odeint + + +def get_hidden_dims(args): + return tuple(map(int, args.dims.split("-"))) + (args.z_size,) + + +def concat_layer_num_params(in_dim, out_dim): + return (in_dim + 1) * out_dim + out_dim + + +class CNFVAE(VAE): + + def __init__(self, args): + super(CNFVAE, self).__init__(args) + + # CNF model + self.cnf = build_model_tabular(args, args.z_size) + + if args.cuda: + self.cuda() + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and flow parameters. + """ + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + + return mean_z, var_z + + def forward(self, x): + """ + Forward pass with planar flows for the transformation z_0 -> z_1 -> ... -> z_k. + Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. + """ + + z_mu, z_var = self.encode(x) + + # Sample z_0 + z0 = self.reparameterize(z_mu, z_var) + + zero = torch.zeros(x.shape[0], 1).to(x) + zk, delta_logp = self.cnf(z0, zero) # run model forward + + x_mean = self.decode(zk) + + return x_mean, z_mu, z_var, -delta_logp.view(-1), z0, zk + + +class AmortizedBiasODEnet(nn.Module): + + def __init__(self, hidden_dims, input_dim, layer_type="concat", nonlinearity="softplus"): + super(AmortizedBiasODEnet, self).__init__() + base_layer = { + "ignore": diffeq_layers.IgnoreLinear, + "hyper": diffeq_layers.HyperLinear, + "squash": diffeq_layers.SquashLinear, + "concat": diffeq_layers.ConcatLinear, + "concat_v2": diffeq_layers.ConcatLinear_v2, + "concatsquash": diffeq_layers.ConcatSquashLinear, + "blend": diffeq_layers.BlendLinear, + "concatcoord": diffeq_layers.ConcatLinear, + }[layer_type] + self.input_dim = input_dim + + # build layers and add them + layers = [] + activation_fns = [] + hidden_shape = input_dim + for dim_out in hidden_dims: + layer = base_layer(hidden_shape, dim_out) + layers.append(layer) + activation_fns.append(NONLINEARITIES[nonlinearity]) + hidden_shape = dim_out + + self.layers = nn.ModuleList(layers) + self.activation_fns = nn.ModuleList(activation_fns[:-1]) + + def _unpack_params(self, params): + return [params] + + def forward(self, t, y, am_biases): + dx = y + for l, layer in enumerate(self.layers): + dx = layer(t, dx) + this_bias, am_biases = am_biases[:, :dx.size(1)], am_biases[:, dx.size(1):] + dx = dx + this_bias + # if not last layer, use nonlinearity + if l < len(self.layers) - 1: + dx = self.activation_fns[l](dx) + return dx + + +class AmortizedLowRankODEnet(nn.Module): + + def __init__(self, hidden_dims, input_dim, rank=1, layer_type="concat", nonlinearity="softplus"): + super(AmortizedLowRankODEnet, self).__init__() + base_layer = { + "ignore": diffeq_layers.IgnoreLinear, + "hyper": diffeq_layers.HyperLinear, + "squash": diffeq_layers.SquashLinear, + "concat": diffeq_layers.ConcatLinear, + "concat_v2": diffeq_layers.ConcatLinear_v2, + "concatsquash": diffeq_layers.ConcatSquashLinear, + "blend": diffeq_layers.BlendLinear, + "concatcoord": diffeq_layers.ConcatLinear, + }[layer_type] + self.input_dim = input_dim + + # build layers and add them + layers = [] + activation_fns = [] + hidden_shape = input_dim + self.output_dims = hidden_dims + self.input_dims = (input_dim,) + hidden_dims[:-1] + for dim_out in hidden_dims: + layer = base_layer(hidden_shape, dim_out) + layers.append(layer) + activation_fns.append(NONLINEARITIES[nonlinearity]) + hidden_shape = dim_out + + self.layers = nn.ModuleList(layers) + self.activation_fns = nn.ModuleList(activation_fns[:-1]) + self.rank = rank + + def _unpack_params(self, params): + return [params] + + def _rank_k_bmm(self, x, u, v): + xu = torch.bmm(x[:, None], u.view(x.shape[0], x.shape[-1], self.rank)) + xuv = torch.bmm(xu, v.view(x.shape[0], self.rank, -1)) + return xuv[:, 0] + + def forward(self, t, y, am_params): + dx = y + for l, (layer, in_dim, out_dim) in enumerate(zip(self.layers, self.input_dims, self.output_dims)): + this_u, am_params = am_params[:, :in_dim * self.rank], am_params[:, in_dim * self.rank:] + this_v, am_params = am_params[:, :out_dim * self.rank], am_params[:, out_dim * self.rank:] + this_bias, am_params = am_params[:, :out_dim], am_params[:, out_dim:] + + xw = layer(t, dx) + xw_am = self._rank_k_bmm(dx, this_u, this_v) + dx = xw + xw_am + this_bias + # if not last layer, use nonlinearity + if l < len(self.layers) - 1: + dx = self.activation_fns[l](dx) + return dx + + +class HyperODEnet(nn.Module): + + def __init__(self, hidden_dims, input_dim, layer_type="concat", nonlinearity="softplus"): + super(HyperODEnet, self).__init__() + assert layer_type == "concat" + self.input_dim = input_dim + + # build layers and add them + activation_fns = [] + for dim_out in hidden_dims + (input_dim,): + activation_fns.append(NONLINEARITIES[nonlinearity]) + self.activation_fns = nn.ModuleList(activation_fns[:-1]) + self.output_dims = hidden_dims + self.input_dims = (input_dim,) + hidden_dims[:-1] + + def _pack_inputs(self, t, x): + tt = torch.ones_like(x[:, :1]) * t + ttx = torch.cat([tt, x], 1) + return ttx + + def _unpack_params(self, params): + layer_params = [] + for in_dim, out_dim in zip(self.input_dims, self.output_dims): + this_num_params = concat_layer_num_params(in_dim, out_dim) + # get params for this layer + this_params, params = params[:, :this_num_params], params[:, this_num_params:] + # split into weight and bias + bias, weight_params = this_params[:, :out_dim], this_params[:, out_dim:] + weight = weight_params.view(weight_params.size(0), in_dim + 1, out_dim) + layer_params.append(weight) + layer_params.append(bias) + return layer_params + + def _layer(self, t, x, weight, bias): + # weights is (batch, in_dim + 1, out_dim) + ttx = self._pack_inputs(t, x) # (batch, in_dim + 1) + ttx = ttx.view(ttx.size(0), 1, ttx.size(1)) # (batch, 1, in_dim + 1) + xw = torch.bmm(ttx, weight)[:, 0, :] # (batch, out_dim) + return xw + bias + + def forward(self, t, y, *layer_params): + dx = y + for l, (weight, bias) in enumerate(zip(layer_params[::2], layer_params[1::2])): + dx = self._layer(t, dx, weight, bias) + # if not last layer, use nonlinearity + if l < len(layer_params) - 1: + dx = self.activation_fns[l](dx) + return dx + + +class LyperODEnet(nn.Module): + + def __init__(self, hidden_dims, input_dim, layer_type="concat", nonlinearity="softplus"): + super(LyperODEnet, self).__init__() + base_layer = { + "ignore": diffeq_layers.IgnoreLinear, + "hyper": diffeq_layers.HyperLinear, + "squash": diffeq_layers.SquashLinear, + "concat": diffeq_layers.ConcatLinear, + "concat_v2": diffeq_layers.ConcatLinear_v2, + "concatsquash": diffeq_layers.ConcatSquashLinear, + "blend": diffeq_layers.BlendLinear, + "concatcoord": diffeq_layers.ConcatLinear, + }[layer_type] + self.input_dim = input_dim + + # build layers and add them + layers = [] + activation_fns = [] + hidden_shape = input_dim + self.dims = (input_dim,) + hidden_dims + self.output_dims = hidden_dims + self.input_dims = (input_dim,) + hidden_dims[:-1] + for dim_out in hidden_dims[:-1]: + layer = base_layer(hidden_shape, dim_out) + layers.append(layer) + activation_fns.append(NONLINEARITIES[nonlinearity]) + hidden_shape = dim_out + + self.layers = nn.ModuleList(layers) + self.activation_fns = nn.ModuleList(activation_fns) + + def _pack_inputs(self, t, x): + tt = torch.ones_like(x[:, :1]) * t + ttx = torch.cat([tt, x], 1) + return ttx + + def _unpack_params(self, params): + return [params] + + def _am_layer(self, t, x, weight, bias): + # weights is (batch, in_dim + 1, out_dim) + ttx = self._pack_inputs(t, x) # (batch, in_dim + 1) + ttx = ttx.view(ttx.size(0), 1, ttx.size(1)) # (batch, 1, in_dim + 1) + xw = torch.bmm(ttx, weight)[:, 0, :] # (batch, out_dim) + return xw + bias + + def forward(self, t, x, am_params): + dx = x + for layer, act in zip(self.layers, self.activation_fns): + dx = act(layer(t, dx)) + bias, weight_params = am_params[:, :self.dims[-1]], am_params[:, self.dims[-1]:] + weight = weight_params.view(weight_params.size(0), self.dims[-2] + 1, self.dims[-1]) + dx = self._am_layer(t, dx, weight, bias) + return dx + + +def construct_amortized_odefunc(args, z_dim, amortization_type="bias"): + + hidden_dims = get_hidden_dims(args) + + if amortization_type == "bias": + diffeq = AmortizedBiasODEnet( + hidden_dims=hidden_dims, + input_dim=z_dim, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + elif amortization_type == "hyper": + diffeq = HyperODEnet( + hidden_dims=hidden_dims, + input_dim=z_dim, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + elif amortization_type == "lyper": + diffeq = LyperODEnet( + hidden_dims=hidden_dims, + input_dim=z_dim, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + elif amortization_type == "low_rank": + diffeq = AmortizedLowRankODEnet( + hidden_dims=hidden_dims, + input_dim=z_dim, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + rank=args.rank, + ) + odefunc = layers.ODEfunc( + diffeq=diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + return odefunc + + +class AmortizedCNFVAE(VAE): + h_size = 256 + + def __init__(self, args): + super(AmortizedCNFVAE, self).__init__(args) + + # CNF model + self.odefuncs = nn.ModuleList([ + construct_amortized_odefunc(args, args.z_size, self.amortization_type) for _ in range(args.num_blocks) + ]) + self.q_am = self._amortized_layers(args) + assert len(self.q_am) == args.num_blocks or len(self.q_am) == 0 + + if args.cuda: + self.cuda() + + self.register_buffer('integration_times', torch.tensor([0.0, args.time_length])) + + self.atol = args.atol + self.rtol = args.rtol + self.solver = args.solver + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and flow parameters. + """ + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + am_params = [q_am(h) for q_am in self.q_am] + + return mean_z, var_z, am_params + + def forward(self, x): + + self.log_det_j = 0. + + z_mu, z_var, am_params = self.encode(x) + + # Sample z_0 + z0 = self.reparameterize(z_mu, z_var) + + delta_logp = torch.zeros(x.shape[0], 1).to(x) + z = z0 + for odefunc, am_param in zip(self.odefuncs, am_params): + am_param_unpacked = odefunc.diffeq._unpack_params(am_param) + odefunc.before_odeint() + states = odeint( + odefunc, + (z, delta_logp) + tuple(am_param_unpacked), + self.integration_times.to(z), + atol=self.atol, + rtol=self.rtol, + method=self.solver, + ) + z, delta_logp = states[0][-1], states[1][-1] + + x_mean = self.decode(z) + + return x_mean, z_mu, z_var, -delta_logp.view(-1), z0, z + + +class AmortizedBiasCNFVAE(AmortizedCNFVAE): + amortization_type = "bias" + + def _amortized_layers(self, args): + hidden_dims = get_hidden_dims(args) + bias_size = sum(hidden_dims) + return nn.ModuleList([nn.Linear(self.h_size, bias_size) for _ in range(args.num_blocks)]) + + +class AmortizedLowRankCNFVAE(AmortizedCNFVAE): + amortization_type = "low_rank" + + def _amortized_layers(self, args): + out_dims = get_hidden_dims(args) + in_dims = (out_dims[-1],) + out_dims[:-1] + params_size = (sum(in_dims) + sum(out_dims)) * args.rank + sum(out_dims) + return nn.ModuleList([nn.Linear(self.h_size, params_size) for _ in range(args.num_blocks)]) + + +class HypernetCNFVAE(AmortizedCNFVAE): + amortization_type = "hyper" + + def _amortized_layers(self, args): + hidden_dims = get_hidden_dims(args) + input_dims = (args.z_size,) + hidden_dims[:-1] + assert args.layer_type == "concat", "hypernets only support concat layers at the moment" + weight_dims = [concat_layer_num_params(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, hidden_dims)] + weight_size = sum(weight_dims) + return nn.ModuleList([nn.Linear(self.h_size, weight_size) for _ in range(args.num_blocks)]) + + +class LypernetCNFVAE(AmortizedCNFVAE): + amortization_type = "lyper" + + def _amortized_layers(self, args): + dims = (args.z_size,) + get_hidden_dims(args) + weight_size = concat_layer_num_params(dims[-2], dims[-1]) + return nn.ModuleList([nn.Linear(self.h_size, weight_size) for _ in range(args.num_blocks)]) diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/VAE.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/VAE.py new file mode 100644 index 0000000..96e4ea5 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/VAE.py @@ -0,0 +1,735 @@ +from __future__ import print_function + +import torch +import torch.nn as nn +import vae_lib.models.flows as flows +from vae_lib.models.layers import GatedConv2d, GatedConvTranspose2d + + +class VAE(nn.Module): + """ + The base VAE class containing gated convolutional encoder and decoder architecture. + Can be used as a base class for VAE's with normalizing flows. + """ + + def __init__(self, args): + super(VAE, self).__init__() + + # extract model settings from args + self.z_size = args.z_size + self.input_size = args.input_size + self.input_type = args.input_type + + if self.input_size == [1, 28, 28] or self.input_size == [3, 28, 28]: + self.last_kernel_size = 7 + elif self.input_size == [1, 28, 20]: + self.last_kernel_size = (7, 5) + else: + raise ValueError('invalid input size!!') + + self.q_z_nn, self.q_z_mean, self.q_z_var = self.create_encoder() + self.p_x_nn, self.p_x_mean = self.create_decoder() + + self.q_z_nn_output_dim = 256 + + # auxiliary + if args.cuda: + self.FloatTensor = torch.cuda.FloatTensor + else: + self.FloatTensor = torch.FloatTensor + + # log-det-jacobian = 0 without flows + self.log_det_j = self.FloatTensor(1).zero_() + + def create_encoder(self): + """ + Helper function to create the elemental blocks for the encoder. Creates a gated convnet encoder. + the encoder expects data as input of shape (batch_size, num_channels, width, height). + """ + + if self.input_type == 'binary': + q_z_nn = nn.Sequential( + GatedConv2d(self.input_size[0], 32, 5, 1, 2), + GatedConv2d(32, 32, 5, 2, 2), + GatedConv2d(32, 64, 5, 1, 2), + GatedConv2d(64, 64, 5, 2, 2), + GatedConv2d(64, 64, 5, 1, 2), + GatedConv2d(64, 256, self.last_kernel_size, 1, 0), + ) + q_z_mean = nn.Linear(256, self.z_size) + q_z_var = nn.Sequential( + nn.Linear(256, self.z_size), + nn.Softplus(), + ) + return q_z_nn, q_z_mean, q_z_var + + elif self.input_type == 'multinomial': + act = None + + q_z_nn = nn.Sequential( + GatedConv2d(self.input_size[0], 32, 5, 1, 2, activation=act), + GatedConv2d(32, 32, 5, 2, 2, activation=act), + GatedConv2d(32, 64, 5, 1, 2, activation=act), + GatedConv2d(64, 64, 5, 2, 2, activation=act), + GatedConv2d(64, 64, 5, 1, 2, activation=act), + GatedConv2d(64, 256, self.last_kernel_size, 1, 0, activation=act) + ) + q_z_mean = nn.Linear(256, self.z_size) + q_z_var = nn.Sequential(nn.Linear(256, self.z_size), nn.Softplus(), nn.Hardtanh(min_val=0.01, max_val=7.)) + return q_z_nn, q_z_mean, q_z_var + + def create_decoder(self): + """ + Helper function to create the elemental blocks for the decoder. Creates a gated convnet decoder. + """ + + num_classes = 256 + + if self.input_type == 'binary': + p_x_nn = nn.Sequential( + GatedConvTranspose2d(self.z_size, 64, self.last_kernel_size, 1, 0), + GatedConvTranspose2d(64, 64, 5, 1, 2), + GatedConvTranspose2d(64, 32, 5, 2, 2, 1), + GatedConvTranspose2d(32, 32, 5, 1, 2), + GatedConvTranspose2d(32, 32, 5, 2, 2, 1), GatedConvTranspose2d(32, 32, 5, 1, 2) + ) + + p_x_mean = nn.Sequential(nn.Conv2d(32, self.input_size[0], 1, 1, 0), nn.Sigmoid()) + return p_x_nn, p_x_mean + + elif self.input_type == 'multinomial': + act = None + p_x_nn = nn.Sequential( + GatedConvTranspose2d(self.z_size, 64, self.last_kernel_size, 1, 0, activation=act), + GatedConvTranspose2d(64, 64, 5, 1, 2, activation=act), + GatedConvTranspose2d(64, 32, 5, 2, 2, 1, activation=act), + GatedConvTranspose2d(32, 32, 5, 1, 2, activation=act), + GatedConvTranspose2d(32, 32, 5, 2, 2, 1, activation=act), + GatedConvTranspose2d(32, 32, 5, 1, 2, activation=act) + ) + + p_x_mean = nn.Sequential( + nn.Conv2d(32, 256, 5, 1, 2), + nn.Conv2d(256, self.input_size[0] * num_classes, 1, 1, 0), + # output shape: batch_size, num_channels * num_classes, pixel_width, pixel_height + ) + + return p_x_nn, p_x_mean + + else: + raise ValueError('invalid input type!!') + + def reparameterize(self, mu, var): + """ + Samples z from a multivariate Gaussian with diagonal covariance matrix using the + reparameterization trick. + """ + + std = var.sqrt() + eps = self.FloatTensor(std.size()).normal_() + z = eps.mul(std).add_(mu) + + return z + + def encode(self, x): + """ + Encoder expects following data shapes as input: shape = (batch_size, num_channels, width, height) + """ + + h = self.q_z_nn(x) + h = h.view(h.size(0), -1) + mean = self.q_z_mean(h) + var = self.q_z_var(h) + + return mean, var + + def decode(self, z): + """ + Decoder outputs reconstructed image in the following shapes: + x_mean.shape = (batch_size, num_channels, width, height) + """ + + z = z.view(z.size(0), self.z_size, 1, 1) + h = self.p_x_nn(z) + x_mean = self.p_x_mean(h) + + return x_mean + + def forward(self, x): + """ + Evaluates the model as a whole, encodes and decodes. Note that the log det jacobian is zero + for a plain VAE (without flows), and z_0 = z_k. + """ + + # mean and variance of z + z_mu, z_var = self.encode(x) + # sample z + z = self.reparameterize(z_mu, z_var) + x_mean = self.decode(z) + + return x_mean, z_mu, z_var, self.log_det_j, z, z + + +class PlanarVAE(VAE): + """ + Variational auto-encoder with planar flows in the encoder. + """ + + def __init__(self, args): + super(PlanarVAE, self).__init__(args) + + # Initialize log-det-jacobian to zero + self.log_det_j = 0. + + # Flow parameters + flow = flows.Planar + self.num_flows = args.num_flows + + # Amortized flow parameters + self.amor_u = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size) + self.amor_w = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size) + self.amor_b = nn.Linear(self.q_z_nn_output_dim, self.num_flows) + + # Normalizing flow layers + for k in range(self.num_flows): + flow_k = flow() + self.add_module('flow_' + str(k), flow_k) + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and flow parameters. + """ + + batch_size = x.size(0) + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + + # return amortized u an w for all flows + u = self.amor_u(h).view(batch_size, self.num_flows, self.z_size, 1) + w = self.amor_w(h).view(batch_size, self.num_flows, 1, self.z_size) + b = self.amor_b(h).view(batch_size, self.num_flows, 1, 1) + + return mean_z, var_z, u, w, b + + def forward(self, x): + """ + Forward pass with planar flows for the transformation z_0 -> z_1 -> ... -> z_k. + Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. + """ + + self.log_det_j = 0. + + z_mu, z_var, u, w, b = self.encode(x) + + # Sample z_0 + z = [self.reparameterize(z_mu, z_var)] + + # Normalizing flows + for k in range(self.num_flows): + flow_k = getattr(self, 'flow_' + str(k)) + z_k, log_det_jacobian = flow_k(z[k], u[:, k, :, :], w[:, k, :, :], b[:, k, :, :]) + z.append(z_k) + self.log_det_j += log_det_jacobian + + x_mean = self.decode(z[-1]) + + return x_mean, z_mu, z_var, self.log_det_j, z[0], z[-1] + + +class OrthogonalSylvesterVAE(VAE): + """ + Variational auto-encoder with orthogonal flows in the encoder. + """ + + def __init__(self, args): + super(OrthogonalSylvesterVAE, self).__init__(args) + + # Initialize log-det-jacobian to zero + self.log_det_j = 0. + + # Flow parameters + flow = flows.Sylvester + self.num_flows = args.num_flows + self.num_ortho_vecs = args.num_ortho_vecs + + assert (self.num_ortho_vecs <= self.z_size) and (self.num_ortho_vecs > 0) + + # Orthogonalization parameters + if self.num_ortho_vecs == self.z_size: + self.cond = 1.e-5 + else: + self.cond = 1.e-6 + + self.steps = 100 + identity = torch.eye(self.num_ortho_vecs, self.num_ortho_vecs) + # Add batch dimension + identity = identity.unsqueeze(0) + # Put identity in buffer so that it will be moved to GPU if needed by any call of .cuda + self.register_buffer('_eye', identity) + self._eye.requires_grad = False + + # Masks needed for triangular R1 and R2. + triu_mask = torch.triu(torch.ones(self.num_ortho_vecs, self.num_ortho_vecs), diagonal=1) + triu_mask = triu_mask.unsqueeze(0).unsqueeze(3) + diag_idx = torch.arange(0, self.num_ortho_vecs).long() + + self.register_buffer('triu_mask', triu_mask) + self.triu_mask.requires_grad = False + self.register_buffer('diag_idx', diag_idx) + + # Amortized flow parameters + # Diagonal elements of R1 * R2 have to satisfy -1 < R1 * R2 for flow to be invertible + self.diag_activation = nn.Tanh() + + self.amor_d = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.num_ortho_vecs * self.num_ortho_vecs) + + self.amor_diag1 = nn.Sequential( + nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.num_ortho_vecs), self.diag_activation + ) + self.amor_diag2 = nn.Sequential( + nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.num_ortho_vecs), self.diag_activation + ) + + self.amor_q = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size * self.num_ortho_vecs) + self.amor_b = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.num_ortho_vecs) + + # Normalizing flow layers + for k in range(self.num_flows): + flow_k = flow(self.num_ortho_vecs) + self.add_module('flow_' + str(k), flow_k) + + def batch_construct_orthogonal(self, q): + """ + Batch orthogonal matrix construction. + :param q: q contains batches of matrices, shape : (batch_size * num_flows, z_size * num_ortho_vecs) + :return: batches of orthogonalized matrices, shape: (batch_size * num_flows, z_size, num_ortho_vecs) + """ + + # Reshape to shape (num_flows * batch_size, z_size * num_ortho_vecs) + q = q.view(-1, self.z_size * self.num_ortho_vecs) + + norm = torch.norm(q, p=2, dim=1, keepdim=True) + amat = torch.div(q, norm) + dim0 = amat.size(0) + amat = amat.resize(dim0, self.z_size, self.num_ortho_vecs) + + max_norm = 0. + + # Iterative orthogonalization + for s in range(self.steps): + tmp = torch.bmm(amat.transpose(2, 1), amat) + tmp = self._eye - tmp + tmp = self._eye + 0.5 * tmp + amat = torch.bmm(amat, tmp) + + # Testing for convergence + test = torch.bmm(amat.transpose(2, 1), amat) - self._eye + norms2 = torch.sum(torch.norm(test, p=2, dim=2)**2, dim=1) + norms = torch.sqrt(norms2) + max_norm = torch.max(norms).item() + if max_norm <= self.cond: + break + + if max_norm > self.cond: + print('\nWARNING WARNING WARNING: orthogonalization not complete') + print('\t Final max norm =', max_norm) + + print() + + # Reshaping: first dimension is batch_size + amat = amat.view(-1, self.num_flows, self.z_size, self.num_ortho_vecs) + amat = amat.transpose(0, 1) + + return amat + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and flow parameters. + """ + + batch_size = x.size(0) + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + + # Amortized r1, r2, q, b for all flows + + full_d = self.amor_d(h) + diag1 = self.amor_diag1(h) + diag2 = self.amor_diag2(h) + + full_d = full_d.resize(batch_size, self.num_ortho_vecs, self.num_ortho_vecs, self.num_flows) + diag1 = diag1.resize(batch_size, self.num_ortho_vecs, self.num_flows) + diag2 = diag2.resize(batch_size, self.num_ortho_vecs, self.num_flows) + + r1 = full_d * self.triu_mask + r2 = full_d.transpose(2, 1) * self.triu_mask + + r1[:, self.diag_idx, self.diag_idx, :] = diag1 + r2[:, self.diag_idx, self.diag_idx, :] = diag2 + + q = self.amor_q(h) + b = self.amor_b(h) + + # Resize flow parameters to divide over K flows + b = b.resize(batch_size, 1, self.num_ortho_vecs, self.num_flows) + + return mean_z, var_z, r1, r2, q, b + + def forward(self, x): + """ + Forward pass with orthogonal sylvester flows for the transformation z_0 -> z_1 -> ... -> z_k. + Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. + """ + + self.log_det_j = 0. + + z_mu, z_var, r1, r2, q, b = self.encode(x) + + # Orthogonalize all q matrices + q_ortho = self.batch_construct_orthogonal(q) + + # Sample z_0 + z = [self.reparameterize(z_mu, z_var)] + + # Normalizing flows + for k in range(self.num_flows): + + flow_k = getattr(self, 'flow_' + str(k)) + z_k, log_det_jacobian = flow_k(z[k], r1[:, :, :, k], r2[:, :, :, k], q_ortho[k, :, :, :], b[:, :, :, k]) + + z.append(z_k) + self.log_det_j += log_det_jacobian + + x_mean = self.decode(z[-1]) + + return x_mean, z_mu, z_var, self.log_det_j, z[0], z[-1] + + +class HouseholderSylvesterVAE(VAE): + """ + Variational auto-encoder with householder sylvester flows in the encoder. + """ + + def __init__(self, args): + super(HouseholderSylvesterVAE, self).__init__(args) + + # Initialize log-det-jacobian to zero + self.log_det_j = 0. + + # Flow parameters + flow = flows.Sylvester + self.num_flows = args.num_flows + self.num_householder = args.num_householder + assert self.num_householder > 0 + + identity = torch.eye(self.z_size, self.z_size) + # Add batch dimension + identity = identity.unsqueeze(0) + # Put identity in buffer so that it will be moved to GPU if needed by any call of .cuda + self.register_buffer('_eye', identity) + self._eye.requires_grad = False + + # Masks needed for triangular r1 and r2. + triu_mask = torch.triu(torch.ones(self.z_size, self.z_size), diagonal=1) + triu_mask = triu_mask.unsqueeze(0).unsqueeze(3) + diag_idx = torch.arange(0, self.z_size).long() + + self.register_buffer('triu_mask', triu_mask) + self.triu_mask.requires_grad = False + self.register_buffer('diag_idx', diag_idx) + + # Amortized flow parameters + # Diagonal elements of r1 * r2 have to satisfy -1 < r1 * r2 for flow to be invertible + self.diag_activation = nn.Tanh() + + self.amor_d = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size * self.z_size) + + self.amor_diag1 = nn.Sequential( + nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size), self.diag_activation + ) + self.amor_diag2 = nn.Sequential( + nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size), self.diag_activation + ) + + self.amor_q = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size * self.num_householder) + + self.amor_b = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size) + + # Normalizing flow layers + for k in range(self.num_flows): + flow_k = flow(self.z_size) + + self.add_module('flow_' + str(k), flow_k) + + def batch_construct_orthogonal(self, q): + """ + Batch orthogonal matrix construction. + :param q: q contains batches of matrices, shape : (batch_size, num_flows * z_size * num_householder) + :return: batches of orthogonalized matrices, shape: (batch_size * num_flows, z_size, z_size) + """ + + # Reshape to shape (num_flows * batch_size * num_householder, z_size) + q = q.view(-1, self.z_size) + + norm = torch.norm(q, p=2, dim=1, keepdim=True) # ||v||_2 + v = torch.div(q, norm) # v / ||v||_2 + + # Calculate Householder Matrices + vvT = torch.bmm(v.unsqueeze(2), v.unsqueeze(1)) # v * v_T : batch_dot( B x L x 1 * B x 1 x L ) = B x L x L + + amat = self._eye - 2 * vvT # NOTICE: v is already normalized! so there is no need to calculate vvT/vTv + + # Reshaping: first dimension is batch_size * num_flows + amat = amat.view(-1, self.num_householder, self.z_size, self.z_size) + + tmp = amat[:, 0] + for k in range(1, self.num_householder): + tmp = torch.bmm(amat[:, k], tmp) + + amat = tmp.view(-1, self.num_flows, self.z_size, self.z_size) + amat = amat.transpose(0, 1) + + return amat + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and flow parameters. + """ + + batch_size = x.size(0) + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + + # Amortized r1, r2, q, b for all flows + full_d = self.amor_d(h) + diag1 = self.amor_diag1(h) + diag2 = self.amor_diag2(h) + + full_d = full_d.resize(batch_size, self.z_size, self.z_size, self.num_flows) + diag1 = diag1.resize(batch_size, self.z_size, self.num_flows) + diag2 = diag2.resize(batch_size, self.z_size, self.num_flows) + + r1 = full_d * self.triu_mask + r2 = full_d.transpose(2, 1) * self.triu_mask + + r1[:, self.diag_idx, self.diag_idx, :] = diag1 + r2[:, self.diag_idx, self.diag_idx, :] = diag2 + + q = self.amor_q(h) + + b = self.amor_b(h) + + # Resize flow parameters to divide over K flows + b = b.resize(batch_size, 1, self.z_size, self.num_flows) + + return mean_z, var_z, r1, r2, q, b + + def forward(self, x): + """ + Forward pass with orthogonal flows for the transformation z_0 -> z_1 -> ... -> z_k. + Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. + """ + + self.log_det_j = 0. + + z_mu, z_var, r1, r2, q, b = self.encode(x) + + # Orthogonalize all q matrices + q_ortho = self.batch_construct_orthogonal(q) + + # Sample z_0 + z = [self.reparameterize(z_mu, z_var)] + + # Normalizing flows + for k in range(self.num_flows): + + flow_k = getattr(self, 'flow_' + str(k)) + q_k = q_ortho[k] + + z_k, log_det_jacobian = flow_k(z[k], r1[:, :, :, k], r2[:, :, :, k], q_k, b[:, :, :, k], sum_ldj=True) + + z.append(z_k) + self.log_det_j += log_det_jacobian + + x_mean = self.decode(z[-1]) + + return x_mean, z_mu, z_var, self.log_det_j, z[0], z[-1] + + +class TriangularSylvesterVAE(VAE): + """ + Variational auto-encoder with triangular Sylvester flows in the encoder. Alternates between setting + the orthogonal matrix equal to permutation and identity matrix for each flow. + """ + + def __init__(self, args): + super(TriangularSylvesterVAE, self).__init__(args) + + # Initialize log-det-jacobian to zero + self.log_det_j = 0. + + # Flow parameters + flow = flows.TriangularSylvester + self.num_flows = args.num_flows + + # permuting indices corresponding to Q=P (permutation matrix) for every other flow + flip_idx = torch.arange(self.z_size - 1, -1, -1).long() + self.register_buffer('flip_idx', flip_idx) + + # Masks needed for triangular r1 and r2. + triu_mask = torch.triu(torch.ones(self.z_size, self.z_size), diagonal=1) + triu_mask = triu_mask.unsqueeze(0).unsqueeze(3) + diag_idx = torch.arange(0, self.z_size).long() + + self.register_buffer('triu_mask', triu_mask) + self.triu_mask.requires_grad = False + self.register_buffer('diag_idx', diag_idx) + + # Amortized flow parameters + # Diagonal elements of r1 * r2 have to satisfy -1 < r1 * r2 for flow to be invertible + self.diag_activation = nn.Tanh() + + self.amor_d = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size * self.z_size) + + self.amor_diag1 = nn.Sequential( + nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size), self.diag_activation + ) + self.amor_diag2 = nn.Sequential( + nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size), self.diag_activation + ) + + self.amor_b = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size) + + # Normalizing flow layers + for k in range(self.num_flows): + flow_k = flow(self.z_size) + + self.add_module('flow_' + str(k), flow_k) + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and flow parameters. + """ + + batch_size = x.size(0) + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + + # Amortized r1, r2, b for all flows + full_d = self.amor_d(h) + diag1 = self.amor_diag1(h) + diag2 = self.amor_diag2(h) + + full_d = full_d.resize(batch_size, self.z_size, self.z_size, self.num_flows) + diag1 = diag1.resize(batch_size, self.z_size, self.num_flows) + diag2 = diag2.resize(batch_size, self.z_size, self.num_flows) + + r1 = full_d * self.triu_mask + r2 = full_d.transpose(2, 1) * self.triu_mask + + r1[:, self.diag_idx, self.diag_idx, :] = diag1 + r2[:, self.diag_idx, self.diag_idx, :] = diag2 + + b = self.amor_b(h) + + # Resize flow parameters to divide over K flows + b = b.resize(batch_size, 1, self.z_size, self.num_flows) + + return mean_z, var_z, r1, r2, b + + def forward(self, x): + """ + Forward pass with orthogonal flows for the transformation z_0 -> z_1 -> ... -> z_k. + Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. + """ + + self.log_det_j = 0. + + z_mu, z_var, r1, r2, b = self.encode(x) + + # Sample z_0 + z = [self.reparameterize(z_mu, z_var)] + + # Normalizing flows + for k in range(self.num_flows): + + flow_k = getattr(self, 'flow_' + str(k)) + if k % 2 == 1: + # Alternate with reorderering z for triangular flow + permute_z = self.flip_idx + else: + permute_z = None + + z_k, log_det_jacobian = flow_k(z[k], r1[:, :, :, k], r2[:, :, :, k], b[:, :, :, k], permute_z, sum_ldj=True) + + z.append(z_k) + self.log_det_j += log_det_jacobian + + x_mean = self.decode(z[-1]) + + return x_mean, z_mu, z_var, self.log_det_j, z[0], z[-1] + + +class IAFVAE(VAE): + """ + Variational auto-encoder with inverse autoregressive flows in the encoder. + """ + + def __init__(self, args): + super(IAFVAE, self).__init__(args) + + # Initialize log-det-jacobian to zero + self.log_det_j = 0. + self.h_size = args.made_h_size + + self.h_context = nn.Linear(self.q_z_nn_output_dim, self.h_size) + + # Flow parameters + self.num_flows = args.num_flows + self.flow = flows.IAF( + z_size=self.z_size, num_flows=self.num_flows, num_hidden=1, h_size=self.h_size, conv2d=False + ) + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and context h for flows. + """ + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + h_context = self.h_context(h) + + return mean_z, var_z, h_context + + def forward(self, x): + """ + Forward pass with inverse autoregressive flows for the transformation z_0 -> z_1 -> ... -> z_k. + Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. + """ + + # mean and variance of z + z_mu, z_var, h_context = self.encode(x) + # sample z + z_0 = self.reparameterize(z_mu, z_var) + + # iaf flows + z_k, self.log_det_j = self.flow(z_0, h_context) + + # decode + x_mean = self.decode(z_k) + + return x_mean, z_mu, z_var, self.log_det_j, z_0, z_k diff --git a/src/torchprune/torchprune/util/models/cnn/models/__init__.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/__init__.py similarity index 100% rename from src/torchprune/torchprune/util/models/cnn/models/__init__.py rename to src/torchprune/torchprune/util/external/ffjord/vae_lib/models/__init__.py diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/flows.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/flows.py new file mode 100644 index 0000000..0894043 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/flows.py @@ -0,0 +1,299 @@ +""" +Collection of flow strategies +""" + +from __future__ import print_function + +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.nn.functional as F + +from vae_lib.models.layers import MaskedConv2d, MaskedLinear + + +class Planar(nn.Module): + """ + PyTorch implementation of planar flows as presented in "Variational Inference with Normalizing Flows" + by Danilo Jimenez Rezende, Shakir Mohamed. Model assumes amortized flow parameters. + """ + + def __init__(self): + + super(Planar, self).__init__() + + self.h = nn.Tanh() + self.softplus = nn.Softplus() + + def der_h(self, x): + """ Derivative of tanh """ + + return 1 - self.h(x)**2 + + def forward(self, zk, u, w, b): + """ + Forward pass. Assumes amortized u, w and b. Conditions on diagonals of u and w for invertibility + will be be satisfied inside this function. Computes the following transformation: + z' = z + u h( w^T z + b) + or actually + z'^T = z^T + h(z^T w + b)u^T + Assumes the following input shapes: + shape u = (batch_size, z_size, 1) + shape w = (batch_size, 1, z_size) + shape b = (batch_size, 1, 1) + shape z = (batch_size, z_size). + """ + + zk = zk.unsqueeze(2) + + # reparameterize u such that the flow becomes invertible (see appendix paper) + uw = torch.bmm(w, u) + m_uw = -1. + self.softplus(uw) + w_norm_sq = torch.sum(w**2, dim=2, keepdim=True) + u_hat = u + ((m_uw - uw) * w.transpose(2, 1) / w_norm_sq) + + # compute flow with u_hat + wzb = torch.bmm(w, zk) + b + z = zk + u_hat * self.h(wzb) + z = z.squeeze(2) + + # compute logdetJ + psi = w * self.der_h(wzb) + log_det_jacobian = torch.log(torch.abs(1 + torch.bmm(psi, u_hat))) + log_det_jacobian = log_det_jacobian.squeeze(2).squeeze(1) + + return z, log_det_jacobian + + +class Sylvester(nn.Module): + """ + Sylvester normalizing flow. + """ + + def __init__(self, num_ortho_vecs): + + super(Sylvester, self).__init__() + + self.num_ortho_vecs = num_ortho_vecs + + self.h = nn.Tanh() + + triu_mask = torch.triu(torch.ones(num_ortho_vecs, num_ortho_vecs), diagonal=1).unsqueeze(0) + diag_idx = torch.arange(0, num_ortho_vecs).long() + + self.register_buffer('triu_mask', Variable(triu_mask)) + self.triu_mask.requires_grad = False + self.register_buffer('diag_idx', diag_idx) + + def der_h(self, x): + return self.der_tanh(x) + + def der_tanh(self, x): + return 1 - self.h(x)**2 + + def _forward(self, zk, r1, r2, q_ortho, b, sum_ldj=True): + """ + All flow parameters are amortized. Conditions on diagonals of R1 and R2 for invertibility need to be satisfied + outside of this function. Computes the following transformation: + z' = z + QR1 h( R2Q^T z + b) + or actually + z'^T = z^T + h(z^T Q R2^T + b^T)R1^T Q^T + :param zk: shape: (batch_size, z_size) + :param r1: shape: (batch_size, num_ortho_vecs, num_ortho_vecs) + :param r2: shape: (batch_size, num_ortho_vecs, num_ortho_vecs) + :param q_ortho: shape (batch_size, z_size , num_ortho_vecs) + :param b: shape: (batch_size, 1, self.z_size) + :return: z, log_det_j + """ + + # Amortized flow parameters + zk = zk.unsqueeze(1) + + # Save diagonals for log_det_j + diag_r1 = r1[:, self.diag_idx, self.diag_idx] + diag_r2 = r2[:, self.diag_idx, self.diag_idx] + + r1_hat = r1 + r2_hat = r2 + + qr2 = torch.bmm(q_ortho, r2_hat.transpose(2, 1)) + qr1 = torch.bmm(q_ortho, r1_hat) + + r2qzb = torch.bmm(zk, qr2) + b + z = torch.bmm(self.h(r2qzb), qr1.transpose(2, 1)) + zk + z = z.squeeze(1) + + # Compute log|det J| + # Output log_det_j in shape (batch_size) instead of (batch_size,1) + diag_j = diag_r1 * diag_r2 + diag_j = self.der_h(r2qzb).squeeze(1) * diag_j + diag_j += 1. + log_diag_j = diag_j.abs().log() + + if sum_ldj: + log_det_j = log_diag_j.sum(-1) + else: + log_det_j = log_diag_j + + return z, log_det_j + + def forward(self, zk, r1, r2, q_ortho, b, sum_ldj=True): + + return self._forward(zk, r1, r2, q_ortho, b, sum_ldj) + + +class TriangularSylvester(nn.Module): + """ + Sylvester normalizing flow with Q=P or Q=I. + """ + + def __init__(self, z_size): + + super(TriangularSylvester, self).__init__() + + self.z_size = z_size + self.h = nn.Tanh() + + diag_idx = torch.arange(0, z_size).long() + self.register_buffer('diag_idx', diag_idx) + + def der_h(self, x): + return self.der_tanh(x) + + def der_tanh(self, x): + return 1 - self.h(x)**2 + + def _forward(self, zk, r1, r2, b, permute_z=None, sum_ldj=True): + """ + All flow parameters are amortized. conditions on diagonals of R1 and R2 need to be satisfied + outside of this function. + Computes the following transformation: + z' = z + QR1 h( R2Q^T z + b) + or actually + z'^T = z^T + h(z^T Q R2^T + b^T)R1^T Q^T + with Q = P a permutation matrix (equal to identity matrix if permute_z=None) + :param zk: shape: (batch_size, z_size) + :param r1: shape: (batch_size, num_ortho_vecs, num_ortho_vecs). + :param r2: shape: (batch_size, num_ortho_vecs, num_ortho_vecs). + :param b: shape: (batch_size, 1, self.z_size) + :return: z, log_det_j + """ + + # Amortized flow parameters + zk = zk.unsqueeze(1) + + # Save diagonals for log_det_j + diag_r1 = r1[:, self.diag_idx, self.diag_idx] + diag_r2 = r2[:, self.diag_idx, self.diag_idx] + + if permute_z is not None: + # permute order of z + z_per = zk[:, :, permute_z] + else: + z_per = zk + + r2qzb = torch.bmm(z_per, r2.transpose(2, 1)) + b + z = torch.bmm(self.h(r2qzb), r1.transpose(2, 1)) + + if permute_z is not None: + # permute order of z again back again + z = z[:, :, permute_z] + + z += zk + z = z.squeeze(1) + + # Compute log|det J| + # Output log_det_j in shape (batch_size) instead of (batch_size,1) + diag_j = diag_r1 * diag_r2 + diag_j = self.der_h(r2qzb).squeeze(1) * diag_j + diag_j += 1. + log_diag_j = diag_j.abs().log() + + if sum_ldj: + log_det_j = log_diag_j.sum(-1) + else: + log_det_j = log_diag_j + + return z, log_det_j + + def forward(self, zk, r1, r2, q_ortho, b, sum_ldj=True): + + return self._forward(zk, r1, r2, q_ortho, b, sum_ldj) + + +class IAF(nn.Module): + """ + PyTorch implementation of inverse autoregressive flows as presented in + "Improving Variational Inference with Inverse Autoregressive Flow" by Diederik P. Kingma, Tim Salimans, + Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling. + Inverse Autoregressive Flow with either MADE MLPs or Pixel CNNs. Contains several flows. Each transformation + takes as an input the previous stochastic z, and a context h. The structure of each flow is then as follows: + z <- autoregressive_layer(z) + h, allow for diagonal connections + z <- autoregressive_layer(z), allow for diagonal connections + : + z <- autoregressive_layer(z), do not allow for diagonal connections. + + Note that the size of h needs to be the same as h_size, which is the width of the MADE layers. + """ + + def __init__(self, z_size, num_flows=2, num_hidden=0, h_size=50, forget_bias=1., conv2d=False): + super(IAF, self).__init__() + self.z_size = z_size + self.num_flows = num_flows + self.num_hidden = num_hidden + self.h_size = h_size + self.conv2d = conv2d + if not conv2d: + ar_layer = MaskedLinear + else: + ar_layer = MaskedConv2d + self.activation = torch.nn.ELU + # self.activation = torch.nn.ReLU + + self.forget_bias = forget_bias + self.flows = [] + self.param_list = [] + + # For reordering z after each flow + flip_idx = torch.arange(self.z_size - 1, -1, -1).long() + self.register_buffer('flip_idx', flip_idx) + + for k in range(num_flows): + arch_z = [ar_layer(z_size, h_size), self.activation()] + self.param_list += list(arch_z[0].parameters()) + z_feats = torch.nn.Sequential(*arch_z) + arch_zh = [] + for j in range(num_hidden): + arch_zh += [ar_layer(h_size, h_size), self.activation()] + self.param_list += list(arch_zh[-2].parameters()) + zh_feats = torch.nn.Sequential(*arch_zh) + linear_mean = ar_layer(h_size, z_size, diagonal_zeros=True) + linear_std = ar_layer(h_size, z_size, diagonal_zeros=True) + self.param_list += list(linear_mean.parameters()) + self.param_list += list(linear_std.parameters()) + + if torch.cuda.is_available(): + z_feats = z_feats.cuda() + zh_feats = zh_feats.cuda() + linear_mean = linear_mean.cuda() + linear_std = linear_std.cuda() + self.flows.append((z_feats, zh_feats, linear_mean, linear_std)) + + self.param_list = torch.nn.ParameterList(self.param_list) + + def forward(self, z, h_context): + + logdets = 0. + for i, flow in enumerate(self.flows): + if (i + 1) % 2 == 0 and not self.conv2d: + # reverse ordering to help mixing + z = z[:, self.flip_idx] + + h = flow[0](z) + h = h + h_context + h = flow[1](h) + mean = flow[2](h) + gate = F.sigmoid(flow[3](h) + self.forget_bias) + z = gate * z + (1 - gate) * mean + logdets += torch.sum(gate.log().view(gate.size(0), -1), 1) + return z, logdets diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/layers.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/layers.py new file mode 100644 index 0000000..e15c453 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/layers.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter +import numpy as np +import torch.nn.functional as F + + +class Identity(nn.Module): + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +class GatedConv2d(nn.Module): + + def __init__(self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None): + super(GatedConv2d, self).__init__() + + self.activation = activation + self.sigmoid = nn.Sigmoid() + + self.h = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation) + self.g = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation) + + def forward(self, x): + if self.activation is None: + h = self.h(x) + else: + h = self.activation(self.h(x)) + + g = self.sigmoid(self.g(x)) + + return h * g + + +class GatedConvTranspose2d(nn.Module): + + def __init__( + self, input_channels, output_channels, kernel_size, stride, padding, output_padding=0, dilation=1, + activation=None + ): + super(GatedConvTranspose2d, self).__init__() + + self.activation = activation + self.sigmoid = nn.Sigmoid() + + self.h = nn.ConvTranspose2d( + input_channels, output_channels, kernel_size, stride, padding, output_padding, dilation=dilation + ) + self.g = nn.ConvTranspose2d( + input_channels, output_channels, kernel_size, stride, padding, output_padding, dilation=dilation + ) + + def forward(self, x): + if self.activation is None: + h = self.h(x) + else: + h = self.activation(self.h(x)) + + g = self.sigmoid(self.g(x)) + + return h * g + + +class MaskedLinear(nn.Module): + """ + Creates masked linear layer for MLP MADE. + For input (x) to hidden (h) or hidden to hidden layers choose diagonal_zeros = False. + For hidden to output (y) layers: + If output depends on input through y_i = f(x_{= n_in: + k = n_out // n_in + for i in range(n_in): + mask[i + 1:, i * k:(i + 1) * k] = 0 + if self.diagonal_zeros: + mask[i:i + 1, i * k:(i + 1) * k] = 0 + else: + k = n_in // n_out + for i in range(n_out): + mask[(i + 1) * k:, i:i + 1] = 0 + if self.diagonal_zeros: + mask[i * k:(i + 1) * k:, i:i + 1] = 0 + return mask + + def forward(self, x): + output = x.mm(self.mask * self.weight) + + if self.bias is not None: + return output.add(self.bias.expand_as(output)) + else: + return output + + def __repr__(self): + if self.bias is not None: + bias = True + else: + bias = False + return self.__class__.__name__ + ' (' \ + + str(self.in_features) + ' -> ' \ + + str(self.out_features) + ', diagonal_zeros=' \ + + str(self.diagonal_zeros) + ', bias=' \ + + str(bias) + ')' + + +class MaskedConv2d(nn.Module): + """ + Creates masked convolutional autoregressive layer for pixelCNN. + For input (x) to hidden (h) or hidden to hidden layers choose diagonal_zeros = False. + For hidden to output (y) layers: + If output depends on input through y_i = f(x_{= n_in: + k = n_out // n_in + for i in range(n_in): + mask[i * k:(i + 1) * k, i + 1:, l, m] = 0 + if self.diagonal_zeros: + mask[i * k:(i + 1) * k, i:i + 1, l, m] = 0 + else: + k = n_in // n_out + for i in range(n_out): + mask[i:i + 1, (i + 1) * k:, l, m] = 0 + if self.diagonal_zeros: + mask[i:i + 1, i * k:(i + 1) * k:, l, m] = 0 + + return mask + + def forward(self, x): + output = F.conv2d(x, self.mask * self.weight, bias=self.bias, padding=(1, 1)) + return output + + def __repr__(self): + if self.bias is not None: + bias = True + else: + bias = False + return self.__class__.__name__ + ' (' \ + + str(self.in_features) + ' -> ' \ + + str(self.out_features) + ', diagonal_zeros=' \ + + str(self.diagonal_zeros) + ', bias=' \ + + str(bias) + ', size_kernel=' \ + + str(self.size_kernel) + ')' diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/__init__.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/loss.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/loss.py new file mode 100644 index 0000000..4e56712 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/loss.py @@ -0,0 +1,271 @@ +from __future__ import print_function + +import numpy as np +import torch +import torch.nn as nn +from vae_lib.utils.distributions import log_normal_diag, log_normal_standard, log_bernoulli +import torch.nn.functional as F + + +def binary_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): + """ + Computes the binary loss function while summing over batch dimension, not averaged! + :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1) + :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. + :param z_mu: mean of z_0 + :param z_var: variance of z_0 + :param z_0: first stochastic latent variable + :param z_k: last stochastic latent variable + :param ldj: log det jacobian + :param beta: beta for kl loss + :return: loss, ce, kl + """ + + reconstruction_function = nn.BCELoss(size_average=False) + + batch_size = x.size(0) + + # - N E_q0 [ ln p(x|z_k) ] + bce = reconstruction_function(recon_x, x) + + # ln p(z_k) (not averaged) + log_p_zk = log_normal_standard(z_k, dim=1) + # ln q(z_0) (not averaged) + log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) + # N E_q0[ ln q(z_0) - ln p(z_k) ] + summed_logs = torch.sum(log_q_z0 - log_p_zk) + + # sum over batches + summed_ldj = torch.sum(ldj) + + # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] + kl = (summed_logs - summed_ldj) + loss = bce + beta * kl + + loss /= float(batch_size) + bce /= float(batch_size) + kl /= float(batch_size) + + return loss, bce, kl + + +def multinomial_loss_function(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): + """ + Computes the cross entropy loss function while summing over batch dimension, not averaged! + :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits + :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. + :param z_mu: mean of z_0 + :param z_var: variance of z_0 + :param z_0: first stochastic latent variable + :param z_k: last stochastic latent variable + :param ldj: log det jacobian + :param args: global parameter settings + :param beta: beta for kl loss + :return: loss, ce, kl + """ + + num_classes = 256 + batch_size = x.size(0) + + x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) + + # make integer class labels + target = (x * (num_classes - 1)).long() + + # - N E_q0 [ ln p(x|z_k) ] + # sums over batch dimension (and feature dimension) + ce = cross_entropy(x_logit, target, size_average=False) + + # ln p(z_k) (not averaged) + log_p_zk = log_normal_standard(z_k, dim=1) + # ln q(z_0) (not averaged) + log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) + # N E_q0[ ln q(z_0) - ln p(z_k) ] + summed_logs = torch.sum(log_q_z0 - log_p_zk) + + # sum over batches + summed_ldj = torch.sum(ldj) + + # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] + kl = (summed_logs - summed_ldj) + loss = ce + beta * kl + + loss /= float(batch_size) + ce /= float(batch_size) + kl /= float(batch_size) + + return loss, ce, kl + + +def binary_loss_array(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): + """ + Computes the binary loss without averaging or summing over the batch dimension. + """ + + batch_size = x.size(0) + + # if not summed over batch_dimension + if len(ldj.size()) > 1: + ldj = ldj.view(ldj.size(0), -1).sum(-1) + + # TODO: upgrade to newest pytorch version on master branch, there the nn.BCELoss comes with the option + # reduce, which when set to False, does no sum over batch dimension. + bce = -log_bernoulli(x.view(batch_size, -1), recon_x.view(batch_size, -1), dim=1) + # ln p(z_k) (not averaged) + log_p_zk = log_normal_standard(z_k, dim=1) + # ln q(z_0) (not averaged) + log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) + # ln q(z_0) - ln p(z_k) ] + logs = log_q_z0 - log_p_zk + + loss = bce + beta * (logs - ldj) + + return loss + + +def multinomial_loss_array(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): + """ + Computes the discritezed logistic loss without averaging or summing over the batch dimension. + """ + + num_classes = 256 + batch_size = x.size(0) + + x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) + + # make integer class labels + target = (x * (num_classes - 1)).long() + + # - N E_q0 [ ln p(x|z_k) ] + # computes cross entropy over all dimensions separately: + ce = cross_entropy(x_logit, target, size_average=False, reduce=False) + # sum over feature dimension + ce = ce.view(batch_size, -1).sum(dim=1) + + # ln p(z_k) (not averaged) + log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1) + # ln q(z_0) (not averaged) + log_q_z0 = log_normal_diag( + z_0.view(batch_size, -1), mean=z_mu.view(batch_size, -1), log_var=z_var.log().view(batch_size, -1), dim=1 + ) + + # ln q(z_0) - ln p(z_k) ] + logs = log_q_z0 - log_p_zk + + loss = ce + beta * (logs - ldj) + + return loss + + +def cross_entropy(input, target, weight=None, size_average=True, ignore_index=-100, reduce=True): + r""" + Taken from the master branch of pytorch, accepts (N, C, d_1, d_2, ..., d_K) input shapes + instead of only (N, C, d_1, d_2) or (N, C). + This criterion combines `log_softmax` and `nll_loss` in a single + function. + See :class:`~torch.nn.CrossEntropyLoss` for details. + Args: + input: Variable :math:`(N, C)` where `C = number of classes` + target: Variable :math:`(N)` where each value is + `0 <= targets[i] <= C-1` + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, has to be a Tensor of size `C` + size_average (bool, optional): By default, the losses are averaged + over observations for each minibatch. However, if the field + sizeAverage is set to False, the losses are instead summed + for each minibatch. Ignored if reduce is False. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When size_average is + True, the loss is averaged over non-ignored targets. Default: -100 + reduce (bool, optional): By default, the losses are averaged or summed over + observations for each minibatch depending on size_average. When reduce + is False, returns a loss per batch element instead and ignores + size_average. Default: ``True`` + """ + return nll_loss(F.log_softmax(input, 1), target, weight, size_average, ignore_index, reduce) + + +def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, reduce=True): + r""" + Taken from the master branch of pytorch, accepts (N, C, d_1, d_2, ..., d_K) input shapes + instead of only (N, C, d_1, d_2) or (N, C). + The negative log likelihood loss. + See :class:`~torch.nn.NLLLoss` for details. + Args: + input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)` + in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K > 1` + in the case of K-dimensional loss. + target: :math:`(N)` where each value is `0 <= targets[i] <= C-1`, + or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K >= 1` for + K-dimensional loss. + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, has to be a Tensor of size `C` + size_average (bool, optional): By default, the losses are averaged + over observations for each minibatch. If size_average + is False, the losses are summed for each minibatch. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When size_average is + True, the loss is averaged over non-ignored targets. Default: -100 + """ + dim = input.dim() + if dim == 2: + return F.nll_loss( + input, target, weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce + ) + elif dim == 4: + return F.nll_loss( + input, target, weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce + ) + elif dim == 3 or dim > 4: + n = input.size(0) + c = input.size(1) + out_size = (n,) + input.size()[2:] + if target.size()[1:] != input.size()[2:]: + raise ValueError('Expected target size {}, got {}'.format(out_size, input.size())) + input = input.contiguous().view(n, c, 1, -1) + target = target.contiguous().view(n, 1, -1) + if reduce: + _loss = nn.NLLLoss2d(weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce) + return _loss(input, target) + out = F.nll_loss( + input, target, weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce + ) + return out.view(out_size) + else: + raise ValueError('Expected 2 or more dimensions (got {})'.format(dim)) + + +def calculate_loss(x_mean, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): + """ + Picks the correct loss depending on the input type. + """ + + if args.input_type == 'binary': + loss, rec, kl = binary_loss_function(x_mean, x, z_mu, z_var, z_0, z_k, ldj, beta=beta) + bpd = 0. + + elif args.input_type == 'multinomial': + loss, rec, kl = multinomial_loss_function(x_mean, x, z_mu, z_var, z_0, z_k, ldj, args, beta=beta) + bpd = loss.data[0] / (np.prod(args.input_size) * np.log(2.)) + + else: + raise ValueError('Invalid input type for calculate loss: %s.' % args.input_type) + + return loss, rec, kl, bpd + + +def calculate_loss_array(x_mean, x, z_mu, z_var, z_0, z_k, ldj, args): + """ + Picks the correct loss depending on the input type. + """ + + if args.input_type == 'binary': + loss = binary_loss_array(x_mean, x, z_mu, z_var, z_0, z_k, ldj) + + elif args.input_type == 'multinomial': + loss = multinomial_loss_array(x_mean, x, z_mu, z_var, z_0, z_k, ldj, args) + + else: + raise ValueError('Invalid input type for calculate loss: %s.' % args.input_type) + + return loss diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/training.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/training.py new file mode 100644 index 0000000..e6f4b37 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/training.py @@ -0,0 +1,171 @@ +from __future__ import print_function +import time +import torch + +from vae_lib.optimization.loss import calculate_loss +from vae_lib.utils.visual_evaluation import plot_reconstructions +from vae_lib.utils.log_likelihood import calculate_likelihood + +import numpy as np +from train_misc import count_nfe, override_divergence_fn + + +def train(epoch, train_loader, model, opt, args, logger): + + model.train() + train_loss = np.zeros(len(train_loader)) + train_bpd = np.zeros(len(train_loader)) + + num_data = 0 + + # set warmup coefficient + beta = min([(epoch * 1.) / max([args.warmup, 1.]), args.max_beta]) + logger.info('beta = {:5.4f}'.format(beta)) + end = time.time() + for batch_idx, (data, _) in enumerate(train_loader): + if args.cuda: + data = data.cuda() + + if args.dynamic_binarization: + data = torch.bernoulli(data) + + data = data.view(-1, *args.input_size) + + opt.zero_grad() + x_mean, z_mu, z_var, ldj, z0, zk = model(data) + + if 'cnf' in args.flow: + f_nfe = count_nfe(model) + + loss, rec, kl, bpd = calculate_loss(x_mean, data, z_mu, z_var, z0, zk, ldj, args, beta=beta) + + loss.backward() + + if 'cnf' in args.flow: + t_nfe = count_nfe(model) + b_nfe = t_nfe - f_nfe + + train_loss[batch_idx] = loss.item() + train_bpd[batch_idx] = bpd + + opt.step() + + rec = rec.item() + kl = kl.item() + + num_data += len(data) + + batch_time = time.time() - end + end = time.time() + + if batch_idx % args.log_interval == 0: + if args.input_type == 'binary': + perc = 100. * batch_idx / len(train_loader) + log_msg = ( + 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | ' + 'Rec {:11.6f} | KL {:11.6f}'.format( + epoch, num_data, len(train_loader.sampler), perc, batch_time, loss.item(), rec, kl + ) + ) + else: + perc = 100. * batch_idx / len(train_loader) + tmp = 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | Bits/dim {:8.6f}' + log_msg = tmp.format(epoch, num_data, len(train_loader.sampler), perc, batch_time, loss.item(), + bpd), '\trec: {:11.3f}\tkl: {:11.6f}'.format(rec, kl) + log_msg = "".join(log_msg) + if 'cnf' in args.flow: + log_msg += ' | NFE Forward {} | NFE Backward {}'.format(f_nfe, b_nfe) + logger.info(log_msg) + + if args.input_type == 'binary': + logger.info('====> Epoch: {:3d} Average train loss: {:.4f}'.format(epoch, train_loss.sum() / len(train_loader))) + else: + logger.info( + '====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'. + format(epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)) + ) + + return train_loss + + +def evaluate(data_loader, model, args, logger, testing=False, epoch=0): + model.eval() + loss = 0. + batch_idx = 0 + bpd = 0. + + if args.input_type == 'binary': + loss_type = 'elbo' + else: + loss_type = 'bpd' + + if testing and 'cnf' in args.flow: + override_divergence_fn(model, "brute_force") + + for data, _ in data_loader: + batch_idx += 1 + + if args.cuda: + data = data.cuda() + + with torch.no_grad(): + data = data.view(-1, *args.input_size) + + x_mean, z_mu, z_var, ldj, z0, zk = model(data) + + batch_loss, rec, kl, batch_bpd = calculate_loss(x_mean, data, z_mu, z_var, z0, zk, ldj, args) + + bpd += batch_bpd + loss += batch_loss.item() + + # PRINT RECONSTRUCTIONS + if batch_idx == 1 and testing is False: + plot_reconstructions(data, x_mean, batch_loss, loss_type, epoch, args) + + loss /= len(data_loader) + bpd /= len(data_loader) + + if testing: + logger.info('====> Test set loss: {:.4f}'.format(loss)) + + # Compute log-likelihood + if testing and not ("cnf" in args.flow): # don't compute log-likelihood for cnf models + + with torch.no_grad(): + test_data = data_loader.dataset.tensors[0] + + if args.cuda: + test_data = test_data.cuda() + + logger.info('Computing log-likelihood on test set') + + model.eval() + + if args.dataset == 'caltech': + log_likelihood, nll_bpd = calculate_likelihood(test_data, model, args, logger, S=2000, MB=500) + else: + log_likelihood, nll_bpd = calculate_likelihood(test_data, model, args, logger, S=5000, MB=500) + + if 'cnf' in args.flow: + override_divergence_fn(model, args.divergence_fn) + else: + log_likelihood = None + nll_bpd = None + + if args.input_type in ['multinomial']: + bpd = loss / (np.prod(args.input_size) * np.log(2.)) + + if testing and not ("cnf" in args.flow): + logger.info('====> Test set log-likelihood: {:.4f}'.format(log_likelihood)) + + if args.input_type != 'binary': + logger.info('====> Test set bpd (elbo): {:.4f}'.format(bpd)) + logger.info( + '====> Test set bpd (log-likelihood): {:.4f}'. + format(log_likelihood / (np.prod(args.input_size) * np.log(2.))) + ) + + if not testing: + return loss, bpd + else: + return log_likelihood, nll_bpd diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/__init__.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/distributions.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/distributions.py new file mode 100644 index 0000000..c58e4bc --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/distributions.py @@ -0,0 +1,65 @@ +from __future__ import print_function +import torch +import torch.utils.data + +import math + +MIN_EPSILON = 1e-5 +MAX_EPSILON = 1. - 1e-5 + +PI = torch.FloatTensor([math.pi]) +if torch.cuda.is_available(): + PI = PI.cuda() + +# N(x | mu, var) = 1/sqrt{2pi var} exp[-1/(2 var) (x-mean)(x-mean)] +# log N(x| mu, var) = -log sqrt(2pi) -0.5 log var - 0.5 (x-mean)(x-mean)/var + + +def log_normal_diag(x, mean, log_var, average=False, reduce=True, dim=None): + log_norm = -0.5 * (log_var + (x - mean) * (x - mean) * log_var.exp().reciprocal()) + if reduce: + if average: + return torch.mean(log_norm, dim) + else: + return torch.sum(log_norm, dim) + else: + return log_norm + + +def log_normal_normalized(x, mean, log_var, average=False, reduce=True, dim=None): + log_norm = -(x - mean) * (x - mean) + log_norm *= torch.reciprocal(2. * log_var.exp()) + log_norm += -0.5 * log_var + log_norm += -0.5 * torch.log(2. * PI) + + if reduce: + if average: + return torch.mean(log_norm, dim) + else: + return torch.sum(log_norm, dim) + else: + return log_norm + + +def log_normal_standard(x, average=False, reduce=True, dim=None): + log_norm = -0.5 * x * x + + if reduce: + if average: + return torch.mean(log_norm, dim) + else: + return torch.sum(log_norm, dim) + else: + return log_norm + + +def log_bernoulli(x, mean, average=False, reduce=True, dim=None): + probs = torch.clamp(mean, min=MIN_EPSILON, max=MAX_EPSILON) + log_bern = x * torch.log(probs) + (1. - x) * torch.log(1. - probs) + if reduce: + if average: + return torch.mean(log_bern, dim) + else: + return torch.sum(log_bern, dim) + else: + return log_bern diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/load_data.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/load_data.py new file mode 100644 index 0000000..ad1a836 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/load_data.py @@ -0,0 +1,205 @@ +from __future__ import print_function + +import torch +import torch.utils.data as data_utils +import pickle +from scipy.io import loadmat + +import numpy as np + +import os + + +def load_static_mnist(args, **kwargs): + """ + Dataloading function for static mnist. Outputs image data in vectorized form: each image is a vector of size 784 + """ + args.dynamic_binarization = False + args.input_type = 'binary' + + args.input_size = [1, 28, 28] + + # start processing + def lines_to_np_array(lines): + return np.array([[int(i) for i in line.split()] for line in lines]) + + with open(os.path.join('data', 'MNIST_static', 'binarized_mnist_train.amat')) as f: + lines = f.readlines() + x_train = lines_to_np_array(lines).astype('float32') + with open(os.path.join('data', 'MNIST_static', 'binarized_mnist_valid.amat')) as f: + lines = f.readlines() + x_val = lines_to_np_array(lines).astype('float32') + with open(os.path.join('data', 'MNIST_static', 'binarized_mnist_test.amat')) as f: + lines = f.readlines() + x_test = lines_to_np_array(lines).astype('float32') + + # shuffle train data + np.random.shuffle(x_train) + + # idle y's + y_train = np.zeros((x_train.shape[0], 1)) + y_val = np.zeros((x_val.shape[0], 1)) + y_test = np.zeros((x_test.shape[0], 1)) + + # pytorch data loader + train = data_utils.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) + train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs) + + validation = data_utils.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val)) + val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs) + + test = data_utils.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test)) + test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, args + + +def load_freyfaces(args, **kwargs): + # set args + args.input_size = [1, 28, 20] + args.input_type = 'multinomial' + args.dynamic_binarization = False + + TRAIN = 1565 + VAL = 200 + TEST = 200 + + # start processing + with open('data/Freyfaces/freyfaces.pkl', 'rb') as f: + data = pickle.load(f, encoding="latin1")[0] + + data = data / 255. + + # NOTE: shuffling is done before splitting into train and test set, so test set is different for every run! + # shuffle data: + np.random.seed(args.freyseed) + + np.random.shuffle(data) + + # train images + x_train = data[0:TRAIN].reshape(-1, 28 * 20) + # validation images + x_val = data[TRAIN:(TRAIN + VAL)].reshape(-1, 28 * 20) + # test images + x_test = data[(TRAIN + VAL):(TRAIN + VAL + TEST)].reshape(-1, 28 * 20) + + # idle y's + y_train = np.zeros((x_train.shape[0], 1)) + y_val = np.zeros((x_val.shape[0], 1)) + y_test = np.zeros((x_test.shape[0], 1)) + + # pytorch data loader + train = data_utils.TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train)) + train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs) + + validation = data_utils.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val)) + val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs) + + test = data_utils.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test)) + test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs) + return train_loader, val_loader, test_loader, args + + +def load_omniglot(args, **kwargs): + n_validation = 1345 + + # set args + args.input_size = [1, 28, 28] + args.input_type = 'binary' + args.dynamic_binarization = True + + # start processing + def reshape_data(data): + return data.reshape((-1, 28, 28)).reshape((-1, 28 * 28), order='F') + + omni_raw = loadmat(os.path.join('data', 'OMNIGLOT', 'chardata.mat')) + + # train and test data + train_data = reshape_data(omni_raw['data'].T.astype('float32')) + x_test = reshape_data(omni_raw['testdata'].T.astype('float32')) + + # shuffle train data + np.random.shuffle(train_data) + + # set train and validation data + x_train = train_data[:-n_validation] + x_val = train_data[-n_validation:] + + # binarize + if args.dynamic_binarization: + args.input_type = 'binary' + np.random.seed(777) + x_val = np.random.binomial(1, x_val) + x_test = np.random.binomial(1, x_test) + else: + args.input_type = 'gray' + + # idle y's + y_train = np.zeros((x_train.shape[0], 1)) + y_val = np.zeros((x_val.shape[0], 1)) + y_test = np.zeros((x_test.shape[0], 1)) + + # pytorch data loader + train = data_utils.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) + train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs) + + validation = data_utils.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val)) + val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs) + + test = data_utils.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test)) + test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, args + + +def load_caltech101silhouettes(args, **kwargs): + # set args + args.input_size = [1, 28, 28] + args.input_type = 'binary' + args.dynamic_binarization = False + + # start processing + def reshape_data(data): + return data.reshape((-1, 28, 28)).reshape((-1, 28 * 28), order='F') + + caltech_raw = loadmat(os.path.join('data', 'Caltech101Silhouettes', 'caltech101_silhouettes_28_split1.mat')) + + # train, validation and test data + x_train = 1. - reshape_data(caltech_raw['train_data'].astype('float32')) + np.random.shuffle(x_train) + x_val = 1. - reshape_data(caltech_raw['val_data'].astype('float32')) + np.random.shuffle(x_val) + x_test = 1. - reshape_data(caltech_raw['test_data'].astype('float32')) + + y_train = caltech_raw['train_labels'] + y_val = caltech_raw['val_labels'] + y_test = caltech_raw['test_labels'] + + # pytorch data loader + train = data_utils.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) + train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs) + + validation = data_utils.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val)) + val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs) + + test = data_utils.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test)) + test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, args + + +def load_dataset(args, **kwargs): + + if args.dataset == 'mnist': + train_loader, val_loader, test_loader, args = load_static_mnist(args, **kwargs) + elif args.dataset == 'caltech': + train_loader, val_loader, test_loader, args = load_caltech101silhouettes(args, **kwargs) + + elif args.dataset == 'freyfaces': + train_loader, val_loader, test_loader, args = load_freyfaces(args, **kwargs) + elif args.dataset == 'omniglot': + train_loader, val_loader, test_loader, args = load_omniglot(args, **kwargs) + else: + raise Exception('Wrong name of the dataset!') + + return train_loader, val_loader, test_loader, args diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/log_likelihood.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/log_likelihood.py new file mode 100644 index 0000000..b5461b4 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/log_likelihood.py @@ -0,0 +1,60 @@ +from __future__ import print_function +import time +import numpy as np +from scipy.misc import logsumexp +from vae_lib.optimization.loss import calculate_loss_array + + +def calculate_likelihood(X, model, args, logger, S=5000, MB=500): + + # set auxiliary variables for number of training and test sets + N_test = X.size(0) + + X = X.view(-1, *args.input_size) + + likelihood_test = [] + + if S <= MB: + R = 1 + else: + R = S // MB + S = MB + + end = time.time() + for j in range(N_test): + + x_single = X[j].unsqueeze(0) + + a = [] + for r in range(0, R): + # Repeat it for all training points + x = x_single.expand(S, *x_single.size()[1:]).contiguous() + + x_mean, z_mu, z_var, ldj, z0, zk = model(x) + + a_tmp = calculate_loss_array(x_mean, x, z_mu, z_var, z0, zk, ldj, args) + + a.append(-a_tmp.cpu().data.numpy()) + + # calculate max + a = np.asarray(a) + a = np.reshape(a, (a.shape[0] * a.shape[1], 1)) + likelihood_x = logsumexp(a) + likelihood_test.append(likelihood_x - np.log(len(a))) + + if j % 1 == 0: + logger.info('Progress: {:.2f}% | Time: {:.4f}'.format(j / (1. * N_test) * 100, time.time() - end)) + end = time.time() + + likelihood_test = np.array(likelihood_test) + + nll = -np.mean(likelihood_test) + + if args.input_type == 'multinomial': + bpd = nll / (np.prod(args.input_size) * np.log(2.)) + elif args.input_type == 'binary': + bpd = 0. + else: + raise ValueError('invalid input type!') + + return nll, bpd diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/plotting.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/plotting.py new file mode 100644 index 0000000..e5c614f --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/plotting.py @@ -0,0 +1,104 @@ +from __future__ import division +from __future__ import print_function + +import numpy as np +import matplotlib +# noninteractive background +matplotlib.use('Agg') +import matplotlib.pyplot as plt + + +def plot_training_curve(train_loss, validation_loss, fname='training_curve.pdf', labels=None): + """ + Plots train_loss and validation loss as a function of optimization iteration + :param train_loss: np.array of train_loss (1D or 2D) + :param validation_loss: np.array of validation loss (1D or 2D) + :param fname: output file name + :param labels: if train_loss and validation loss are 2D, then labels indicate which variable is varied + accross training curves. + :return: None + """ + + plt.close() + + matplotlib.rcParams.update({'font.size': 14}) + matplotlib.rcParams['mathtext.fontset'] = 'stix' + matplotlib.rcParams['font.family'] = 'STIXGeneral' + + if len(train_loss.shape) == 1: + # Single training curve + fig, ax = plt.subplots(nrows=1, ncols=1) + figsize = (6, 4) + + if train_loss.shape[0] == validation_loss.shape[0]: + # validation score evaluated every iteration + x = np.arange(train_loss.shape[0]) + ax.plot(x, train_loss, '-', lw=2., color='black', label='train') + ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val') + + elif train_loss.shape[0] % validation_loss.shape[0] == 0: + # validation score evaluated every epoch + x = np.arange(train_loss.shape[0]) + ax.plot(x, train_loss, '-', lw=2., color='black', label='train') + + x = np.arange(validation_loss.shape[0]) + x = (x + 1) * train_loss.shape[0] / validation_loss.shape[0] + ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val') + else: + raise ValueError('Length of train_loss and validation_loss must be equal or divisible') + + miny = np.minimum(validation_loss.min(), train_loss.min()) - 20. + maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30. + ax.set_ylim([miny, maxy]) + + elif len(train_loss.shape) == 2: + # Multiple training curves + + cmap = plt.cm.brg + + cNorm = matplotlib.colors.Normalize(vmin=0, vmax=train_loss.shape[0]) + scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap) + + fig, ax = plt.subplots(nrows=1, ncols=1) + figsize = (6, 4) + + if labels is None: + labels = ['%d' % i for i in range(train_loss.shape[0])] + + if train_loss.shape[1] == validation_loss.shape[1]: + for i in range(train_loss.shape[0]): + color_val = scalarMap.to_rgba(i) + + # validation score evaluated every iteration + x = np.arange(train_loss.shape[0]) + ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i]) + ax.plot(x, validation_loss[i], '--', lw=2., color=color_val) + + elif train_loss.shape[1] % validation_loss.shape[1] == 0: + for i in range(train_loss.shape[0]): + color_val = scalarMap.to_rgba(i) + + # validation score evaluated every epoch + x = np.arange(train_loss.shape[1]) + ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i]) + + x = np.arange(validation_loss.shape[1]) + x = (x + 1) * train_loss.shape[1] / validation_loss.shape[1] + ax.plot(x, validation_loss[i], '-', lw=2., color=color_val) + + miny = np.minimum(validation_loss.min(), train_loss.min()) - 20. + maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30. + ax.set_ylim([miny, maxy]) + + else: + raise ValueError('train_loss and validation_loss must be 1D or 2D arrays') + + ax.set_xlabel('iteration') + ax.set_ylabel('loss') + plt.title('Training and validation loss') + + fig.set_size_inches(figsize) + fig.subplots_adjust(hspace=0.1) + plt.savefig(fname, bbox_inches='tight') + + plt.close() diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/visual_evaluation.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/visual_evaluation.py new file mode 100644 index 0000000..1281dea --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/visual_evaluation.py @@ -0,0 +1,53 @@ +from __future__ import print_function +import os +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec + + +def plot_reconstructions(data, recon_mean, loss, loss_type, epoch, args): + + if args.input_type == 'multinomial': + # data is already between 0 and 1 + num_classes = 256 + # Find largest class logit + tmp = recon_mean.view(-1, num_classes, *args.input_size).max(dim=1)[1] + recon_mean = tmp.float() / (num_classes - 1.) + if epoch == 1: + if not os.path.exists(args.snap_dir + 'reconstruction/'): + os.makedirs(args.snap_dir + 'reconstruction/') + # VISUALIZATION: plot real images + plot_images(args, data.data.cpu().numpy()[0:9], args.snap_dir + 'reconstruction/', 'real', size_x=3, size_y=3) + # VISUALIZATION: plot reconstructions + if loss_type == 'bpd': + fname = str(epoch) + '_bpd_%5.3f' % loss + elif loss_type == 'elbo': + fname = str(epoch) + '_elbo_%6.4f' % loss + plot_images(args, recon_mean.data.cpu().numpy()[0:9], args.snap_dir + 'reconstruction/', fname, size_x=3, size_y=3) + + +def plot_images(args, x_sample, dir, file_name, size_x=3, size_y=3): + + fig = plt.figure(figsize=(size_x, size_y)) + # fig = plt.figure(1) + gs = gridspec.GridSpec(size_x, size_y) + gs.update(wspace=0.05, hspace=0.05) + + for i, sample in enumerate(x_sample): + ax = plt.subplot(gs[i]) + plt.axis('off') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_aspect('equal') + sample = sample.reshape((args.input_size[0], args.input_size[1], args.input_size[2])) + sample = sample.swapaxes(0, 2) + sample = sample.swapaxes(0, 1) + if (args.input_type == 'binary') or (args.input_type in ['multinomial'] and args.input_size[0] == 1): + sample = sample[:, :, 0] + plt.imshow(sample, cmap='gray', vmin=0, vmax=1) + else: + plt.imshow(sample) + + plt.savefig(dir + file_name + '.png', bbox_inches='tight') + plt.close(fig) diff --git a/src/torchprune/torchprune/util/metrics.py b/src/torchprune/torchprune/util/metrics.py index 8349f3b..39f2b5f 100644 --- a/src/torchprune/torchprune/util/metrics.py +++ b/src/torchprune/torchprune/util/metrics.py @@ -4,6 +4,8 @@ from scipy import stats import torch +from .nn_loss import NLLPriorLoss, NLLNatsLoss, NLLBitsLoss + class AbstractMetric(ABC): """Functor template for metric.""" @@ -29,7 +31,7 @@ def __call__(self, output, target): # check if output is dict if isinstance(output, dict): if "out" in output: - # Segmentation networks + # Segmentation networks and ffjord networks output = output["out"] elif "logits" in output: # BERT @@ -263,3 +265,60 @@ def short_name(self): def _get_metric(self, output, target): """Compute metric and return as 0d tensor.""" return torch.tensor(0.0) + + +class NLLPrior(AbstractMetric): + """A wrapper for the NLLPriorLoss as metric.""" + + @property + def name(self): + """Get the display name of this metric.""" + return "Negative Log Likelihood" + + @property + def short_name(self): + """Get the short name of this metric.""" + return "NLL" + + def _get_metric(self, output, target): + """Compute metric and return as 0d tensor.""" + # since for the metric higher is better we negate the loss + return -(NLLPriorLoss()(output["output"], target)) + + def __call__(self, output, target): + """Call metric like output but wrap output so we keep dictionary.""" + return super().__call__({"output": output}, target) + + +class NLLNats(NLLPrior): + """A wrapper for the NLLNatsLoss as metric.""" + + @property + def name(self): + """Get the display name of this metric.""" + return "Negative Log Probability (nats)" + + @property + def short_name(self): + """Get the short name of this metric.""" + return "Nats" + + def _get_metric(self, output, target): + return -(NLLNatsLoss()(output["output"], target)) + + +class NLLBits(NLLPrior): + """A wrapper for the NLLBitsLoss as metric.""" + + @property + def name(self): + """Get the display name of this metric.""" + return "Negative Log Probability (bits/dim)" + + @property + def short_name(self): + """Get the short name of this metric.""" + return "Bits" + + def _get_metric(self, output, target): + return -(NLLBitsLoss()(output["output"], target)) diff --git a/src/torchprune/torchprune/util/models/__init__.py b/src/torchprune/torchprune/util/models/__init__.py index 94d9b5b..fb86d1b 100644 --- a/src/torchprune/torchprune/util/models/__init__.py +++ b/src/torchprune/torchprune/util/models/__init__.py @@ -1,13 +1,18 @@ # flake8: noqa: F401, F403 """Package with all custom net implementation and CIFAR nets.""" # import custom nets -from .fcnet import FCNet, lenet300_100, lenet500_300_100 +from .fcnet import FCNet, lenet300_100, lenet500_300_100, fcnet_nettrim from .lenet5 import lenet5 from .deepknight import deepknight from .deeplab import * from .cnn60k import cnn60k from .cnn5 import cnn5 from .bert import bert +from .node import * +from .cnf import * +from .ffjord import * +from .ffjord_tabular import * +from .ffjord_cnf import * # import cifar nets from ..external.cnn.models.cifar import * diff --git a/src/torchprune/torchprune/util/models/cnf.py b/src/torchprune/torchprune/util/models/cnf.py new file mode 100644 index 0000000..1e8cbfa --- /dev/null +++ b/src/torchprune/torchprune/util/models/cnf.py @@ -0,0 +1,452 @@ +"""Module containing various FFjord NODE configurations with torchdyn lib.""" + +import torch +import torch.nn as nn +from torchdyn.models import autograd_trace + +from .ffjord import Ffjord + + +class VanillaCNF(Ffjord): + """Neural ODEs for CNFs via brute-force trace estimator.""" + + @property + def trace_estimator(self): + """Return the desired trace estimator.""" + return autograd_trace + + +def cnf_l4_h64_sigmoid(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + ) + + +def cnf_l4_h64_softplus(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and softplus.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Softplus, + ) + + +def cnf_l4_h64_tanh(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and tanh.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Tanh, + ) + + +def cnf_l4_h64_relu(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and relu.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.ReLU, + ) + + +def cnf_l8_h64_sigmoid(num_classes): + """Return a brute-force CNF with 8 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=8, + hidden_size=64, + module_activate=nn.Sigmoid, + ) + + +def cnf_l2_h128_sigmoid(num_classes): + """Return a brute-force CNF with 2 layers, 128 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=2, + hidden_size=128, + module_activate=nn.Sigmoid, + ) + + +def cnf_l2_h64_sigmoid(num_classes): + """Return a brute-force CNF with 2 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.Sigmoid, + ) + + +def cnf_l4_h64_sigmoid_dopri_adjoint(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h64_sigmoid_dopri_autograd(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="autograd", + solver="dopri5", + ) + + +def cnf_l4_h64_sigmoid_rk4_autograd(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 20), + sensitivity="autograd", + solver="rk4", + ) + + +def cnf_l4_h64_sigmoid_rk4_adjoint(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 20), + sensitivity="adjoint", + solver="rk4", + ) + + +def cnf_l4_h64_sigmoid_euler_autograd(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 100), + sensitivity="autograd", + solver="euler", + ) + + +def cnf_l4_h64_sigmoid_euler_adjoint(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 100), + sensitivity="adjoint", + solver="euler", + ) + + +def cnf_l4_h64_sigmoid_da(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h64_softplus_da(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and softplus.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Softplus, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h64_tanh_da(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and tanh.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h64_relu_da(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and relu.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.ReLU, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l8_h64_sigmoid_da(num_classes): + """Return a brute-force CNF with 8 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=8, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l8_h37_sigmoid_da(num_classes): + """Return a brute-force CNF with 8 layers, 37 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=8, + hidden_size=37, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l8_h18_sigmoid_da(num_classes): + """Return a brute-force CNF with 8 layers, 18 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=8, + hidden_size=18, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l8_h10_sigmoid_da(num_classes): + """Return a brute-force CNF with 8 layers, 10 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=8, + hidden_size=10, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l6_h45_sigmoid_da(num_classes): + """Return a brute-force CNF with 6 layers, 45 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=6, + hidden_size=45, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l6_h22_sigmoid_da(num_classes): + """Return a brute-force CNF with 6 layers, 22 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=6, + hidden_size=22, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l6_h12_sigmoid_da(num_classes): + """Return a brute-force CNF with 6 layers, 12 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=6, + hidden_size=12, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h128_sigmoid_da(num_classes): + """Return a brute-force CNF with 4 layers, 128 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=128, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h30_sigmoid_da(num_classes): + """Return a brute-force CNF with 4 layers, 30 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=30, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h17_sigmoid_da(num_classes): + """Return a brute-force CNF with 4 layers, 17 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=17, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l3_h90_sigmoid_da(num_classes): + """Return a brute-force CNF with 3 layers, 90 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=3, + hidden_size=90, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l3_h43_sigmoid_da(num_classes): + """Return a brute-force CNF with 3 layers, 43 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=3, + hidden_size=43, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l3_h23_sigmoid_da(num_classes): + """Return a brute-force CNF with 3 layers, 23 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=3, + hidden_size=23, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l2_h1700_sigmoid_da(num_classes): + """Return a brute-force CNF with 2 layers, 1700 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=2, + hidden_size=1700, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l2_h400_sigmoid_da(num_classes): + """Return a brute-force CNF with 2 layers, 400 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=2, + hidden_size=400, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l2_h128_sigmoid_da(num_classes): + """Return a brute-force CNF with 2 layers, 128 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=2, + hidden_size=128, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l2_h64_sigmoid_da(num_classes): + """Return a brute-force CNF with 2 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h64_sigmoid_da_high_tol(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + atol=1e-4, + rtol=1e-4, + ) diff --git a/src/torchprune/torchprune/util/models/cnn/README.md b/src/torchprune/torchprune/util/models/cnn/README.md deleted file mode 100644 index 707c6fd..0000000 --- a/src/torchprune/torchprune/util/models/cnn/README.md +++ /dev/null @@ -1,74 +0,0 @@ -# pytorch-classification -Classification on CIFAR-10/100 and ImageNet with PyTorch. - -## Features -* Unified interface for different network architectures -* Multi-GPU support -* Training progress bar with rich info -* Training log and training curve visualization code (see `./utils/logger.py`) - -## Install -* Install [PyTorch](http://pytorch.org/) -* Clone recursively - ``` - git clone --recursive https://github.com/bearpaw/pytorch-classification.git - ``` - -## Training -Please see the [Training recipes](TRAINING.md) for how to train the models. - -## Results - -### CIFAR -Top1 error rate on the CIFAR-10/100 benchmarks are reported. You may get different results when training your models with different random seed. -Note that the number of parameters are computed on the CIFAR-10 dataset. - -| Model | Params (M) | CIFAR-10 (%) | CIFAR-100 (%) | -| ------------------- | ------------------ | ------------------ | ------------------ | -| alexnet | 2.47 | 22.78 | 56.13 | -| vgg19_bn | 20.04 | 6.66 | 28.05 | -| ResNet-110 | 1.70 | 6.11 | 28.86 | -| PreResNet-110 | 1.70 | 4.94 | 23.65 | -| WRN-28-10 (drop 0.3) | 36.48 | 3.79 | 18.14 | -| ResNeXt-29, 8x64 | 34.43 | 3.69 | 17.38 | -| ResNeXt-29, 16x64 | 68.16 | 3.53 | 17.30 | -| DenseNet-BC (L=100, k=12) | 0.77 | 4.54 | 22.88 | -| DenseNet-BC (L=190, k=40) | 25.62 | 3.32 | 17.17 | - - -![cifar](utils/images/cifar.png) - -### ImageNet -Single-crop (224x224) validation error rate is reported. - - -| Model | Params (M) | Top-1 Error (%) | Top-5 Error (%) | -| ------------------- | ------------------ | ------------------ | ------------------ | -| ResNet-18 | 11.69 | 30.09 | 10.78 | -| ResNeXt-50 (32x4d) | 25.03 | 22.6 | 6.29 | - -![Validation curve](utils/images/imagenet.png) - -## Pretrained models -Our trained models and training logs are downloadable at [OneDrive](https://mycuhk-my.sharepoint.com/personal/1155056070_link_cuhk_edu_hk/_layouts/15/guestaccess.aspx?folderid=0a380d1fece1443f0a2831b761df31905&authkey=Ac5yBC-FSE4oUJZ2Lsx7I5c). - -## Supported Architectures - -### CIFAR-10 / CIFAR-100 -Since the size of images in CIFAR dataset is `32x32`, popular network structures for ImageNet need some modifications to adapt this input size. The modified models is in the package `models.cifar`: -- [x] [AlexNet](https://arxiv.org/abs/1404.5997) -- [x] [VGG](https://arxiv.org/abs/1409.1556) (Imported from [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar)) -- [x] [ResNet](https://arxiv.org/abs/1512.03385) -- [x] [Pre-act-ResNet](https://arxiv.org/abs/1603.05027) -- [x] [ResNeXt](https://arxiv.org/abs/1611.05431) (Imported from [ResNeXt.pytorch](https://github.com/prlz77/ResNeXt.pytorch)) -- [x] [Wide Residual Networks](http://arxiv.org/abs/1605.07146) (Imported from [WideResNet-pytorch](https://github.com/xternalz/WideResNet-pytorch)) -- [x] [DenseNet](https://arxiv.org/abs/1608.06993) - -### ImageNet -- [x] All models in `torchvision.models` (alexnet, vgg, resnet, densenet, inception_v3, squeezenet) -- [x] [ResNeXt](https://arxiv.org/abs/1611.05431) -- [ ] [Wide Residual Networks](http://arxiv.org/abs/1605.07146) - - -## Contribute -Feel free to create a pull request if you find any bugs or you want to contribute (e.g., more datasets and more network structures). diff --git a/src/torchprune/torchprune/util/models/cnn/TRAINING.md b/src/torchprune/torchprune/util/models/cnn/TRAINING.md deleted file mode 100644 index ff140ab..0000000 --- a/src/torchprune/torchprune/util/models/cnn/TRAINING.md +++ /dev/null @@ -1,119 +0,0 @@ - -## CIFAR-10 - -#### AlexNet -``` -python cifar.py -a alexnet --epochs 164 --schedule 81 122 --gamma 0.1 --checkpoint checkpoints/cifar10/alexnet -``` - - -#### VGG19 (BN) -``` -python cifar.py -a vgg19_bn --epochs 164 --schedule 81 122 --gamma 0.1 --checkpoint checkpoints/cifar10/vgg19_bn -``` - -#### ResNet-110 -``` -python cifar.py -a resnet --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar10/resnet-110 -``` - -#### ResNet-1202 -``` -python cifar.py -a resnet --depth 1202 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar10/resnet-1202 -``` - -#### PreResNet-110 -``` -python cifar.py -a preresnet --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar10/preresnet-110 -``` - -#### ResNeXt-29, 8x64d -``` -python cifar.py -a resnext --depth 29 --cardinality 8 --widen-factor 4 --schedule 150 225 --wd 5e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/resnext-8x64d -``` -#### ResNeXt-29, 16x64d -``` -python cifar.py -a resnext --depth 29 --cardinality 16 --widen-factor 4 --schedule 150 225 --wd 5e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/resnext-16x64d -``` - -#### WRN-28-10-drop -``` -python cifar.py -a wrn --depth 28 --depth 28 --widen-factor 10 --drop 0.3 --epochs 200 --schedule 60 120 160 --wd 5e-4 --gamma 0.2 --checkpoint checkpoints/cifar10/WRN-28-10-drop -``` - -#### DenseNet-BC (L=100, k=12) -**Note**: -* DenseNet use weight decay value `1e-4`. Larger weight decay (`5e-4`) if harmful for the accuracy (95.46 vs. 94.05) -* Official batch size is 64. But there is no big difference using batchsize 64 or 128 (95.46 vs 95.11). - -``` -python cifar.py -a densenet --depth 100 --growthRate 12 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/densenet-bc-100-12 -``` - -#### DenseNet-BC (L=190, k=40) -``` -python cifar.py -a densenet --depth 190 --growthRate 40 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/densenet-bc-L190-k40 -``` - -## CIFAR-100 - -#### AlexNet -``` -python cifar.py -a alexnet --dataset cifar100 --checkpoint checkpoints/cifar100/alexnet --epochs 164 --schedule 81 122 --gamma 0.1 -``` - -#### VGG19 (BN) -``` -python cifar.py -a vgg19_bn --dataset cifar100 --checkpoint checkpoints/cifar100/vgg19_bn --epochs 164 --schedule 81 122 --gamma 0.1 -``` - -#### ResNet-110 -``` -python cifar.py -a resnet --dataset cifar100 --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/resnet-110 -``` - -#### ResNet-1202 -``` -python cifar.py -a resnet --dataset cifar100 --depth 1202 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/resnet-1202 -``` - -#### PreResNet-110 -``` -python cifar.py -a preresnet --dataset cifar100 --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/preresnet-110 -``` - -#### ResNeXt-29, 8x64d -``` -python cifar.py -a resnext --dataset cifar100 --depth 29 --cardinality 8 --widen-factor 4 --checkpoint checkpoints/cifar100/resnext-8x64d --schedule 150 225 --wd 5e-4 --gamma 0.1 -``` -#### ResNeXt-29, 16x64d -``` -python cifar.py -a resnext --dataset cifar100 --depth 29 --cardinality 16 --widen-factor 4 --checkpoint checkpoints/cifar100/resnext-16x64d --schedule 150 225 --wd 5e-4 --gamma 0.1 -``` - -#### WRN-28-10-drop -``` -python cifar.py -a wrn --dataset cifar100 --depth 28 --depth 28 --widen-factor 10 --drop 0.3 --epochs 200 --schedule 60 120 160 --wd 5e-4 --gamma 0.2 --checkpoint checkpoints/cifar100/WRN-28-10-drop -``` - -#### DenseNet-BC (L=100, k=12) -``` -python cifar.py -a densenet --dataset cifar100 --depth 100 --growthRate 12 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar100/densenet-bc-100-12 -``` - -#### DenseNet-BC (L=190, k=40) -``` -python cifar.py -a densenet --dataset cifar100 --depth 190 --growthRate 40 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar100/densenet-bc-L190-k40 -``` - -## ImageNet -### ResNet-18 -``` -python imagenet.py -a resnet18 --data ~/dataset/ILSVRC2012/ --epochs 90 --schedule 31 61 --gamma 0.1 -c checkpoints/imagenet/resnet18 -``` - -### ResNeXt-50 (32x4d) -*(Originally trained on 8xGPUs)* -``` -python imagenet.py -a resnext50 --base-width 4 --cardinality 32 --data ~/dataset/ILSVRC2012/ --epochs 90 --schedule 31 61 --gamma 0.1 -c checkpoints/imagenet/resnext50-32x4d -``` \ No newline at end of file diff --git a/src/torchprune/torchprune/util/models/cnn/cifar.py b/src/torchprune/torchprune/util/models/cnn/cifar.py deleted file mode 100644 index 510e474..0000000 --- a/src/torchprune/torchprune/util/models/cnn/cifar.py +++ /dev/null @@ -1,350 +0,0 @@ -''' -Training script for CIFAR-10/100 -Copyright (c) Wei YANG, 2017 -''' -from __future__ import print_function - -import argparse -import os -import shutil -import time -import random - -import torch -import torch.nn as nn -import torch.nn.parallel -import torch.backends.cudnn as cudnn -import torch.optim as optim -import torch.utils.data as data -import torchvision.transforms as transforms -import torchvision.datasets as datasets -import models.cifar as models - -from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig - - -model_names = sorted(name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name])) - -parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100 Training') -# Datasets -parser.add_argument('-d', '--dataset', default='cifar10', type=str) -parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='number of data loading workers (default: 4)') -# Optimization options -parser.add_argument('--epochs', default=300, type=int, metavar='N', - help='number of total epochs to run') -parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='manual epoch number (useful on restarts)') -parser.add_argument('--train-batch', default=128, type=int, metavar='N', - help='train batchsize') -parser.add_argument('--test-batch', default=100, type=int, metavar='N', - help='test batchsize') -parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, - metavar='LR', help='initial learning rate') -parser.add_argument('--drop', '--dropout', default=0, type=float, - metavar='Dropout', help='Dropout ratio') -parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], - help='Decrease learning rate at these epochs.') -parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') -parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') -parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)') -# Checkpoints -parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', - help='path to save checkpoint (default: checkpoint)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') -# Architecture -parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet20', - choices=model_names, - help='model architecture: ' + - ' | '.join(model_names) + - ' (default: resnet18)') -parser.add_argument('--depth', type=int, default=29, help='Model depth.') -parser.add_argument('--cardinality', type=int, default=8, help='Model cardinality (group).') -parser.add_argument('--widen-factor', type=int, default=4, help='Widen factor. 4 -> 64, 8 -> 128, ...') -parser.add_argument('--growthRate', type=int, default=12, help='Growth rate for DenseNet.') -parser.add_argument('--compressionRate', type=int, default=2, help='Compression Rate (theta) for DenseNet.') -# Miscs -parser.add_argument('--manualSeed', type=int, help='manual seed') -parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', - help='evaluate model on validation set') -#Device options -parser.add_argument('--gpu-id', default='0', type=str, - help='id(s) for CUDA_VISIBLE_DEVICES') - -args = parser.parse_args() -state = {k: v for k, v in args._get_kwargs()} - -# Validate dataset -assert args.dataset == 'cifar10' or args.dataset == 'cifar100', 'Dataset can only be cifar10 or cifar100.' - -# Use CUDA -os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id -use_cuda = torch.cuda.is_available() - -# Random seed -if args.manualSeed is None: - args.manualSeed = random.randint(1, 10000) -random.seed(args.manualSeed) -torch.manual_seed(args.manualSeed) -if use_cuda: - torch.cuda.manual_seed_all(args.manualSeed) - -best_acc = 0 # best test accuracy - -def main(): - global best_acc - start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch - - if not os.path.isdir(args.checkpoint): - mkdir_p(args.checkpoint) - - - - # Data - print('==> Preparing dataset %s' % args.dataset) - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - if args.dataset == 'cifar10': - dataloader = datasets.CIFAR10 - num_classes = 10 - else: - dataloader = datasets.CIFAR100 - num_classes = 100 - - - trainset = dataloader(root='./data', train=True, download=True, transform=transform_train) - trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers) - - testset = dataloader(root='./data', train=False, download=False, transform=transform_test) - testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) - - # Model - print("==> creating model '{}'".format(args.arch)) - if args.arch.startswith('resnext'): - model = models.__dict__[args.arch]( - cardinality=args.cardinality, - num_classes=num_classes, - depth=args.depth, - widen_factor=args.widen_factor, - dropRate=args.drop, - ) - elif args.arch.startswith('densenet'): - model = models.__dict__[args.arch]( - num_classes=num_classes, - depth=args.depth, - growthRate=args.growthRate, - compressionRate=args.compressionRate, - dropRate=args.drop, - ) - elif args.arch.startswith('wrn'): - model = models.__dict__[args.arch]( - num_classes=num_classes, - depth=args.depth, - widen_factor=args.widen_factor, - dropRate=args.drop, - ) - elif args.arch.endswith('resnet'): - model = models.__dict__[args.arch]( - num_classes=num_classes, - depth=args.depth, - ) - else: - model = models.__dict__[args.arch](num_classes=num_classes) - - model = torch.nn.DataParallel(model).cuda() - cudnn.benchmark = True - print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) - - criterion = nn.CrossEntropyLoss() - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - - # Resume - title = 'cifar-10-' + args.arch - if args.resume: - # Load checkpoint. - print('==> Resuming from checkpoint..') - assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' - args.checkpoint = os.path.dirname(args.resume) - checkpoint = torch.load(args.resume) - best_acc = checkpoint['best_acc'] - start_epoch = checkpoint['epoch'] - model.load_state_dict(checkpoint['state_dict']) - optimizer.load_state_dict(checkpoint['optimizer']) - logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) - else: - logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) - logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) - - - if args.evaluate: - print('\nEvaluation only') - test_loss, test_acc = test(testloader, model, criterion, start_epoch, use_cuda) - print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) - return - - # Train and val - for epoch in range(start_epoch, args.epochs): - adjust_learning_rate(optimizer, epoch) - - print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) - - train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda) - test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda) - - # append logger file - logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc]) - - # save model - is_best = test_acc > best_acc - best_acc = max(test_acc, best_acc) - save_checkpoint({ - 'epoch': epoch + 1, - 'state_dict': model.state_dict(), - 'acc': test_acc, - 'best_acc': best_acc, - 'optimizer' : optimizer.state_dict(), - }, is_best, checkpoint=args.checkpoint) - - logger.close() - logger.plot() - savefig(os.path.join(args.checkpoint, 'log.eps')) - - print('Best acc:') - print(best_acc) - -def train(trainloader, model, criterion, optimizer, epoch, use_cuda): - # switch to train mode - model.train() - - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - end = time.time() - - bar = Bar('Processing', max=len(trainloader)) - for batch_idx, (inputs, targets) in enumerate(trainloader): - # measure data loading time - data_time.update(time.time() - end) - - if use_cuda: - inputs, targets = inputs.cuda(), targets.cuda(async=True) - inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) - - # compute output - outputs = model(inputs) - loss = criterion(outputs, targets) - - # measure accuracy and record loss - prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) - losses.update(loss.data[0], inputs.size(0)) - top1.update(prec1[0], inputs.size(0)) - top5.update(prec5[0], inputs.size(0)) - - # compute gradient and do SGD step - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - # plot progress - bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( - batch=batch_idx + 1, - size=len(trainloader), - data=data_time.avg, - bt=batch_time.avg, - total=bar.elapsed_td, - eta=bar.eta_td, - loss=losses.avg, - top1=top1.avg, - top5=top5.avg, - ) - bar.next() - bar.finish() - return (losses.avg, top1.avg) - -def test(testloader, model, criterion, epoch, use_cuda): - global best_acc - - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - - # switch to evaluate mode - model.eval() - - end = time.time() - bar = Bar('Processing', max=len(testloader)) - for batch_idx, (inputs, targets) in enumerate(testloader): - # measure data loading time - data_time.update(time.time() - end) - - if use_cuda: - inputs, targets = inputs.cuda(), targets.cuda() - inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets) - - # compute output - outputs = model(inputs) - loss = criterion(outputs, targets) - - # measure accuracy and record loss - prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) - losses.update(loss.data[0], inputs.size(0)) - top1.update(prec1[0], inputs.size(0)) - top5.update(prec5[0], inputs.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - # plot progress - bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( - batch=batch_idx + 1, - size=len(testloader), - data=data_time.avg, - bt=batch_time.avg, - total=bar.elapsed_td, - eta=bar.eta_td, - loss=losses.avg, - top1=top1.avg, - top5=top5.avg, - ) - bar.next() - bar.finish() - return (losses.avg, top1.avg) - -def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): - filepath = os.path.join(checkpoint, filename) - torch.save(state, filepath) - if is_best: - shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) - -def adjust_learning_rate(optimizer, epoch): - global state - if epoch in args.schedule: - state['lr'] *= args.gamma - for param_group in optimizer.param_groups: - param_group['lr'] = state['lr'] - -if __name__ == '__main__': - main() diff --git a/src/torchprune/torchprune/util/models/cnn/imagenet.py b/src/torchprune/torchprune/util/models/cnn/imagenet.py deleted file mode 100644 index b1dace6..0000000 --- a/src/torchprune/torchprune/util/models/cnn/imagenet.py +++ /dev/null @@ -1,344 +0,0 @@ -''' -Training script for ImageNet -Copyright (c) Wei YANG, 2017 -''' -from __future__ import print_function - -import argparse -import os -import shutil -import time -import random - -import torch -import torch.nn as nn -import torch.nn.parallel -import torch.backends.cudnn as cudnn -import torch.optim as optim -import torch.utils.data as data -import torchvision.transforms as transforms -import torchvision.datasets as datasets -import torchvision.models as models -import models.imagenet as customized_models - -from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig - -# Models -default_model_names = sorted(name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name])) - -customized_models_names = sorted(name for name in customized_models.__dict__ - if name.islower() and not name.startswith("__") - and callable(customized_models.__dict__[name])) - -for name in customized_models.__dict__: - if name.islower() and not name.startswith("__") and callable(customized_models.__dict__[name]): - models.__dict__[name] = customized_models.__dict__[name] - -model_names = default_model_names + customized_models_names - -# Parse arguments -parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') - -# Datasets -parser.add_argument('-d', '--data', default='path to dataset', type=str) -parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='number of data loading workers (default: 4)') -# Optimization options -parser.add_argument('--epochs', default=90, type=int, metavar='N', - help='number of total epochs to run') -parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='manual epoch number (useful on restarts)') -parser.add_argument('--train-batch', default=256, type=int, metavar='N', - help='train batchsize (default: 256)') -parser.add_argument('--test-batch', default=200, type=int, metavar='N', - help='test batchsize (default: 200)') -parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, - metavar='LR', help='initial learning rate') -parser.add_argument('--drop', '--dropout', default=0, type=float, - metavar='Dropout', help='Dropout ratio') -parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], - help='Decrease learning rate at these epochs.') -parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') -parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') -parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)') -# Checkpoints -parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', - help='path to save checkpoint (default: checkpoint)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') -# Architecture -parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', - choices=model_names, - help='model architecture: ' + - ' | '.join(model_names) + - ' (default: resnet18)') -parser.add_argument('--depth', type=int, default=29, help='Model depth.') -parser.add_argument('--cardinality', type=int, default=32, help='ResNet cardinality (group).') -parser.add_argument('--base-width', type=int, default=4, help='ResNet base width.') -parser.add_argument('--widen-factor', type=int, default=4, help='Widen factor. 4 -> 64, 8 -> 128, ...') -# Miscs -parser.add_argument('--manualSeed', type=int, help='manual seed') -parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', - help='evaluate model on validation set') -parser.add_argument('--pretrained', dest='pretrained', action='store_true', - help='use pre-trained model') -#Device options -parser.add_argument('--gpu-id', default='0', type=str, - help='id(s) for CUDA_VISIBLE_DEVICES') - -args = parser.parse_args() -state = {k: v for k, v in args._get_kwargs()} - -# Use CUDA -os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id -use_cuda = torch.cuda.is_available() - -# Random seed -if args.manualSeed is None: - args.manualSeed = random.randint(1, 10000) -random.seed(args.manualSeed) -torch.manual_seed(args.manualSeed) -if use_cuda: - torch.cuda.manual_seed_all(args.manualSeed) - -best_acc = 0 # best test accuracy - -def main(): - global best_acc - start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch - - if not os.path.isdir(args.checkpoint): - mkdir_p(args.checkpoint) - - # Data loading code - traindir = os.path.join(args.data, 'train') - valdir = os.path.join(args.data, 'val') - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - - train_loader = torch.utils.data.DataLoader( - datasets.ImageFolder(traindir, transforms.Compose([ - transforms.RandomSizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])), - batch_size=args.train_batch, shuffle=True, - num_workers=args.workers, pin_memory=True) - - val_loader = torch.utils.data.DataLoader( - datasets.ImageFolder(valdir, transforms.Compose([ - transforms.Scale(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])), - batch_size=args.test_batch, shuffle=False, - num_workers=args.workers, pin_memory=True) - - # create model - if args.pretrained: - print("=> using pre-trained model '{}'".format(args.arch)) - model = models.__dict__[args.arch](pretrained=True) - elif args.arch.startswith('resnext'): - model = models.__dict__[args.arch]( - baseWidth=args.base_width, - cardinality=args.cardinality, - ) - else: - print("=> creating model '{}'".format(args.arch)) - model = models.__dict__[args.arch]() - - if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): - model.features = torch.nn.DataParallel(model.features) - model.cuda() - else: - model = torch.nn.DataParallel(model).cuda() - - cudnn.benchmark = True - print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) - - # define loss function (criterion) and optimizer - criterion = nn.CrossEntropyLoss().cuda() - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - - # Resume - title = 'ImageNet-' + args.arch - if args.resume: - # Load checkpoint. - print('==> Resuming from checkpoint..') - assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' - args.checkpoint = os.path.dirname(args.resume) - checkpoint = torch.load(args.resume) - best_acc = checkpoint['best_acc'] - start_epoch = checkpoint['epoch'] - model.load_state_dict(checkpoint['state_dict']) - optimizer.load_state_dict(checkpoint['optimizer']) - logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) - else: - logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) - logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) - - - if args.evaluate: - print('\nEvaluation only') - test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda) - print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) - return - - # Train and val - for epoch in range(start_epoch, args.epochs): - adjust_learning_rate(optimizer, epoch) - - print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) - - train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda) - test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda) - - # append logger file - logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc]) - - # save model - is_best = test_acc > best_acc - best_acc = max(test_acc, best_acc) - save_checkpoint({ - 'epoch': epoch + 1, - 'state_dict': model.state_dict(), - 'acc': test_acc, - 'best_acc': best_acc, - 'optimizer' : optimizer.state_dict(), - }, is_best, checkpoint=args.checkpoint) - - logger.close() - logger.plot() - savefig(os.path.join(args.checkpoint, 'log.eps')) - - print('Best acc:') - print(best_acc) - -def train(train_loader, model, criterion, optimizer, epoch, use_cuda): - # switch to train mode - model.train() - - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - end = time.time() - - bar = Bar('Processing', max=len(train_loader)) - for batch_idx, (inputs, targets) in enumerate(train_loader): - # measure data loading time - data_time.update(time.time() - end) - - if use_cuda: - inputs, targets = inputs.cuda(), targets.cuda(async=True) - inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) - - # compute output - outputs = model(inputs) - loss = criterion(outputs, targets) - - # measure accuracy and record loss - prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) - losses.update(loss.data[0], inputs.size(0)) - top1.update(prec1[0], inputs.size(0)) - top5.update(prec5[0], inputs.size(0)) - - # compute gradient and do SGD step - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - # plot progress - bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( - batch=batch_idx + 1, - size=len(train_loader), - data=data_time.val, - bt=batch_time.val, - total=bar.elapsed_td, - eta=bar.eta_td, - loss=losses.avg, - top1=top1.avg, - top5=top5.avg, - ) - bar.next() - bar.finish() - return (losses.avg, top1.avg) - -def test(val_loader, model, criterion, epoch, use_cuda): - global best_acc - - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - - # switch to evaluate mode - model.eval() - - end = time.time() - bar = Bar('Processing', max=len(val_loader)) - for batch_idx, (inputs, targets) in enumerate(val_loader): - # measure data loading time - data_time.update(time.time() - end) - - if use_cuda: - inputs, targets = inputs.cuda(), targets.cuda() - inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets) - - # compute output - outputs = model(inputs) - loss = criterion(outputs, targets) - - # measure accuracy and record loss - prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) - losses.update(loss.data[0], inputs.size(0)) - top1.update(prec1[0], inputs.size(0)) - top5.update(prec5[0], inputs.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - # plot progress - bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( - batch=batch_idx + 1, - size=len(val_loader), - data=data_time.avg, - bt=batch_time.avg, - total=bar.elapsed_td, - eta=bar.eta_td, - loss=losses.avg, - top1=top1.avg, - top5=top5.avg, - ) - bar.next() - bar.finish() - return (losses.avg, top1.avg) - -def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): - filepath = os.path.join(checkpoint, filename) - torch.save(state, filepath) - if is_best: - shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) - -def adjust_learning_rate(optimizer, epoch): - global state - if epoch in args.schedule: - state['lr'] *= args.gamma - for param_group in optimizer.param_groups: - param_group['lr'] = state['lr'] - -if __name__ == '__main__': - main() diff --git a/src/torchprune/torchprune/util/models/cnn/models/cifar/__init__.py b/src/torchprune/torchprune/util/models/cnn/models/cifar/__init__.py deleted file mode 100644 index 3011b7a..0000000 --- a/src/torchprune/torchprune/util/models/cnn/models/cifar/__init__.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import absolute_import - -"""The models subpackage contains definitions for the following model for CIFAR10/CIFAR100 -architectures: - -- `AlexNet`_ -- `VGG`_ -- `ResNet`_ -- `SqueezeNet`_ -- `DenseNet`_ - -You can construct a model with random weights by calling its constructor: - -.. code:: python - - import torchvision.models as models - resnet18 = models.resnet18() - alexnet = models.alexnet() - squeezenet = models.squeezenet1_0() - densenet = models.densenet_161() - -We provide pre-trained models for the ResNet variants and AlexNet, using the -PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing -``pretrained=True``: - -.. code:: python - - import torchvision.models as models - resnet18 = models.resnet18(pretrained=True) - alexnet = models.alexnet(pretrained=True) - -ImageNet 1-crop error rates (224x224) - -======================== ============= ============= -Network Top-1 error Top-5 error -======================== ============= ============= -ResNet-18 30.24 10.92 -ResNet-34 26.70 8.58 -ResNet-50 23.85 7.13 -ResNet-101 22.63 6.44 -ResNet-152 21.69 5.94 -Inception v3 22.55 6.44 -AlexNet 43.45 20.91 -VGG-11 30.98 11.37 -VGG-13 30.07 10.75 -VGG-16 28.41 9.62 -VGG-19 27.62 9.12 -SqueezeNet 1.0 41.90 19.58 -SqueezeNet 1.1 41.81 19.38 -Densenet-121 25.35 7.83 -Densenet-169 24.00 7.00 -Densenet-201 22.80 6.43 -Densenet-161 22.35 6.20 -======================== ============= ============= - - -.. _AlexNet: https://arxiv.org/abs/1404.5997 -.. _VGG: https://arxiv.org/abs/1409.1556 -.. _ResNet: https://arxiv.org/abs/1512.03385 -.. _SqueezeNet: https://arxiv.org/abs/1602.07360 -.. _DenseNet: https://arxiv.org/abs/1608.06993 -""" - -from .alexnet import * -from .vgg import * -from .resnet import * -from .resnext import * -from .wrn import * -from .preresnet import * -from .densenet import * diff --git a/src/torchprune/torchprune/util/models/cnn/models/cifar/alexnet.py b/src/torchprune/torchprune/util/models/cnn/models/cifar/alexnet.py deleted file mode 100644 index 8c9407d..0000000 --- a/src/torchprune/torchprune/util/models/cnn/models/cifar/alexnet.py +++ /dev/null @@ -1,44 +0,0 @@ -'''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted. -Without BN, the start learning rate should be 0.01 -(c) YANG, Wei -''' -import torch.nn as nn - - -__all__ = ['alexnet'] - - -class AlexNet(nn.Module): - - def __init__(self, num_classes=10): - super(AlexNet, self).__init__() - self.features = nn.Sequential( - nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(64, 192, kernel_size=5, padding=2), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(192, 384, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(384, 256, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(256, 256, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - ) - self.classifier = nn.Linear(256, num_classes) - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), -1) - x = self.classifier(x) - return x - - -def alexnet(**kwargs): - r"""AlexNet model architecture from the - `"One weird trick..."