diff --git a/README.md b/README.md index 7e6aa1b..55337d5 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,9 @@ -# Neural Network Pruning -[Lucas Liebenwein](https://people.csail.mit.edu/lucasl/), -[Cenk Baykal](http://www.mit.edu/~baykal/), -[Alaa Maalouf](https://www.linkedin.com/in/alaa-maalouf/), -[Igor Gilitschenski](https://www.gilitschenski.org/igor/), -[Dan Feldman](http://people.csail.mit.edu/dannyf/), -[Daniela Rus](http://danielarus.csail.mit.edu/) +# torchprune +Main contributors of this code base: +[Lucas Liebenwein](http://www.mit.edu/~lucasl/), +[Cenk Baykal](http://www.mit.edu/~baykal/). + +Please check individual paper folders for authors of each paper.

@@ -15,10 +14,11 @@ This repository contains code to reproduce the results from the following papers: | Paper | Venue | Title & Link | | :---: | :---: | :--- | +| **Node** | NeurIPS 2021 | [Sparse Flows: Pruning Continuous-depth Models](https://proceedings.neurips.cc/paper/2021/hash/bf1b2f4b901c21a1d8645018ea9aeb05-Abstract.html) | | **ALDS** | NeurIPS 2021 | [Compressing Neural Networks: Towards Determining the Optimal Layer-wise Decomposition](https://arxiv.org/abs/2107.11442) | | **Lost** | MLSys 2021 | [Lost in Pruning: The Effects of Pruning Neural Networks beyond Test Accuracy](https://proceedings.mlsys.org/paper/2021/hash/2a79ea27c279e471f4d180b08d62b00a-Abstract.html) | | **PFP** | ICLR 2020 | [Provable Filter Pruning for Efficient Neural Networks](https://openreview.net/forum?id=BJxkOlSYDH) | -| **SiPP** | arXiv | [SiPPing Neural Networks: Sensitivity-informed Provable Pruning of Neural Networks](https://arxiv.org/abs/1910.05422) | +| **SiPP** | SIAM 2022 | [SiPPing Neural Networks: Sensitivity-informed Provable Pruning of Neural Networks](https://doi.org/10.1137/20M1383239) | ### Packages In addition, the repo also contains two stand-alone python packages that @@ -35,6 +35,7 @@ about the paper and scripts and parameter configuration to reproduce the exact results from the paper. | Paper | Location | | :---: | :---: | +| **Node** | [paper/node](./paper/node) | | **ALDS** | [paper/alds](./paper/alds) | | **Lost** | [paper/lost](./paper/lost) | | **PFP** | [paper/pfp](./paper/pfp) | @@ -98,14 +99,27 @@ using the codebase. | --- | --- | | [src/torchprune/README.md](./src/torchprune) | more details to prune neural networks, how to use and setup the data sets, how to implement custom pruning methods, and how to add your data sets and networks. | | [src/experiment/README.md](./src/experiment) | more details on how to configure and run your own experiments, and more information on how to re-produce the results. | +| [paper/node/README.md](./paper/node) | check out for more information on the [Node](https://proceedings.neurips.cc/paper/2021/hash/bf1b2f4b901c21a1d8645018ea9aeb05-Abstract.html) paper. | | [paper/alds/README.md](./paper/alds) | check out for more information on the [ALDS](https://arxiv.org/abs/2107.11442) paper. | | [paper/lost/README.md](./paper/lost) | check out for more information on the [Lost](https://proceedings.mlsys.org/paper/2021/hash/2a79ea27c279e471f4d180b08d62b00a-Abstract.html) paper. | | [paper/pfp/README.md](./paper/pfp) | check out for more information on the [PFP](https://openreview.net/forum?id=BJxkOlSYDH) paper. | -| [paper/sipp/README.md](./paper/sipp) | check out for more information on the [SiPP](https://arxiv.org/abs/1910.05422) paper. | +| [paper/sipp/README.md](./paper/sipp) | check out for more information on the [SiPP](https://doi.org/10.1137/20M1383239) paper. | ## Citations Please cite the respective papers when using our work. +### [Sparse flows: Pruning continuous-depth models](https://proceedings.neurips.cc/paper/2021/hash/bf1b2f4b901c21a1d8645018ea9aeb05-Abstract.html) +``` +@article{liebenwein2021sparse, + title={Sparse flows: Pruning continuous-depth models}, + author={Liebenwein, Lucas and Hasani, Ramin and Amini, Alexander and Rus, Daniela}, + journal={Advances in Neural Information Processing Systems}, + volume={34}, + pages={22628--22642}, + year={2021} +} +``` + ### [Towards Determining the Optimal Layer-wise Decomposition](https://arxiv.org/abs/2107.11442) ``` @inproceedings{liebenwein2021alds, @@ -140,12 +154,16 @@ url={https://openreview.net/forum?id=BJxkOlSYDH} } ``` -### [SiPPing Neural Networks](https://arxiv.org/abs/1910.05422) +### [SiPPing Neural Networks](https://doi.org/10.1137/20M1383239) (Weight Pruning) ``` -@article{baykal2019sipping, -title={SiPPing Neural Networks: Sensitivity-informed Provable Pruning of Neural Networks}, -author={Baykal, Cenk and Liebenwein, Lucas and Gilitschenski, Igor and Feldman, Dan and Rus, Daniela}, -journal={arXiv preprint arXiv:1910.05422}, -year={2019} +@article{baykal2022sensitivity, + title={Sensitivity-informed provable pruning of neural networks}, + author={Baykal, Cenk and Liebenwein, Lucas and Gilitschenski, Igor and Feldman, Dan and Rus, Daniela}, + journal={SIAM Journal on Mathematics of Data Science}, + volume={4}, + number={1}, + pages={26--45}, + year={2022}, + publisher={SIAM} } ``` \ No newline at end of file diff --git a/misc/imgs/node_overview.png b/misc/imgs/node_overview.png new file mode 100644 index 0000000..ff19684 Binary files /dev/null and b/misc/imgs/node_overview.png differ diff --git a/misc/requirements.txt b/misc/requirements.txt index efe3eb2..704fb17 100644 --- a/misc/requirements.txt +++ b/misc/requirements.txt @@ -3,10 +3,10 @@ -e ./src/experiment # We need those with special tags unfortunately... --f https://download.pytorch.org/whl/torch_stable.html -torch==1.7.1+cu110 -torchvision==0.8.2+cu110 -torchaudio===0.7.2 +-f https://download.pytorch.org/whl/lts/1.8/torch_lts.html +torch==1.8.2+cu111 +torchvision==0.9.2+cu111 +torchaudio==0.8.2 # Some extra requirements for the code base jupyter diff --git a/paper/alds/param/imagenet/prune/mbv2.yaml b/paper/alds/param/imagenet/prune/mobilenet_v2.yaml similarity index 100% rename from paper/alds/param/imagenet/prune/mbv2.yaml rename to paper/alds/param/imagenet/prune/mobilenet_v2.yaml diff --git a/paper/alds/param/imagenet/retrain/mbv2.yaml b/paper/alds/param/imagenet/retrain/mobilenet_v2.yaml similarity index 100% rename from paper/alds/param/imagenet/retrain/mbv2.yaml rename to paper/alds/param/imagenet/retrain/mobilenet_v2.yaml diff --git a/paper/alds/script/results_viewer.py b/paper/alds/script/results_viewer.py index 95b405a..f2c8d8b 100644 --- a/paper/alds/script/results_viewer.py +++ b/paper/alds/script/results_viewer.py @@ -38,7 +38,7 @@ TABLE_BOLD_THRESHOLD = 0.005 # auto-discover files from folder without "common.yaml" -FILES = glob.glob(os.path.join(FOLDER, "[!common]*.yaml")) +FILES = glob.glob(os.path.join(FOLDER, "*[!common]*.yaml")) def key_files(item): @@ -52,6 +52,7 @@ def key_files(item): "resnet18", "resnet101", "wide_resnet50_2", + "mobilenet_v2", "deeplabv3_resnet50", ] @@ -127,6 +128,8 @@ def get_results(file, logger, legend_on): elif "imagenet/prune" in file: graphers[0]._figure.gca().set_xlim([0, 87]) graphers[0]._figure.gca().set_ylim([-87, 5]) + elif "imagenet/retrain/mobilenet_v2" in file: + graphers[0]._figure.gca().set_ylim([-5, 0.5]) elif "imagenet/retrain/" in file: graphers[0]._figure.gca().set_ylim([-3.5, 1.5]) elif "imagenet/retraincascade" in file: @@ -317,6 +320,7 @@ def generate_table_entries( "resnet18": "ResNet18", "resnet101": "ResNet101", "wide_resnet50_2": "WRN50-2", + "mobilenet_v2": "MobileNetV2", "deeplabv3_resnet50": "DeeplabV3-ResNet50", } diff --git a/paper/node/README.md b/paper/node/README.md new file mode 100644 index 0000000..c24e74d --- /dev/null +++ b/paper/node/README.md @@ -0,0 +1,68 @@ +# Sparse flows: Pruning continuous-depth models +[Lucas Liebenwein*](https://people.csail.mit.edu/lucasl/), +[Ramin Hasani*](http://www.raminhasani.com), +[Alexander Amini](https://www.mit.edu/~amini/), +[Daniela Rus](http://danielarus.csail.mit.edu/) + +***Equal contribution** + +

+ +

+ + +Continuous deep learning architectures enable learning of flexible +probabilistic models for predictive modeling as neural ordinary differential +equations (ODEs), and for generative modeling as continuous normalizing flows. +In this work, we design a framework to decipher the internal dynamics of these +continuous depth models by pruning their network architectures. Our empirical +results suggest that pruning improves generalization for neural ODEs in +generative modeling. We empirically show that the improvement is because +pruning helps avoid mode- collapse and flatten the loss surface. Moreover, +pruning finds efficient neural ODE representations with up to 98% less +parameters compared to the original network, without loss of accuracy. We hope +our results will invigorate further research into the performance-size +trade-offs of modern continuous-depth models. + +## Setup +Check out the main [README.md](../../README.md) and the respective packages for +more information on the code base. + +## Overview + +### Run compression experiments +The experiment configurations are located [here](./param). To reproduce the +experiments for a specific configuration, run: +```bash +python -m experiment.main param/toy/ffjord/spirals/vanilla_l4_h64.yaml +``` + +The pruning experiments will be run fully automatically and store all the +results. + +### Experimental evaluations + +The [script](./script) contains the evaluation and plotting scripts to +evaluate and analyze the various experiments. Please take a look at each of +them to understand how to load the pruning experiments and how to analyze +the pruning experiments. + +Each plot and experiment presented in the paper can be reproduced this way. + +## Citation +Please cite the following paper when using our work. + +### Paper link +[Sparse flows: Pruning continuous-depth models](https://proceedings.neurips.cc/paper/2021/hash/bf1b2f4b901c21a1d8645018ea9aeb05-Abstract.html) + +### Bibtex +``` +@article{liebenwein2021sparse, + title={Sparse flows: Pruning continuous-depth models}, + author={Liebenwein, Lucas and Hasani, Ramin and Amini, Alexander and Rus, Daniela}, + journal={Advances in Neural Information Processing Systems}, + volume={34}, + pages={22628--22642}, + year={2021} +} +``` diff --git a/paper/node/param/cnf/cifar_multiscale.yaml b/paper/node/param/cnf/cifar_multiscale.yaml new file mode 100644 index 0000000..c08c237 --- /dev/null +++ b/paper/node/param/cnf/cifar_multiscale.yaml @@ -0,0 +1,68 @@ +network: + name: "ffjord_multiscale_cifar" + dataset: "CIFAR10" + outputSize: 10 + +training: + transformsTrain: + - type: RandomHorizontalFlip + kwargs: {} + transformsTest: [] + transformsFinal: + - type: Resize + kwargs: { size: 32 } + - type: ToTensor + kwargs: {} + - type: RandomNoise + kwargs: { "normalization": 255.0 } + + loss: "NLLBitsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLBits + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 200 # don't change that since it's hard-coded + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 0.0 + + numEpochs: 50 + earlyStopEpoch: 0 + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [45] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 12 + maxVal: 0.80 + minVal: 0.05 + + retrainIterations: -1 diff --git a/paper/node/param/cnf/mnist_multiscale.yaml b/paper/node/param/cnf/mnist_multiscale.yaml new file mode 100644 index 0000000..fab6fac --- /dev/null +++ b/paper/node/param/cnf/mnist_multiscale.yaml @@ -0,0 +1,66 @@ +network: + name: "ffjord_multiscale_mnist" + dataset: "MNIST" + outputSize: 10 + +training: + transformsTrain: [] + transformsTest: [] + transformsFinal: + - type: Resize + kwargs: { size: 28 } + - type: ToTensor + kwargs: {} + - type: RandomNoise + kwargs: { "normalization": 255.0 } + + loss: "NLLBitsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLBits + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 200 # don't change that since it's hard-coded + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 0.0 + + numEpochs: 50 + earlyStopEpoch: 0 + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [45] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 12 + maxVal: 0.80 + minVal: 0.05 + + retrainIterations: -1 diff --git a/paper/node/param/directories.yaml b/paper/node/param/directories.yaml new file mode 100644 index 0000000..ab6b9e8 --- /dev/null +++ b/paper/node/param/directories.yaml @@ -0,0 +1,6 @@ +# relative directories from where main.py was called +directories: + results: "./data/node/results" + trained_networks: null + training_data: "./data/training" + local_data: "./local" diff --git a/paper/node/param/tabular/bsds300/l3_hm20_f2_softplus.yaml b/paper/node/param/tabular/bsds300/l3_hm20_f2_softplus.yaml new file mode 100644 index 0000000..ced304b --- /dev/null +++ b/paper/node/param/tabular/bsds300/l3_hm20_f2_softplus.yaml @@ -0,0 +1,61 @@ +network: + name: "ffjord_l3_hm20_f2_softplus" + dataset: "Bsds300" + outputSize: 63 + +training: + transformsTrain: [] + transformsTest: [] + transformsFinal: [] + + loss: "NLLNatsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLNats + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 10000 + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 1.0e-6 + + numEpochs: 100 + earlyStopEpoch: 0 + + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [96, 99] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 8 + maxVal: 0.70 + minVal: 0.10 + + retrainIterations: -1 diff --git a/paper/node/param/tabular/gas/l3_hm20_f5_tanh.yaml b/paper/node/param/tabular/gas/l3_hm20_f5_tanh.yaml new file mode 100644 index 0000000..7678a0b --- /dev/null +++ b/paper/node/param/tabular/gas/l3_hm20_f5_tanh.yaml @@ -0,0 +1,61 @@ +network: + name: "ffjord_l3_hm20_f5_tanh" + dataset: "Gas" + outputSize: 8 + +training: + transformsTrain: [] + transformsTest: [] + transformsFinal: [] + + loss: "NLLNatsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLNats + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 1000 + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 1.0e-6 + + numEpochs: 30 + earlyStopEpoch: 0 + + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [25, 28] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 8 + maxVal: 0.70 + minVal: 0.10 + + retrainIterations: -1 diff --git a/paper/node/param/tabular/hepmass/l2_hm10_f10_softplus.yaml b/paper/node/param/tabular/hepmass/l2_hm10_f10_softplus.yaml new file mode 100644 index 0000000..04f28ea --- /dev/null +++ b/paper/node/param/tabular/hepmass/l2_hm10_f10_softplus.yaml @@ -0,0 +1,68 @@ +network: + name: "ffjord_l2_hm10_f10_softplus" + dataset: "Hepmass" + outputSize: 21 + +training: + transformsTrain: [] + transformsTest: [] + transformsFinal: [] + + loss: "NLLNatsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLNats + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 10000 + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 1.0e-6 + + numEpochs: 400 + earlyStopEpoch: 0 + + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [325, 375] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + numEpochs: 300 + earlyStopEpoch: 0 + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [250, 295] } + kwargs: { gamma: 0.1 } + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 12 + maxVal: 0.80 + minVal: 0.05 + + retrainIterations: -1 diff --git a/paper/node/param/tabular/miniboone/l2_hm20_f1_softplus.yaml b/paper/node/param/tabular/miniboone/l2_hm20_f1_softplus.yaml new file mode 100644 index 0000000..9cd4394 --- /dev/null +++ b/paper/node/param/tabular/miniboone/l2_hm20_f1_softplus.yaml @@ -0,0 +1,65 @@ +network: + name: "ffjord_l2_hm20_f1_softplus" + dataset: "Miniboone" + outputSize: 43 + +training: + transformsTrain: [] + transformsTest: [] + transformsFinal: [] + + loss: "NLLNatsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLNats + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 1000 + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 1.0e-6 + + numEpochs: 400 + earlyStopEpoch: 0 + + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [300, 350] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 12 + maxVal: 0.80 + minVal: 0.05 + - type: "linear" + numIntervals: 4 + maxVal: 0.04 + minVal: 0.01 + + retrainIterations: -1 diff --git a/paper/node/param/tabular/power/l3_hm10_f5_tanh.yaml b/paper/node/param/tabular/power/l3_hm10_f5_tanh.yaml new file mode 100644 index 0000000..71b07b4 --- /dev/null +++ b/paper/node/param/tabular/power/l3_hm10_f5_tanh.yaml @@ -0,0 +1,61 @@ +network: + name: "ffjord_l3_hm10_f5_tanh" + dataset: "Power" + outputSize: 6 + +training: + transformsTrain: [] + transformsTest: [] + transformsFinal: [] + + loss: "NLLNatsLoss" + lossKwargs: {} + + metricsTest: + - type: NLLNats + kwargs: {} + - type: Dummy + kwargs: {} + + batchSize: 10000 + + optimizer: "Adam" + optimizerKwargs: + lr: 1.0e-3 + weight_decay: 1.0e-6 + + numEpochs: 100 + earlyStopEpoch: 0 + + enableAMP: False + + lrSchedulers: + - type: MultiStepLR + stepKwargs: { milestones: [90, 97] } + kwargs: { gamma: 0.1 } + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "ThresNet" + - "FilterThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 8 + maxVal: 0.70 + minVal: 0.10 + + retrainIterations: -1 diff --git a/paper/node/param/toy/ffjord/common/experiment.yaml b/paper/node/param/toy/ffjord/common/experiment.yaml new file mode 100644 index 0000000..511c93b --- /dev/null +++ b/paper/node/param/toy/ffjord/common/experiment.yaml @@ -0,0 +1,29 @@ +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "FilterThresNet" + - "ThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 20 + maxVal: 0.80 + minVal: 0.20 + - type: "linear" + numIntervals: 9 + maxVal: 0.18 + minVal: 0.02 + + retrainIterations: -1 diff --git a/paper/node/param/toy/ffjord/common/sweep_activation_da.yaml b/paper/node/param/toy/ffjord/common/sweep_activation_da.yaml new file mode 100644 index 0000000..fcbe321 --- /dev/null +++ b/paper/node/param/toy/ffjord/common/sweep_activation_da.yaml @@ -0,0 +1,12 @@ +file: "paper/node/param/toy/ffjord/common/experiment.yaml" + +# we vary the activation function +customizations: + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l4_h64_softplus_da" + - key: ["network", "name"] + value: "ffjord_l4_h64_tanh_da" + - key: ["network", "name"] + value: "ffjord_l4_h64_relu_da" diff --git a/paper/node/param/toy/ffjord/common/sweep_model_da.yaml b/paper/node/param/toy/ffjord/common/sweep_model_da.yaml new file mode 100644 index 0000000..d906115 --- /dev/null +++ b/paper/node/param/toy/ffjord/common/sweep_model_da.yaml @@ -0,0 +1,12 @@ +file: "paper/node/param/toy/ffjord/common/experiment.yaml" + +# now we vary the architecture +customizations: + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l8_h64_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l2_h128_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l2_h64_sigmoid_da" diff --git a/paper/node/param/toy/ffjord/common/sweep_model_large.yaml b/paper/node/param/toy/ffjord/common/sweep_model_large.yaml new file mode 100644 index 0000000..ad6de78 --- /dev/null +++ b/paper/node/param/toy/ffjord/common/sweep_model_large.yaml @@ -0,0 +1,14 @@ +file: "paper/node/param/toy/ffjord/common/experiment.yaml" + +# now we vary the architecture +customizations: + - key: ["network", "name"] + value: "ffjord_l8_h37_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l6_h45_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l3_h90_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l2_h1700_sigmoid_da" diff --git a/paper/node/param/toy/ffjord/common/sweep_model_med.yaml b/paper/node/param/toy/ffjord/common/sweep_model_med.yaml new file mode 100644 index 0000000..d37fc6b --- /dev/null +++ b/paper/node/param/toy/ffjord/common/sweep_model_med.yaml @@ -0,0 +1,14 @@ +file: "paper/node/param/toy/ffjord/common/experiment.yaml" + +# now we vary the architecture with approximately fixed # of parameters +customizations: + - key: ["network", "name"] + value: "ffjord_l8_h18_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l6_h22_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l4_h30_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l3_h43_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l2_h400_sigmoid_da" diff --git a/paper/node/param/toy/ffjord/common/sweep_model_small.yaml b/paper/node/param/toy/ffjord/common/sweep_model_small.yaml new file mode 100644 index 0000000..9dcccf5 --- /dev/null +++ b/paper/node/param/toy/ffjord/common/sweep_model_small.yaml @@ -0,0 +1,14 @@ +file: "paper/node/param/toy/ffjord/common/experiment.yaml" + +# now we vary the architecture with approximately fixed # of parameters +customizations: + - key: ["network", "name"] + value: "ffjord_l8_h10_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l6_h12_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l4_h17_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l3_h23_sigmoid_da" + - key: ["network", "name"] + value: "ffjord_l2_h128_sigmoid_da" diff --git a/paper/node/param/toy/ffjord/common/sweep_solver.yaml b/paper/node/param/toy/ffjord/common/sweep_solver.yaml new file mode 100644 index 0000000..4f75ce2 --- /dev/null +++ b/paper/node/param/toy/ffjord/common/sweep_solver.yaml @@ -0,0 +1,16 @@ +file: "paper/node/param/toy/ffjord/common/experiment.yaml" + +# we vary the solver +customizations: + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_rk4_autograd" + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_rk4_adjoint" + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_dopri_autograd" + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_dopri_adjoint" + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_euler_autograd" + - key: ["network", "name"] + value: "ffjord_l4_h64_sigmoid_euler_adjoint" diff --git a/paper/node/param/toy/ffjord/gaussians/l2_h128_sigmoid_da.yaml b/paper/node/param/toy/ffjord/gaussians/l2_h128_sigmoid_da.yaml new file mode 100644 index 0000000..28818a0 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/l2_h128_sigmoid_da.yaml @@ -0,0 +1,8 @@ +network: + file: "paper/node/param/toy/ffjord/gaussians/network_da.yaml" + name: "ffjord_l2_h128_sigmoid_da" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/ffjord/gaussians/l4_h64_sigmoid_da.yaml b/paper/node/param/toy/ffjord/gaussians/l4_h64_sigmoid_da.yaml new file mode 100644 index 0000000..05ed81c --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/l4_h64_sigmoid_da.yaml @@ -0,0 +1,7 @@ +network: + file: "paper/node/param/toy/ffjord/gaussians/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/ffjord/gaussians/network_da.yaml b/paper/node/param/toy/ffjord/gaussians/network_da.yaml new file mode 100644 index 0000000..4dcacdc --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/network_da.yaml @@ -0,0 +1,3 @@ +name: "ffjord_l4_h64_sigmoid_da" +dataset: "ToyGaussians" +outputSize: 2 diff --git a/paper/node/param/toy/ffjord/gaussians/sweep_activation_da.yaml b/paper/node/param/toy/ffjord/gaussians/sweep_activation_da.yaml new file mode 100644 index 0000000..b1c848b --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/sweep_activation_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussians/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/ffjord/common/sweep_activation_da.yaml" diff --git a/paper/node/param/toy/ffjord/gaussians/sweep_model_da.yaml b/paper/node/param/toy/ffjord/gaussians/sweep_model_da.yaml new file mode 100644 index 0000000..9c1c004 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/sweep_model_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussians/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +# now we vary the model +file: "paper/node/param/toy/ffjord/common/sweep_model_da.yaml" diff --git a/paper/node/param/toy/ffjord/gaussians/sweep_model_small.yaml b/paper/node/param/toy/ffjord/gaussians/sweep_model_small.yaml new file mode 100644 index 0000000..472c5fa --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/sweep_model_small.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussians/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +# now we vary the model +file: "paper/node/param/toy/ffjord/common/sweep_model_small.yaml" diff --git a/paper/node/param/toy/ffjord/gaussians/sweep_solver.yaml b/paper/node/param/toy/ffjord/gaussians/sweep_solver.yaml new file mode 100644 index 0000000..80f89fa --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/sweep_solver.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussians/network.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +# now we vary the solver +file: "paper/node/param/toy/ffjord/common/sweep_solver.yaml" diff --git a/paper/node/param/toy/ffjord/gaussians/training.yaml b/paper/node/param/toy/ffjord/gaussians/training.yaml new file mode 100644 index 0000000..e9b2554 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/training.yaml @@ -0,0 +1,25 @@ +transformsTrain: [] +transformsTest: [] +transformsFinal: [] + +loss: "NLLPriorLoss" +lossKwargs: {} + +metricsTest: + - type: NLLPrior + kwargs: {} + - type: Dummy + kwargs: {} + +batchSize: 1024 + +optimizer: "AdamW" +optimizerKwargs: + lr: 5.0e-3 + weight_decay: 1.0e-5 + +numEpochs: 100 + +enableAMP: False + +lrSchedulers: [] diff --git a/paper/node/param/toy/ffjord/gaussians/vanilla_l2_h128.yaml b/paper/node/param/toy/ffjord/gaussians/vanilla_l2_h128.yaml new file mode 100644 index 0000000..2326bdd --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussians/vanilla_l2_h128.yaml @@ -0,0 +1,8 @@ +network: + file: "paper/node/param/toy/ffjord/gaussians/network_da.yaml" + name: "cnf_l2_h128_sigmoid_da" + +training: + file: "paper/node/param/toy/ffjord/gaussians/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/l4_h64_sigmoid_da.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/l4_h64_sigmoid_da.yaml new file mode 100644 index 0000000..cac948f --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/l4_h64_sigmoid_da.yaml @@ -0,0 +1,7 @@ +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml new file mode 100644 index 0000000..5e1865e --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml @@ -0,0 +1,3 @@ +name: "ffjord_l4_h64_sigmoid_da" +dataset: "ToyGaussiansSpiral" +outputSize: 2 diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/sweep_act_da.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_act_da.yaml new file mode 100644 index 0000000..5abf6ba --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_act_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/ffjord/common/sweep_activation_da.yaml" diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/sweep_model_da.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_model_da.yaml new file mode 100644 index 0000000..9dc88e0 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_model_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +# now we vary the model +file: "paper/node/param/toy/ffjord/common/sweep_model_da.yaml" diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/sweep_model_med.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_model_med.yaml new file mode 100644 index 0000000..b395378 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_model_med.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +# now we vary the model +file: "paper/node/param/toy/ffjord/common/sweep_model_med.yaml" diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/sweep_opt_ref.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_opt_ref.yaml new file mode 100644 index 0000000..6134a06 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_opt_ref.yaml @@ -0,0 +1,56 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: [] + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 2 + maxVal: 0.80 + minVal: 0.20 + + retrainIterations: -1 + +customizations: + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-5 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-4 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-5 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-4 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-7 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-5 } diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/sweep_solver.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_solver.yaml new file mode 100644 index 0000000..784f5b9 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/sweep_solver.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +# now we vary the solver +file: "paper/node/param/toy/ffjord/common/sweep_solver.yaml" diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/training.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/training.yaml new file mode 100644 index 0000000..21f219c --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/training.yaml @@ -0,0 +1,25 @@ +transformsTrain: [] +transformsTest: [] +transformsFinal: [] + +loss: "NLLPriorLoss" +lossKwargs: {} + +metricsTest: + - type: NLLPrior + kwargs: {} + - type: Dummy + kwargs: {} + +batchSize: 1024 + +optimizer: "AdamW" +optimizerKwargs: + lr: 0.05 + weight_decay: 0.01 + +numEpochs: 100 + +enableAMP: False + +lrSchedulers: [] diff --git a/paper/node/param/toy/ffjord/gaussiansspiral/vanilla_l4_h64.yaml b/paper/node/param/toy/ffjord/gaussiansspiral/vanilla_l4_h64.yaml new file mode 100644 index 0000000..feb1d10 --- /dev/null +++ b/paper/node/param/toy/ffjord/gaussiansspiral/vanilla_l4_h64.yaml @@ -0,0 +1,8 @@ +network: + file: "paper/node/param/toy/ffjord/gaussiansspiral/network_da.yaml" + name: "cnf_l4_h64_sigmoid_da" + +training: + file: "paper/node/param/toy/ffjord/gaussiansspiral/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/ffjord/spirals/l4_h64_sigmoid_da.yaml b/paper/node/param/toy/ffjord/spirals/l4_h64_sigmoid_da.yaml new file mode 100644 index 0000000..b4e4bef --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/l4_h64_sigmoid_da.yaml @@ -0,0 +1,7 @@ +network: + file: "paper/node/param/toy/ffjord/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/spirals/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/ffjord/spirals/network_da.yaml b/paper/node/param/toy/ffjord/spirals/network_da.yaml new file mode 100644 index 0000000..3965cba --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/network_da.yaml @@ -0,0 +1,3 @@ +name: "ffjord_l4_h64_sigmoid_da" +dataset: "ToySpirals2" +outputSize: 2 diff --git a/paper/node/param/toy/ffjord/spirals/sweep_act_da.yaml b/paper/node/param/toy/ffjord/spirals/sweep_act_da.yaml new file mode 100644 index 0000000..c6f1dd0 --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/sweep_act_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/spirals/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/ffjord/common/sweep_activation_da.yaml" diff --git a/paper/node/param/toy/ffjord/spirals/sweep_model_da.yaml b/paper/node/param/toy/ffjord/spirals/sweep_model_da.yaml new file mode 100644 index 0000000..22365fd --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/sweep_model_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/spirals/training.yaml" + +# now we vary the model +file: "paper/node/param/toy/ffjord/common/sweep_model_da.yaml" diff --git a/paper/node/param/toy/ffjord/spirals/sweep_model_large.yaml b/paper/node/param/toy/ffjord/spirals/sweep_model_large.yaml new file mode 100644 index 0000000..b0a30e1 --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/sweep_model_large.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/spirals/training.yaml" + +# now we vary the model +file: "paper/node/param/toy/ffjord/common/sweep_model_large.yaml" diff --git a/paper/node/param/toy/ffjord/spirals/sweep_opt_ref.yaml b/paper/node/param/toy/ffjord/spirals/sweep_opt_ref.yaml new file mode 100644 index 0000000..43067d6 --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/sweep_opt_ref.yaml @@ -0,0 +1,56 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/ffjord/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/ffjord/spirals/training.yaml" + +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: [] + mode: "cascade" + + numRepetitions: 1 + numNets: 1 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 2 + maxVal: 0.80 + minVal: 0.20 + + retrainIterations: -1 + +customizations: + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-5 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-4 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-5 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-4 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-7 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-5 } diff --git a/paper/node/param/toy/ffjord/spirals/training.yaml b/paper/node/param/toy/ffjord/spirals/training.yaml new file mode 100644 index 0000000..7551456 --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/training.yaml @@ -0,0 +1,25 @@ +transformsTrain: [] +transformsTest: [] +transformsFinal: [] + +loss: "NLLPriorLoss" +lossKwargs: {} + +metricsTest: + - type: NLLPrior + kwargs: {} + - type: Dummy + kwargs: {} + +batchSize: 1024 + +optimizer: "AdamW" +optimizerKwargs: + lr: 0.05 + weight_decay: 1.0e-6 + +numEpochs: 100 + +enableAMP: False + +lrSchedulers: [] diff --git a/paper/node/param/toy/ffjord/spirals/vanilla_l4_h64.yaml b/paper/node/param/toy/ffjord/spirals/vanilla_l4_h64.yaml new file mode 100644 index 0000000..d8cc39e --- /dev/null +++ b/paper/node/param/toy/ffjord/spirals/vanilla_l4_h64.yaml @@ -0,0 +1,8 @@ +network: + file: "paper/node/param/toy/ffjord/spirals/network_da.yaml" + name: "cnf_l4_h64_sigmoid_da_high_tol" + +training: + file: "paper/node/param/toy/ffjord/spirals/training.yaml" + +file: "paper/node/param/toy/ffjord/common/experiment.yaml" diff --git a/paper/node/param/toy/node/common/experiment.yaml b/paper/node/param/toy/node/common/experiment.yaml new file mode 100644 index 0000000..511c93b --- /dev/null +++ b/paper/node/param/toy/node/common/experiment.yaml @@ -0,0 +1,29 @@ +file: "paper/node/param/directories.yaml" + +retraining: + startEpoch: 0 + +experiments: + methods: + - "FilterThresNet" + - "ThresNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 20 + maxVal: 0.80 + minVal: 0.20 + - type: "linear" + numIntervals: 9 + maxVal: 0.18 + minVal: 0.02 + + retrainIterations: -1 diff --git a/paper/node/param/toy/node/common/sweep_activation_da.yaml b/paper/node/param/toy/node/common/sweep_activation_da.yaml new file mode 100644 index 0000000..b6601af --- /dev/null +++ b/paper/node/param/toy/node/common/sweep_activation_da.yaml @@ -0,0 +1,12 @@ +file: "paper/node/param/toy/node/common/experiment.yaml" + +# now we vary the activation function +customizations: + - key: ["network", "name"] + value: "node_l2_h64_tanh_da" + - key: ["network", "name"] + value: "node_l2_h64_sigmoid_da" + - key: ["network", "name"] + value: "node_l2_h64_softplus_da" + - key: ["network", "name"] + value: "node_l2_h64_relu_da" diff --git a/paper/node/param/toy/node/common/sweep_model_da.yaml b/paper/node/param/toy/node/common/sweep_model_da.yaml new file mode 100644 index 0000000..7076374 --- /dev/null +++ b/paper/node/param/toy/node/common/sweep_model_da.yaml @@ -0,0 +1,14 @@ +file: "paper/node/param/toy/node/common/experiment.yaml" + +# now we vary the architecture +customizations: + - key: ["network", "name"] + value: "node_l2_h32_tanh_da" + - key: ["network", "name"] + value: "node_l2_h64_tanh_da" + - key: ["network", "name"] + value: "node_l2_h128_tanh_da" + - key: ["network", "name"] + value: "node_l4_h32_tanh_da" + - key: ["network", "name"] + value: "node_l4_h128_tanh_da" diff --git a/paper/node/param/toy/node/common/sweep_opt.yaml b/paper/node/param/toy/node/common/sweep_opt.yaml new file mode 100644 index 0000000..c6c86e0 --- /dev/null +++ b/paper/node/param/toy/node/common/sweep_opt.yaml @@ -0,0 +1,28 @@ +file: "paper/node/param/toy/node/common/experiment.yaml" + +# now we vary the optimizer +customizations: + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-5 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.05, "weight_decay": 1.0e-4 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-5 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.01, "weight_decay": 1.0e-4 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-7 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-6 } + - key: ["training", "optimizerKwargs"] + value: { "lr": 0.005, "weight_decay": 1.0e-5 } diff --git a/paper/node/param/toy/node/common/sweep_solver.yaml b/paper/node/param/toy/node/common/sweep_solver.yaml new file mode 100644 index 0000000..9f610b5 --- /dev/null +++ b/paper/node/param/toy/node/common/sweep_solver.yaml @@ -0,0 +1,16 @@ +file: "paper/node/param/toy/node/common/experiment.yaml" + +# now we vary the solver +customizations: + - key: ["network", "name"] + value: "node_l4_h32_tanh_dopri_adjoint" + - key: ["network", "name"] + value: "node_l4_h32_tanh_dopri_autograd" + - key: ["network", "name"] + value: "node_l4_h32_tanh_rk4_adjoint" + - key: ["network", "name"] + value: "node_l4_h32_tanh_rk4_autograd" + - key: ["network", "name"] + value: "node_l4_h32_tanh_euler_adjoint" + - key: ["network", "name"] + value: "node_l4_h32_tanh_euler_autograd" diff --git a/paper/node/param/toy/node/concentric/l2_h128_tanh_da.yaml b/paper/node/param/toy/node/concentric/l2_h128_tanh_da.yaml new file mode 100644 index 0000000..3425d86 --- /dev/null +++ b/paper/node/param/toy/node/concentric/l2_h128_tanh_da.yaml @@ -0,0 +1,8 @@ +network: + name: "node_l2_h128_tanh_da" + file: "paper/node/param/toy/node/concentric/network_da.yaml" + +training: + file: "paper/node/param/toy/node/concentric/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/concentric/l2_h64_tanh_da.yaml b/paper/node/param/toy/node/concentric/l2_h64_tanh_da.yaml new file mode 100644 index 0000000..eb16d82 --- /dev/null +++ b/paper/node/param/toy/node/concentric/l2_h64_tanh_da.yaml @@ -0,0 +1,7 @@ +network: + file: "paper/node/param/toy/node/concentric/network_da.yaml" + +training: + file: "paper/node/param/toy/node/concentric/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/concentric/network_da.yaml b/paper/node/param/toy/node/concentric/network_da.yaml new file mode 100644 index 0000000..d5f9fd2 --- /dev/null +++ b/paper/node/param/toy/node/concentric/network_da.yaml @@ -0,0 +1,3 @@ +name: "node_l2_h64_tanh_da" +dataset: "ToyConcentric" +outputSize: 2 diff --git a/paper/node/param/toy/node/concentric/sweep_activation_da.yaml b/paper/node/param/toy/node/concentric/sweep_activation_da.yaml new file mode 100644 index 0000000..b23ae9c --- /dev/null +++ b/paper/node/param/toy/node/concentric/sweep_activation_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/concentric/network_da.yaml" + +training: + file: "paper/node/param/toy/node/concentric/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_activation_da.yaml" diff --git a/paper/node/param/toy/node/concentric/sweep_model_da.yaml b/paper/node/param/toy/node/concentric/sweep_model_da.yaml new file mode 100644 index 0000000..f3ce6bc --- /dev/null +++ b/paper/node/param/toy/node/concentric/sweep_model_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/concentric/network_da.yaml" + +training: + file: "paper/node/param/toy/node/concentric/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_model_da.yaml" diff --git a/paper/node/param/toy/node/concentric/sweep_opt_da.yaml b/paper/node/param/toy/node/concentric/sweep_opt_da.yaml new file mode 100644 index 0000000..727d46b --- /dev/null +++ b/paper/node/param/toy/node/concentric/sweep_opt_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/concentric/network_da.yaml" + +training: + file: "paper/node/param/toy/node/concentric/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_opt.yaml" diff --git a/paper/node/param/toy/node/concentric/sweep_solver.yaml b/paper/node/param/toy/node/concentric/sweep_solver.yaml new file mode 100644 index 0000000..0d9a67b --- /dev/null +++ b/paper/node/param/toy/node/concentric/sweep_solver.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/concentric/network.yaml" + +training: + file: "paper/node/param/toy/node/concentric/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_solver.yaml" diff --git a/paper/node/param/toy/node/concentric/training.yaml b/paper/node/param/toy/node/concentric/training.yaml new file mode 100644 index 0000000..6c7e65d --- /dev/null +++ b/paper/node/param/toy/node/concentric/training.yaml @@ -0,0 +1,25 @@ +transformsTrain: [] +transformsTest: [] +transformsFinal: [] + +loss: "CrossEntropyLoss" +lossKwargs: { reduction: mean } + +metricsTest: + - type: TopK + kwargs: { topk: 1 } + - type: MCorr + kwargs: {} + +batchSize: 128 + +optimizer: "Adam" +optimizerKwargs: + lr: 0.01 + weight_decay: 1.0e-5 + +numEpochs: 50 + +enableAMP: False + +lrSchedulers: [] diff --git a/paper/node/param/toy/node/moons/l2_h128_tanh_da.yaml b/paper/node/param/toy/node/moons/l2_h128_tanh_da.yaml new file mode 100644 index 0000000..c6849fd --- /dev/null +++ b/paper/node/param/toy/node/moons/l2_h128_tanh_da.yaml @@ -0,0 +1,8 @@ +network: + name: "node_l2_h128_tanh_da" + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/moons/l2_h32_tanh_da.yaml b/paper/node/param/toy/node/moons/l2_h32_tanh_da.yaml new file mode 100644 index 0000000..3c6ab7d --- /dev/null +++ b/paper/node/param/toy/node/moons/l2_h32_tanh_da.yaml @@ -0,0 +1,8 @@ +network: + name: "node_l2_h32_tanh_da" + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/moons/l2_h3_tanh_da.yaml b/paper/node/param/toy/node/moons/l2_h3_tanh_da.yaml new file mode 100644 index 0000000..d1dbf91 --- /dev/null +++ b/paper/node/param/toy/node/moons/l2_h3_tanh_da.yaml @@ -0,0 +1,8 @@ +network: + name: "node_l2_h3_tanh_da" + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/moons/l2_h64_tanh_da.yaml b/paper/node/param/toy/node/moons/l2_h64_tanh_da.yaml new file mode 100644 index 0000000..3386573 --- /dev/null +++ b/paper/node/param/toy/node/moons/l2_h64_tanh_da.yaml @@ -0,0 +1,7 @@ +network: + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/moons/network_da.yaml b/paper/node/param/toy/node/moons/network_da.yaml new file mode 100644 index 0000000..0aa3519 --- /dev/null +++ b/paper/node/param/toy/node/moons/network_da.yaml @@ -0,0 +1,3 @@ +name: "node_l2_h64_tanh_da" +dataset: "ToyMoons" +outputSize: 2 diff --git a/paper/node/param/toy/node/moons/sweep_activation_da.yaml b/paper/node/param/toy/node/moons/sweep_activation_da.yaml new file mode 100644 index 0000000..b24561c --- /dev/null +++ b/paper/node/param/toy/node/moons/sweep_activation_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_activation_da.yaml" diff --git a/paper/node/param/toy/node/moons/sweep_model_da.yaml b/paper/node/param/toy/node/moons/sweep_model_da.yaml new file mode 100644 index 0000000..ff7a051 --- /dev/null +++ b/paper/node/param/toy/node/moons/sweep_model_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_model_da.yaml" diff --git a/paper/node/param/toy/node/moons/sweep_opt_da.yaml b/paper/node/param/toy/node/moons/sweep_opt_da.yaml new file mode 100644 index 0000000..d56369c --- /dev/null +++ b/paper/node/param/toy/node/moons/sweep_opt_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/moons/network_da.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_opt.yaml" diff --git a/paper/node/param/toy/node/moons/sweep_solver.yaml b/paper/node/param/toy/node/moons/sweep_solver.yaml new file mode 100644 index 0000000..a66cc4a --- /dev/null +++ b/paper/node/param/toy/node/moons/sweep_solver.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/moons/network.yaml" + +training: + file: "paper/node/param/toy/node/moons/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_solver.yaml" diff --git a/paper/node/param/toy/node/moons/training.yaml b/paper/node/param/toy/node/moons/training.yaml new file mode 100644 index 0000000..e091f28 --- /dev/null +++ b/paper/node/param/toy/node/moons/training.yaml @@ -0,0 +1,25 @@ +transformsTrain: [] +transformsTest: [] +transformsFinal: [] + +loss: "CrossEntropyLoss" +lossKwargs: { reduction: mean } + +metricsTest: + - type: TopK + kwargs: { topk: 1 } + - type: MCorr + kwargs: {} + +batchSize: 128 + +optimizer: "Adam" +optimizerKwargs: + lr: 0.01 + weight_decay: 1.0e-4 + +numEpochs: 50 + +enableAMP: False + +lrSchedulers: [] diff --git a/paper/node/param/toy/node/spirals/l2_h64_tanh_da.yaml b/paper/node/param/toy/node/spirals/l2_h64_tanh_da.yaml new file mode 100644 index 0000000..300f3f0 --- /dev/null +++ b/paper/node/param/toy/node/spirals/l2_h64_tanh_da.yaml @@ -0,0 +1,7 @@ +network: + file: "paper/node/param/toy/node/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/node/spirals/training.yaml" + +file: "paper/node/param/toy/node/common/experiment.yaml" diff --git a/paper/node/param/toy/node/spirals/network_da.yaml b/paper/node/param/toy/node/spirals/network_da.yaml new file mode 100644 index 0000000..d8538c1 --- /dev/null +++ b/paper/node/param/toy/node/spirals/network_da.yaml @@ -0,0 +1,3 @@ +name: "node_l2_h64_tanh_da" +dataset: "ToySpirals" +outputSize: 2 diff --git a/paper/node/param/toy/node/spirals/sweep_activation_da.yaml b/paper/node/param/toy/node/spirals/sweep_activation_da.yaml new file mode 100644 index 0000000..9f7cc31 --- /dev/null +++ b/paper/node/param/toy/node/spirals/sweep_activation_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/node/spirals/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_activation_da.yaml" diff --git a/paper/node/param/toy/node/spirals/sweep_model_da.yaml b/paper/node/param/toy/node/spirals/sweep_model_da.yaml new file mode 100644 index 0000000..73be2cc --- /dev/null +++ b/paper/node/param/toy/node/spirals/sweep_model_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/node/spirals/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_model_da.yaml" diff --git a/paper/node/param/toy/node/spirals/sweep_opt_da.yaml b/paper/node/param/toy/node/spirals/sweep_opt_da.yaml new file mode 100644 index 0000000..546f1b0 --- /dev/null +++ b/paper/node/param/toy/node/spirals/sweep_opt_da.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/spirals/network_da.yaml" + +training: + file: "paper/node/param/toy/node/spirals/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_opt.yaml" diff --git a/paper/node/param/toy/node/spirals/sweep_solver.yaml b/paper/node/param/toy/node/spirals/sweep_solver.yaml new file mode 100644 index 0000000..b5eac12 --- /dev/null +++ b/paper/node/param/toy/node/spirals/sweep_solver.yaml @@ -0,0 +1,9 @@ +# Base Experiment is the classic retraining experiment +network: + file: "paper/node/param/toy/node/spirals/network.yaml" + +training: + file: "paper/node/param/toy/node/spirals/training.yaml" + +# now we vary the activation function +file: "paper/node/param/toy/node/common/sweep_solver.yaml" diff --git a/paper/node/param/toy/node/spirals/training.yaml b/paper/node/param/toy/node/spirals/training.yaml new file mode 100644 index 0000000..ae28b7e --- /dev/null +++ b/paper/node/param/toy/node/spirals/training.yaml @@ -0,0 +1,25 @@ +transformsTrain: [] +transformsTest: [] +transformsFinal: [] + +loss: "CrossEntropyLoss" +lossKwargs: { reduction: mean } + +metricsTest: + - type: TopK + kwargs: { topk: 1 } + - type: MCorr + kwargs: {} + +batchSize: 128 + +optimizer: "Adam" +optimizerKwargs: + lr: 0.01 + weight_decay: 1.0e-5 + +numEpochs: 100 + +enableAMP: False + +lrSchedulers: [] diff --git a/paper/node/script/plot_datasets.py b/paper/node/script/plot_datasets.py new file mode 100644 index 0000000..0622b39 --- /dev/null +++ b/paper/node/script/plot_datasets.py @@ -0,0 +1,67 @@ +# %% Setup script +import random +import os +import sys + +import numpy as np +from IPython import get_ipython +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import torch + +from torchprune.util import datasets + +IN_JUPYTER = True +try: + get_ipython().run_line_magic("matplotlib", "inline") # show plots +except AttributeError: + IN_JUPYTER = False + +# switch to root folder for data +folder = os.path.abspath("") +if "paper/node/script" in folder: + src_folder = os.path.join(folder, "../../..") + os.chdir(src_folder) + +# %% define plotting function +def plot_dataset(dset): + # put them in a loader and retrieve batch + loader = torch.utils.data.DataLoader( + dataset=dset, + batch_size=min(len(dset), 1000), + num_workers=0, + shuffle=False, + ) + x_data, y_data = next(loader.__iter__()) + + fig = plt.figure(figsize=(3, 3)) + ax = fig.add_subplot(111) + ax.scatter(x_data[:, 0], x_data[:, 1], s=1, c=y_data) + + +# %% go through datasets and plot each of them +dset_list = [ + # "ToyConcentric", + # "ToyMoons", + # "ToySpirals", + # "ToySpirals2", + # "ToyGaussians", + # "ToyGaussiansSpiral", + # "ToyDiffeqml", + "Bsds300", + "Hepmass", + "Miniboone", + "Power", + "Gas", +] +dsets = [] +for dset_name in dset_list: + dsets.append( + getattr(datasets, dset_name)( + root="./local", + file_dir="./data/training", + download=True, + train=False, + ) + ) + plot_dataset(dsets[-1]) diff --git a/paper/node/script/plots2d.py b/paper/node/script/plots2d.py new file mode 100644 index 0000000..56cdc21 --- /dev/null +++ b/paper/node/script/plots2d.py @@ -0,0 +1,209 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import os + +from torchdyn import utils as plot + + +def get_mesh(X_data, N=1000): + X_data = X_data.detach() + spacing = [torch.linspace(x_i.min(), x_i.max(), N) for x_i in X_data.T] + return torch.stack(torch.meshgrid(*spacing), dim=-1) + + +def plot_for_sweep(**kwargs): + """Plot the desired sweep plot.""" + plot_2d_boundary(**kwargs) + + +def plot_dataset(x_data, y_data, **kwargs): + x_data = x_data.detach().cpu() + y_data = y_data.detach().cpu() + colors = ["orange", "blue"] + fig = plt.figure(figsize=(3, 3)) + ax = fig.add_subplot(111) + for i in range(len(x_data)): + ax.scatter( + x_data[i, 0], x_data[i, 1], s=1, color=colors[y_data[i].int()] + ) + + +def plot_2d_boundary( + model, + x_data, + y_data, + mesh, + num_classes=2, + axis=None, + **kwargs, +): + x_data = x_data.detach().cpu() + y_data = y_data.detach().cpu() + preds = torch.argmax(nn.Softmax(-1)(model(mesh)), dim=-1) + preds = preds.detach().cpu().reshape(mesh.size(0), mesh.size(1)) + if axis is None: + plt.figure(figsize=(8, 4)) + axis = plt.gca() + + contour_colors = ["navy", "tab:orange"] + scatter_colors = ["midnightblue", "darkorange"] + axis.contourf( + mesh[:, :, 0].detach().cpu(), + mesh[:, :, 1].detach().cpu(), + preds, + colors=contour_colors, + alpha=0.4, + levels=1, + ) + for i in range(num_classes): + axis.scatter( + x_data[y_data == i, 0], + x_data[y_data == i, 1], + alpha=1.0, + s=6.0, + linewidths=0, + c=scatter_colors[i], + edgecolors=None, + ) + + +def plot_static_vector_field(model, x_data, t=0.0, N=100, axis=None, **kwargs): + device = next(model.parameters()).device + x = torch.linspace(x_data[:, 0].min(), x_data[:, 0].max(), N) + y = torch.linspace(x_data[:, 1].min(), x_data[:, 1].max(), N) + X, Y = torch.meshgrid(x, y) + + U, V = torch.zeros_like(X), torch.zeros_like(Y) + + for i in range(N): + for j in range(N): + p = torch.cat( + [X[i, j].reshape(1, 1), Y[i, j].reshape(1, 1)], 1 + ).to(device) + O = model.defunc(t, p).detach().cpu() + U[i, j], V[i, j] = O[0, 0], O[0, 1] + + # convert to cpu numpy + X, Y, U, V = [tnsr.cpu().numpy() for tnsr in (X, Y, U, V)] + + if axis is None: + fig = plt.figure(figsize=(3, 3)) + axis = fig.add_subplot(111) + axis.contourf( + X, + Y, + np.sqrt(U ** 2 + V ** 2), + cmap="RdYlBu", + levels=1000, + alpha=0.6, + ) + axis.streamplot( + X.T, + Y.T, + U.T, + V.T, + color="k", + density=1.5, + linewidth=0.7, + arrowsize=0.7, + arrowstyle="-|>", + ) + + axis.set_xlim([x.min(), x.max()]) + axis.set_ylim([y.min(), y.max()]) + axis.set_xlabel(r"$h_0$") + axis.set_ylabel(r"$h_1$") + axis.set_title("Learned Vector Field") + + +def plot_2D_state_space(trajectory, y_data, n_lines, **kwargs): + plot.plot_2D_state_space(trajectory, y_data, n_lines) + + +def plot_2D_depth_trajectory( + s_span, trajectory, y_data, axis1=None, axis2=None, **kwargs +): + if axis1 is None or axis2 is None: + fig = plt.figure(figsize=(8, 2)) + axis1 = fig.add_subplot(121) + axis2 = fig.add_subplot(122) + + colors = ["midnightblue", "darkorange"] + + for i, label in enumerate(y_data): + color = colors[int(label)] + axis1.plot(s_span, trajectory[:, i, 0], color=color, alpha=0.1) + axis2.plot(s_span, trajectory[:, i, 1], color=color, alpha=0.1) + + axis1.set_xlabel(r"Depth") + axis1.set_ylabel(r"Dim. 0") + axis2.set_xlabel(r"Depth") + axis2.set_ylabel(r"Dim. 1") + + +def prepare_data(model, loader, compute_yhat=False, **kwargs): + """Prepare and return the required data.""" + # setup + model = model.model + device = next(model.parameters()).device + plt.style.use("default") + + # collect data from loader + x_data, y_data = None, None + for x_b, y_b in loader: + if x_data is None: + x_data = x_b + y_data = y_b + else: + x_data = torch.cat((x_data, x_b)) + y_data = torch.cat((y_data, y_b)) + x_data, y_data = x_data.to(device), y_data.to(device) + s_span = torch.linspace(0, 1, 100) + trajectory = model.trajectory(x_data, s_span.to(device)).detach().cpu() + mesh = get_mesh(x_data).to(device) + + data = { + "x_data": x_data, + "y_data": y_data, + "n_lines": len(x_data), + "model": model, + "device": device, + "s_span": s_span, + "trajectory": trajectory, + "mesh": mesh, + } + if compute_yhat: + data["y_hat"] = model(x_data).argmax(dim=1) + + return data + + +def plot_all(model, loader, plot_folder=None, all_p=False): + # default plotting style + plt.style.use("default") + + # retrieve plotting kwargs + kwargs_plot = prepare_data(model, loader, all_p) + + def _plot_and_save(plt_handle, plt_name): + plt_handle(**kwargs_plot) + if plot_folder is not None: + os.makedirs(plot_folder, exist_ok=True) + fig = plt.gcf() + fig.savefig( + os.path.join(plot_folder, f"{plt_name}.pdf"), + bbox_inches="tight", + ) + plt.close(fig) + + _plot_and_save(plot_2d_boundary, "2d_boundary") + _plot_and_save(plot_static_vector_field, "static_vector_field") + + if not all_p: + return + + _plot_and_save(plot_2D_state_space, "2D_state_space") + _plot_and_save(plot_2D_depth_trajectory, "2D_depth_trajectory") + _plot_and_save(plot_dataset, "dataset") diff --git a/paper/node/script/plots_cnf.py b/paper/node/script/plots_cnf.py new file mode 100644 index 0000000..f44642e --- /dev/null +++ b/paper/node/script/plots_cnf.py @@ -0,0 +1,215 @@ +import copy +import os +import numpy as np +import matplotlib.pyplot as plt +import torch +from torchdyn.nn import Augmenter + + +def plot_dataset(x_data, **kwargs): + x_data = x_data.detach().cpu() + plt.figure(figsize=(3, 3)) + plot_samples(x_data, axis=plt.gca()) + + +def plot_for_sweep(**kwargs): + """Plot the desired sweep plot.""" + plot_samples(**kwargs) + + +def plot_samples(x_sampled, axis, **kwargs): + x_sampled = x_sampled.detach().cpu() + if x_sampled.shape[1] > 2: + x_sampled = x_sampled[:, 1:3] + axis.scatter( + x_sampled[:, 0], + x_sampled[:, 1], + s=0.2, + alpha=0.8, + linewidths=0, + c="midnightblue", + edgecolors=None, + ) + axis.set_xlim([-2, 2]) + axis.set_ylim([-2, 2]) + + +def plot_samples_density(x_data, x_sampled, **kwargs): + x_data = x_data.detach().cpu() + x_sampled = x_sampled.detach().cpu() + plt.figure(figsize=(12, 4)) + plt.subplot(121) + plot_samples(x_sampled, axis=plt.gca()) + + plt.subplot(122) + plot_samples(x_data, axis=plt.gca(), color="red") + plt.xlim(-2, 2) + plt.ylim(-2, 2) + + +def plot_flow(sample, trajectory, **kwargs): + traj = trajectory.detach().cpu() + sample = sample.detach().cpu() + n = 2000 + plt.figure(figsize=(6, 6)) + plt.scatter(sample[:n, 0], sample[:n, 1], s=10, alpha=0.8, c="black") + plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive") + plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue") + plt.legend(["Prior sample z(S)", "Flow", "z(0)"]) + + +def plot_2D_depth_trajectory( + s_span, trajectory, axis1=None, axis2=None, num_lines=200, **kwargs +): + if axis1 is None or axis2 is None: + fig = plt.figure(figsize=(8, 2)) + axis1 = fig.add_subplot(121) + axis2 = fig.add_subplot(122) + + # trajectory has shape [len(s_span), num_data_points, dim] originally + trajectory = trajectory.detach().permute(1, 2, 0).cpu() + + # subsample trajectories + num_lines = min(num_lines, len(trajectory)) + trajectory = trajectory[torch.randperm(num_lines)][:num_lines] + for traj_one in trajectory: + axis1.plot(s_span, traj_one[0], alpha=0.2) + axis2.plot(s_span, traj_one[1], alpha=0.2) + + axis1.set_xlabel(r"Depth") + axis1.set_ylabel(r"Dim. 0") + axis2.set_xlabel(r"Depth") + axis2.set_ylabel(r"Dim. 1") + + +def plot_static_vector_field(model, N=100, axis=None, **kwargs): + device = next(model.parameters()).device + model = model[1].defunc.m.net + x = torch.linspace(-2, 2, N) + y = torch.linspace(-2, 2, N) + + X, Y = torch.meshgrid(x, y) + U, V = torch.zeros(N, N), torch.zeros(N, N) + + for i in range(N): + for j in range(N): + p = torch.cat( + [X[i, j].reshape(1, 1), Y[i, j].reshape(1, 1)], 1 + ).to(device) + O = model(p).detach().cpu() + U[i, j], V[i, j] = O[0, 0], O[0, 1] + + # convert to cpu numpy + X, Y, U, V = [tnsr.cpu().numpy() for tnsr in (X, Y, U, V)] + + if axis is None: + fig = plt.figure(figsize=(3, 3)) + axis = fig.add_subplot(111) + + axis.contourf( + X, + Y, + np.sqrt(U ** 2 + V ** 2), + cmap="RdYlBu", + levels=1000, + alpha=0.6, + ) + + axis.streamplot( + X.T, + Y.T, + U.T, + V.T, + color="k", + density=1.5, + linewidth=0.7, + arrowsize=0.7, + arrowstyle="<|-", + ) + + axis.set_xlim([x.min(), x.max()]) + axis.set_ylim([y.min(), y.max()]) + axis.set_xlabel(r"$h_0$") + axis.set_ylabel(r"$h_1$") + axis.set_title("Learned Vector Field") + + +def prepare_data(model, loader, collect_from_loader=False, n_samp=2 ** 14): + """Prepare model and data and return as dict.""" + device = next(model.parameters()).device + + # collect prior from model + prior = None + for x_b, _ in loader: + prior = model(x_b.to(device))["prior"] + break + + # extract ffjord model + model = model.model + + # set s-span but keep old one around! + s_span_backup = copy.deepcopy(model[1].s_span) + model[1].s_span = torch.linspace(1, 0, 2).to(device) + + # s_span for trajectory + s_span_traj = torch.linspace(1, 0, 100) + + # preparing some data and samples for plotting + sample = prior.sample(torch.Size([n_samp])).to(device) + x_sampled = model(sample) + trajectory = model[1].trajectory( + Augmenter(1, 1)(sample), + s_span=s_span_traj.to(device), + ) + # scrapping first dimension := jacobian trace + trajectory = trajectory[:, :, 1:] + + # restore s-span + model[1].s_span = s_span_backup + + data = { + "sample": sample, + "x_sampled": x_sampled, + "s_span": s_span_traj, + "trajectory": trajectory, + "model": model, + "device": device, + } + + # collect data from loader + if collect_from_loader: + x_data = None + for x_b, _ in loader: + if x_data is None: + x_data = x_b + else: + x_data = torch.cat((x_data, x_b)) + x_data = x_data.to(device) + data["x_data"] = x_data + + return data + + +def plot_all(model, loader, plot_folder=None, all_p=False): + plt.style.use("default") + + # retrieve plotting kwargs + kwargs_plot = prepare_data(model, loader, True) + + def _plot_and_save(plt_handle, plt_name): + plt_handle(**kwargs_plot) + if plot_folder is not None: + os.makedirs(plot_folder, exist_ok=True) + fig = plt.gcf() + fig.savefig( + os.path.join(plot_folder, f"{plt_name}.pdf"), + bbox_inches="tight", + ) + plt.close(fig) + + _plot_and_save(plot_samples_density, "samples_density") + _plot_and_save(plot_static_vector_field, "static_vector_field") + + if all_p: + _plot_and_save(plot_dataset, "dataset") + _plot_and_save(plot_flow, "flow") diff --git a/paper/node/script/sizes/nn_.py b/paper/node/script/sizes/nn_.py new file mode 100644 index 0000000..b90c45e --- /dev/null +++ b/paper/node/script/sizes/nn_.py @@ -0,0 +1,567 @@ +#!/usr/bin/env python2 +# -*- coding: utf-8 -*- +""" +Created on Mon Dec 11 13:58:12 2017 + +@author: CW +""" + + +import math +import numpy as np + +import torch +import torch.nn as nn +from torch.nn import Module +from torch.nn import functional as F +from torch.nn.parameter import Parameter +from torch.nn.modules.utils import _pair +from torch.autograd import Variable + +# aliasing +N_ = None + + +delta = 1e-6 +softplus_ = nn.Softplus() +softplus = lambda x: softplus_(x) + delta +sigmoid_ = nn.Sigmoid() +sigmoid = lambda x: sigmoid_(x) * (1 - delta) + 0.5 * delta +sigmoid2 = lambda x: sigmoid(x) * 2.0 +logsigmoid = lambda x: -softplus(-x) +logit = lambda x: torch.log +log = lambda x: torch.log(x * 1e2) - np.log(1e2) +logit = lambda x: log(x) - log(1 - x) + + +def softmax(x, dim=-1): + e_x = torch.exp(x - x.max(dim=dim, keepdim=True)[0]) + out = e_x / e_x.sum(dim=dim, keepdim=True) + return out + + +sum1 = lambda x: x.sum(1) +sum_from_one = ( + lambda x: sum_from_one(sum1(x)) if len(x.size()) > 2 else sum1(x) +) + + +class Sigmoid(Module): + def forward(self, x): + return sigmoid(x) + + +class WNlinear(Module): + def __init__( + self, in_features, out_features, bias=True, mask=N_, norm=True + ): + super(WNlinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer("mask", mask) + self.norm = norm + self.direction = Parameter(torch.Tensor(out_features, in_features)) + self.scale = Parameter(torch.Tensor(out_features)) + if bias: + self.bias = Parameter(torch.Tensor(out_features)) + else: + self.register_parameter("bias", N_) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.direction.size(1)) + self.direction.data.uniform_(-stdv, stdv) + self.scale.data.uniform_(1, 1) + if self.bias is not N_: + self.bias.data.uniform_(-stdv, stdv) + + def forward(self, input): + if self.norm: + dir_ = self.direction + direction = dir_.div(dir_.pow(2).sum(1).sqrt()[:, N_]) + weight = self.scale[:, N_].mul(direction) + else: + weight = self.scale[:, N_].mul(self.direction) + if self.mask is not N_: + # weight = weight * getattr(self.mask, + # ('cpu', 'cuda')[weight.is_cuda])() + weight = weight * Variable(self.mask) + return F.linear(input, weight, self.bias) + + def __repr__(self): + return ( + self.__class__.__name__ + + "(" + + "in_features=" + + str(self.in_features) + + ", out_features=" + + str(self.out_features) + + ")" + ) + + +class CWNlinear(Module): + def __init__( + self, in_features, out_features, context_features, mask=N_, norm=True + ): + super(CWNlinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.context_features = context_features + self.register_buffer("mask", mask) + self.norm = norm + self.direction = Parameter(torch.Tensor(out_features, in_features)) + self.cscale = nn.Linear(context_features, out_features) + self.cbias = nn.Linear(context_features, out_features) + self.reset_parameters() + self.cscale.weight.data.normal_(0, 0.001) + self.cbias.weight.data.normal_(0, 0.001) + + def reset_parameters(self): + self.direction.data.normal_(0, 0.001) + + def forward(self, inputs): + input, context = inputs + scale = self.cscale(context) + bias = self.cbias(context) + if self.norm: + dir_ = self.direction + direction = dir_.div(dir_.pow(2).sum(1).sqrt()[:, N_]) + weight = direction + else: + weight = self.direction + if self.mask is not N_: + # weight = weight * getattr(self.mask, + # ('cpu', 'cuda')[weight.is_cuda])() + weight = weight * Variable(self.mask) + return scale * F.linear(input, weight, None) + bias, context + + def __repr__(self): + return ( + self.__class__.__name__ + + "(" + + "in_features=" + + str(self.in_features) + + ", out_features=" + + str(self.out_features) + + ")" + ) + + +class WNBilinear(Module): + def __init__(self, in1_features, in2_features, out_features, bias=True): + super(WNBilinear, self).__init__() + self.in1_features = in1_features + self.in2_features = in2_features + self.out_features = out_features + self.direction = Parameter( + torch.Tensor(out_features, in1_features, in2_features) + ) + self.scale = Parameter(torch.Tensor(out_features)) + if bias: + self.bias = Parameter(torch.Tensor(out_features)) + else: + self.register_parameter("bias", N_) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.direction.size(1)) + self.direction.data.uniform_(-stdv, stdv) + self.scale.data.uniform_(1, 1) + if self.bias is not N_: + self.bias.data.uniform_(-stdv, stdv) + + def forward(self, input1, input2): + dir_ = self.direction + direction = dir_.div(dir_.pow(2).sum(1).sum(1).sqrt()[:, N_, N_]) + weight = self.scale[:, N_, N_].mul(direction) + return F.bilinear(input1, input2, weight, self.bias) + + def __repr__(self): + return ( + self.__class__.__name__ + + "(" + + "in1_features=" + + str(self.in1_features) + + ", in2_features=" + + str(self.in2_features) + + ", out_features=" + + str(self.out_features) + + ")" + ) + + +class _WNconvNd(Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + bias, + ): + super(_WNconvNd, self).__init__() + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + if out_channels % groups != 0: + raise ValueError("out_channels must be divisible by groups") + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + + # weight – filters tensor (out_channels x in_channels/groups x kH x kW) + if transposed: + self.direction = Parameter( + torch.Tensor(in_channels, out_channels // groups, *kernel_size) + ) + self.scale = Parameter(torch.Tensor(in_channels)) + else: + self.direction = Parameter( + torch.Tensor(out_channels, in_channels // groups, *kernel_size) + ) + self.scale = Parameter(torch.Tensor(out_channels)) + if bias: + self.bias = Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias", N_) + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1.0 / math.sqrt(n) + self.direction.data.uniform_(-stdv, stdv) + self.scale.data.uniform_(1, 1) + if self.bias is not N_: + self.bias.data.uniform_(-stdv, stdv) + + def __repr__(self): + s = ( + "{name} ({in_channels}, {out_channels}, kernel_size={kernel_size}" + ", stride={stride}" + ) + if self.padding != (0,) * len(self.padding): + s += ", padding={padding}" + if self.dilation != (1,) * len(self.dilation): + s += ", dilation={dilation}" + if self.output_padding != (0,) * len(self.output_padding): + s += ", output_padding={output_padding}" + if self.groups != 1: + s += ", groups={groups}" + if self.bias is N_: + s += ", bias=False" + s += ")" + return s.format(name=self.__class__.__name__, **self.__dict__) + + +class WNconv2d(_WNconvNd): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + mask=N_, + norm=True, + ): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super(WNconv2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _pair(0), + groups, + bias, + ) + + self.register_buffer("mask", mask) + self.norm = norm + + def forward(self, input): + if self.norm: + dir_ = self.direction + direction = dir_.div( + dir_.pow(2).sum(1).sum(1).sum(1).sqrt()[:, N_, N_, N_] + ) + weight = self.scale[:, N_, N_, N_].mul(direction) + else: + weight = self.scale[:, N_, N_, N_].mul(self.direction) + if self.mask is not None: + weight = weight * Variable(self.mask) + return F.conv2d( + input, + weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + +class CWNconv2d(_WNconvNd): + def __init__( + self, + context_features, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + mask=N_, + norm=True, + ): + kernel_size = _pair(kernel_size) + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + super(CWNconv2d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + False, + _pair(0), + groups, + False, + ) + + self.register_buffer("mask", mask) + self.norm = norm + self.cscale = nn.Linear(context_features, out_channels) + self.cbias = nn.Linear(context_features, out_channels) + + def forward(self, inputs): + input, context = inputs + scale = self.cscale(context)[:, :, N_, N_] + bias = self.cbias(context)[:, :, N_, N_] + if self.norm: + dir_ = self.direction + direction = dir_.div( + dir_.pow(2).sum(1).sum(1).sum(1).sqrt()[:, N_, N_, N_] + ) + weight = direction + else: + weight = self.direction + if self.mask is not None: + weight = weight * Variable(self.mask) + pre = F.conv2d( + input, + weight, + None, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + return pre * scale + bias, context + + +class ResConv2d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + activation=nn.ReLU(), + oper=WNconv2d, + ): + super(ResConv2d, self).__init__() + + self.conv_0h = oper( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + self.conv_h1 = oper(out_channels, out_channels, 3, 1, 1, 1, 1, True) + self.conv_01 = oper( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + self.activation = activation + + def forward(self, input): + h = self.activation(self.conv_0h(input)) + out_nonlinear = self.conv_h1(h) + out_skip = self.conv_01(input) + return out_nonlinear + out_skip + + +class ResLinear(nn.Module): + def __init__( + self, + in_features, + out_features, + bias=True, + same_dim=False, + activation=nn.ReLU(), + oper=WNlinear, + ): + super(ResLinear, self).__init__() + + self.same_dim = same_dim + + self.dot_0h = oper(in_features, out_features, bias) + self.dot_h1 = oper(out_features, out_features, bias) + if not same_dim: + self.dot_01 = oper(in_features, out_features, bias) + + self.activation = activation + + def forward(self, input): + h = self.activation(self.dot_0h(input)) + out_nonlinear = self.dot_h1(h) + out_skip = input if self.same_dim else self.dot_01(input) + return out_nonlinear + out_skip + + +class GatingLinear(nn.Module): + def __init__(self, in_features, out_features, oper=WNlinear, **kwargs): + super(GatingLinear, self).__init__() + + self.dot = oper(in_features, out_features, **kwargs) + self.gate = oper(in_features, out_features, **kwargs) + + def forward(self, input): + h = self.dot(input) + s = sigmoid_(self.gate(input)) + return s * h + + +class Reshape(nn.Module): + def __init__(self, shape): + super(Reshape, self).__init__() + self.shape = shape + + def forward(self, input): + return input.view(self.shape) + + +class Slice(nn.Module): + def __init__(self, slc): + super(Slice, self).__init__() + self.slc = slc + + def forward(self, input): + return input.__getitem__(self.slc) + + +class SliceFactory(object): + def __init__(self): + pass + + def __getitem__(self, slc): + return Slice(slc) + + +slicer = SliceFactory() + + +class Lambda(nn.Module): + def __init__(self, function): + super(Lambda, self).__init__() + self.function = function + + def forward(self, input): + return self.function(input) + + +class SequentialFlow(nn.Sequential): + def sample(self, n=1, context=None, **kwargs): + dim = self[0].dim + if isinstance(dim, int): + dim = [ + dim, + ] + + spl = torch.autograd.Variable(torch.FloatTensor(n, *dim).normal_()) + lgd = torch.autograd.Variable( + torch.from_numpy(np.random.rand(n).astype("float32")) + ) + if context is None: + context = torch.autograd.Variable( + torch.from_numpy( + np.zeros((n, self[0].context_dim)).astype("float32") + ) + ) + + if hasattr(self, "gpu"): + if self.gpu: + spl = spl.cuda() + lgd = lgd.cuda() + context = context.cuda() + + return self.forward((spl, lgd, context)) + + def cuda(self): + self.gpu = True + return super(SequentialFlow, self).cuda() + + +class ContextWrapper(nn.Module): + def __init__(self, module): + super(ContextWrapper, self).__init__() + self.module = module + + def forward(self, inputs): + input, context = inputs + output = self.module.forward(input) + return output, context + + +if __name__ == "__main__": + + mdl = CWNlinear(2, 5, 3) + + inp = torch.autograd.Variable( + torch.from_numpy(np.random.rand(2, 2).astype("float32")) + ) + con = torch.autograd.Variable( + torch.from_numpy(np.random.rand(2, 3).astype("float32")) + ) + + print(mdl((inp, con))[0].size()) diff --git a/paper/node/script/sizes/sizes_maf.py b/paper/node/script/sizes/sizes_maf.py new file mode 100644 index 0000000..39de27a --- /dev/null +++ b/paper/node/script/sizes/sizes_maf.py @@ -0,0 +1,61 @@ +# %% just compute a couple of sizes + + +def made_size(l, h, d, k=None): + return int(1.5 * d * h + 0.5 * (l - 1) * h * h) + + +def nvp_size(l, h, d, k=10): + return int(2 * k * d * h + 2 * k * (l - 1) * h * h) + + +def maf_size(l, h, d, k=10): + return int(1.5 * k * d * h + 0.5 * k * (l - 1) * h * h) + + +def format_as_str(num): + if num / 1e9 > 1: + factor, suffix = 1e9, "B" + elif num / 1e6 > 1: + factor, suffix = 1e6, "M" + elif num / 1e3 > 1: + factor, suffix = 1e3, "K" + else: + factor, suffix = 1e0, "" + + num_factored = num / factor + + if num_factored / 1e2 > 1 or True: + num_rounded = str(int(round(num_factored))) + elif num_factored / 1e1 > 1: + num_rounded = f"{num_factored:.1f}" + else: + num_rounded = f"{num_factored:.2f}" + + return f"{num_rounded}{suffix} % {num}" + + +datasets = { + "power": {"l": 2, "h": 100, "d": 6}, + "gas": {"l": 2, "h": 100, "d": 8}, + "hepmass": {"l": 2, "h": 512, "d": 21}, + "miniboone": {"l": 2, "h": 512, "d": 43}, + "bsds300": {"l": 2, "h": 1024, "d": 63}, + "mnist": {"l": 1, "h": 1024, "d": 784, "k": 10}, + "cifar": {"l": 2, "h": 2048, "d": 3072, "k": 10}, +} + +networks = { + "MADE": made_size, + "RealNVP": nvp_size, + "MAF": maf_size, +} + +for net, handle in networks.items(): + print(net) + for dset, s_kwargs in datasets.items(): + print(f"{dset}: #params: {format_as_str(handle(**s_kwargs))}") + print("\n") + + +# miniboone ffjord: 820613 \ No newline at end of file diff --git a/paper/node/script/sizes/sizes_naf.py b/paper/node/script/sizes/sizes_naf.py new file mode 100644 index 0000000..27b03f2 --- /dev/null +++ b/paper/node/script/sizes/sizes_naf.py @@ -0,0 +1,1082 @@ +# %% + +import numpy as np +import torch +import torch.utils.data as data +import torch.nn as nn + +# import torch.optim as optim +from torch.autograd import Variable +from torchvision.transforms import transforms + + +import time +import json +import argparse, os + + +import torch +import torch.nn as nn +from torch.nn import Module +from torch.nn.parameter import Parameter +from torch.nn import functional as F +import nn_ as nn_ +from nn_ import log +from torch.autograd import Variable +import numpy as np + + +sum_from_one = nn_.sum_from_one + + +#!/usr/bin/env python2 +# -*- coding: utf-8 -*- +""" +Created on Mon Dec 11 13:58:12 2017 + +@author: CW +""" + + +import numpy as np + +import torch +import torch.nn as nn +from torch.nn import Module + + +from functools import reduce + +# aliasing +N_ = None + + +delta = 1e-6 +softplus_ = nn.Softplus() +softplus = lambda x: softplus_(x) + delta + + +tile = lambda x, r: np.tile(x, r).reshape(x.shape[0], x.shape[1] * r) + + +# %------------ MADE ------------% + + +def get_rank(max_rank, num_out): + rank_out = np.array([]) + while len(rank_out) < num_out: + rank_out = np.concatenate([rank_out, np.arange(max_rank)]) + excess = len(rank_out) - num_out + remove_ind = np.random.choice(max_rank, excess, False) + rank_out = np.delete(rank_out, remove_ind) + np.random.shuffle(rank_out) + return rank_out.astype("float32") + + +def get_mask_from_ranks(r1, r2): + return (r2[:, None] >= r1[None, :]).astype("float32") + + +def get_masks_all(ds, fixed_order=False, derank=1): + # ds: list of dimensions dx, d1, d2, ... dh, dx, + # (2 in/output + h hidden layers) + # derank only used for self connection, dim > 1 + dx = ds[0] + ms = list() + rx = get_rank(dx, dx) + if fixed_order: + rx = np.sort(rx) + r1 = rx + if dx != 1: + for d in ds[1:-1]: + r2 = get_rank(dx - derank, d) + ms.append(get_mask_from_ranks(r1, r2)) + r1 = r2 + r2 = rx - derank + ms.append(get_mask_from_ranks(r1, r2)) + else: + ms = [ + np.zeros([ds[i + 1], ds[i]]).astype("float32") + for i in range(len(ds) - 1) + ] + if derank == 1: + assert np.all(np.diag(reduce(np.dot, ms[::-1])) == 0), "wrong masks" + + return ms, rx + + +def get_masks(dim, dh, num_layers, num_outlayers, fixed_order=False, derank=1): + ms, rx = get_masks_all( + [ + dim, + ] + + [dh for i in range(num_layers - 1)] + + [ + dim, + ], + fixed_order, + derank, + ) + ml = ms[-1] + ml_ = ( + ( + ml.transpose(1, 0)[:, :, None] + * ( + [ + np.cast["float32"](1), + ] + * num_outlayers + ) + ) + .reshape(dh, dim * num_outlayers) + .transpose(1, 0) + ) + ms[-1] = ml_ + return ms, rx + + +class MADE(Module): + def __init__( + self, + dim, + hid_dim, + num_layers, + num_outlayers=1, + activation=nn.ELU(), + fixed_order=False, + derank=1, + ): + super(MADE, self).__init__() + + oper = nn_.WNlinear + + self.dim = dim + self.hid_dim = hid_dim + self.num_layers = num_layers + self.num_outlayers = num_outlayers + self.activation = activation + + ms, rx = get_masks( + dim, hid_dim, num_layers, num_outlayers, fixed_order, derank + ) + ms = [m for m in map(torch.from_numpy, ms)] + self.rx = rx + + sequels = list() + for i in range(num_layers - 1): + if i == 0: + sequels.append(oper(dim, hid_dim, True, ms[i], False)) + sequels.append(activation) + else: + sequels.append(oper(hid_dim, hid_dim, True, ms[i], False)) + sequels.append(activation) + + self.input_to_hidden = nn.Sequential(*sequels) + self.hidden_to_output = oper( + hid_dim, dim * num_outlayers, True, ms[-1] + ) + + def forward(self, input): + hid = self.input_to_hidden(input) + return self.hidden_to_output(hid).view( + -1, self.dim, self.num_outlayers + ) + + def randomize(self): + ms, rx = get_masks( + self.dim, self.hid_dim, self.num_layers, self.num_outlayers + ) + for i in range(self.num_layers - 1): + mask = torch.from_numpy(ms[i]) + if self.input_to_hidden[i * 2].mask.is_cuda: + mask = mask.cuda() + self.input_to_hidden[i * 2].mask.data.zero_().add_(mask) + self.rx = rx + + +class cMADE(Module): + def __init__( + self, + dim, + hid_dim, + context_dim, + num_layers, + num_outlayers=1, + activation=nn.ELU(), + fixed_order=False, + derank=1, + ): + super(cMADE, self).__init__() + + oper = nn_.CWNlinear + + self.dim = dim + self.hid_dim = hid_dim + self.num_layers = num_layers + self.context_dim = context_dim + self.num_outlayers = num_outlayers + self.activation = nn_.Lambda(lambda x: (activation(x[0]), x[1])) + + ms, rx = get_masks( + dim, hid_dim, num_layers, num_outlayers, fixed_order, derank + ) + ms = [m for m in map(torch.from_numpy, ms)] + self.rx = rx + + sequels = list() + for i in range(num_layers - 1): + if i == 0: + sequels.append(oper(dim, hid_dim, context_dim, ms[i], False)) + sequels.append(self.activation) + else: + sequels.append( + oper(hid_dim, hid_dim, context_dim, ms[i], False) + ) + sequels.append(self.activation) + + self.input_to_hidden = nn.Sequential(*sequels) + self.hidden_to_output = oper( + hid_dim, dim * num_outlayers, context_dim, ms[-1] + ) + + def forward(self, inputs): + input, context = inputs + hid, _ = self.input_to_hidden((input, context)) + out, _ = self.hidden_to_output((hid, context)) + return out.view(-1, self.dim, self.num_outlayers), context + + def randomize(self): + ms, rx = get_masks( + self.dim, self.hid_dim, self.num_layers, self.num_outlayers + ) + for i in range(self.num_layers - 1): + mask = torch.from_numpy(ms[i]) + if self.input_to_hidden[i * 2].mask.is_cuda: + mask = mask.cuda() + self.input_to_hidden[i * 2].mask.zero_().add_(mask) + self.rx = rx + + +class BaseFlow(Module): + def sample(self, n=1, context=None, **kwargs): + dim = self.dim + if isinstance(self.dim, int): + dim = [ + dim, + ] + + spl = Variable(torch.FloatTensor(n, *dim).normal_()) + lgd = Variable(torch.from_numpy(np.zeros(n).astype("float32"))) + if context is None: + context = Variable( + torch.from_numpy( + np.ones((n, self.context_dim)).astype("float32") + ) + ) + + if hasattr(self, "gpu"): + if self.gpu: + spl = spl.cuda() + lgd = lgd.cuda() + context = context.gpu() + + return self.forward((spl, lgd, context)) + + def cuda(self): + self.gpu = True + return super(BaseFlow, self).cuda() + + +class LinearFlow(BaseFlow): + def __init__( + self, dim, context_dim, oper=nn_.ResLinear, realify=nn_.softplus + ): + super(LinearFlow, self).__init__() + self.realify = realify + + self.dim = dim + self.context_dim = context_dim + + if type(dim) is int: + dim_ = dim + else: + dim_ = np.prod(dim) + + self.mean = oper(context_dim, dim_) + self.lstd = oper(context_dim, dim_) + + self.reset_parameters() + + def reset_parameters(self): + if isinstance(self.mean, nn_.ResLinear): + self.mean.dot_01.scale.data.uniform_(-0.001, 0.001) + self.mean.dot_h1.scale.data.uniform_(-0.001, 0.001) + self.mean.dot_01.bias.data.uniform_(-0.001, 0.001) + self.mean.dot_h1.bias.data.uniform_(-0.001, 0.001) + self.lstd.dot_01.scale.data.uniform_(-0.001, 0.001) + self.lstd.dot_h1.scale.data.uniform_(-0.001, 0.001) + if self.realify == nn_.softplus: + inv = np.log(np.exp(1 - nn_.delta) - 1) * 0.5 + self.lstd.dot_01.bias.data.uniform_(inv - 0.001, inv + 0.001) + self.lstd.dot_h1.bias.data.uniform_(inv - 0.001, inv + 0.001) + else: + self.lstd.dot_01.bias.data.uniform_(-0.001, 0.001) + self.lstd.dot_h1.bias.data.uniform_(-0.001, 0.001) + elif isinstance(self.mean, nn.Linear): + self.mean.weight.data.uniform_(-0.001, 0.001) + self.mean.bias.data.uniform_(-0.001, 0.001) + self.lstd.weight.data.uniform_(-0.001, 0.001) + if self.realify == nn_.softplus: + inv = np.log(np.exp(1 - nn_.delta) - 1) * 0.5 + self.lstd.bias.data.uniform_(inv - 0.001, inv + 0.001) + else: + self.lstd.bias.data.uniform_(-0.001, 0.001) + + def forward(self, inputs): + x, logdet, context = inputs + mean = self.mean(context) + lstd = self.lstd(context) + std = self.realify(lstd) + + if type(self.dim) is int: + x_ = mean + std * x + else: + size = x.size() + x_ = mean.view(size) + std.view(size) * x + logdet_ = sum_from_one(torch.log(std)) + logdet + return x_, logdet_, context + + +class BlockAffineFlow(Module): + # NICE, volume preserving + # x2' = x2 + nonLinfunc(x1) + + def __init__(self, dim1, dim2, context_dim, hid_dim, activation=nn.ELU()): + super(BlockAffineFlow, self).__init__() + self.dim1 = dim1 + self.dim2 = dim2 + self.actv = activation + + self.hid = nn_.WNBilinear(dim1, context_dim, hid_dim) + self.shift = nn_.WNBilinear(hid_dim, context_dim, dim2) + + def forward(self, inputs): + x, logdet, context = inputs + x1, x2 = x + + hid = self.actv(self.hid(x1, context)) + shift = self.shift(hid, context) + + x2_ = x2 + shift + + return (x1, x2_), 0, context + + +class IAF(BaseFlow): + def __init__( + self, + dim, + hid_dim, + context_dim, + num_layers, + activation=nn.ELU(), + realify=nn_.sigmoid, + fixed_order=False, + ): + super(IAF, self).__init__() + self.realify = realify + + self.dim = dim + self.context_dim = context_dim + + if type(dim) is int: + self.mdl = cMADE( + dim, + hid_dim, + context_dim, + num_layers, + 2, + activation, + fixed_order, + ) + self.reset_parameters() + + def reset_parameters(self): + self.mdl.hidden_to_output.cscale.weight.data.uniform_(-0.001, 0.001) + self.mdl.hidden_to_output.cscale.bias.data.uniform_(0.0, 0.0) + self.mdl.hidden_to_output.cbias.weight.data.uniform_(-0.001, 0.001) + self.mdl.hidden_to_output.cbias.bias.data.uniform_(0.0, 0.0) + if self.realify == nn_.softplus: + inv = np.log(np.exp(1 - nn_.delta) - 1) + self.mdl.hidden_to_output.cbias.bias.data[1::2].uniform_(inv, inv) + elif self.realify == nn_.sigmoid: + self.mdl.hidden_to_output.cbias.bias.data[1::2].uniform_(2.0, 2.0) + + def forward(self, inputs): + x, logdet, context = inputs + out, _ = self.mdl((x, context)) + if isinstance(self.mdl, cMADE): + mean = out[:, :, 0] + lstd = out[:, :, 1] + + std = self.realify(lstd) + + if self.realify == nn_.softplus: + x_ = mean + std * x + elif self.realify == nn_.sigmoid: + x_ = (-std + 1.0) * mean + std * x + elif self.realify == nn_.sigmoid2: + x_ = (-std + 2.0) * mean + std * x + logdet_ = sum_from_one(torch.log(std)) + logdet + return x_, logdet_, context + + +class IAF_VP(BaseFlow): + def __init__( + self, + dim, + hid_dim, + context_dim, + num_layers, + activation=nn.ELU(), + fixed_order=True, + ): + super(IAF_VP, self).__init__() + + self.dim = dim + self.context_dim = context_dim + + if type(dim) is int: + self.mdl = cMADE( + dim, + hid_dim, + context_dim, + num_layers, + 1, + activation, + fixed_order, + ) + self.reset_parameters() + + def reset_parameters(self): + self.mdl.hidden_to_output.cscale.weight.data.uniform_(-0.001, 0.001) + self.mdl.hidden_to_output.cscale.bias.data.uniform_(0.0, 0.0) + self.mdl.hidden_to_output.cbias.weight.data.uniform_(-0.001, 0.001) + self.mdl.hidden_to_output.cbias.bias.data.uniform_(0.0, 0.0) + + def forward(self, inputs): + x, logdet, context = inputs + out, _ = self.mdl((x, context)) + mean = out[:, :, 0] + x_ = mean + x + return x_, logdet, context + + +class IAF_DSF(BaseFlow): + + mollify = 0.0 + + def __init__( + self, + dim, + hid_dim, + context_dim, + num_layers, + activation=nn.ELU(), + fixed_order=False, + num_ds_dim=4, + num_ds_layers=1, + num_ds_multiplier=3, + ): + super(IAF_DSF, self).__init__() + + self.dim = dim + self.context_dim = context_dim + self.num_ds_dim = num_ds_dim + self.num_ds_layers = num_ds_layers + + if type(dim) is int: + self.mdl = cMADE( + dim, + hid_dim, + context_dim, + num_layers, + num_ds_multiplier * (hid_dim // dim) * num_ds_layers, + activation, + fixed_order, + ) + self.out_to_dsparams = nn.Conv1d( + num_ds_multiplier * (hid_dim // dim) * num_ds_layers, + 3 * num_ds_layers * num_ds_dim, + 1, + ) + self.reset_parameters() + + self.sf = SigmoidFlow(num_ds_dim) + + def reset_parameters(self): + self.out_to_dsparams.weight.data.uniform_(-0.001, 0.001) + self.out_to_dsparams.bias.data.uniform_(0.0, 0.0) + + inv = np.log(np.exp(1 - nn_.delta) - 1) + for l in range(self.num_ds_layers): + nc = self.num_ds_dim + nparams = nc * 3 + s = l * nparams + self.out_to_dsparams.bias.data[s : s + nc].uniform_(inv, inv) + + def forward(self, inputs): + x, logdet, context = inputs + out, _ = self.mdl((x, context)) + if isinstance(self.mdl, cMADE): + out = out.permute(0, 2, 1) + dsparams = self.out_to_dsparams(out).permute(0, 2, 1) + nparams = self.num_ds_dim * 3 + + mollify = self.mollify + h = x.view(x.size(0), -1) + for i in range(self.num_ds_layers): + params = dsparams[:, :, i * nparams : (i + 1) * nparams] + h, logdet = self.sf(h, logdet, params, mollify) + + return h, logdet, context + + +class SigmoidFlow(BaseFlow): + def __init__(self, num_ds_dim=4): + super(SigmoidFlow, self).__init__() + self.num_ds_dim = num_ds_dim + + self.act_a = lambda x: nn_.softplus(x) + self.act_b = lambda x: x + self.act_w = lambda x: nn_.softmax(x, dim=2) + + def forward(self, x, logdet, dsparams, mollify=0.0, delta=nn_.delta): + + ndim = self.num_ds_dim + a_ = self.act_a(dsparams[:, :, 0 * ndim : 1 * ndim]) + b_ = self.act_b(dsparams[:, :, 1 * ndim : 2 * ndim]) + w = self.act_w(dsparams[:, :, 2 * ndim : 3 * ndim]) + + a = a_ * (1 - mollify) + 1.0 * mollify + b = b_ * (1 - mollify) + 0.0 * mollify + + pre_sigm = a * x[:, :, None] + b + sigm = torch.sigmoid(pre_sigm) + x_pre = torch.sum(w * sigm, dim=2) + x_pre_clipped = x_pre * (1 - delta) + delta * 0.5 + x_ = log(x_pre_clipped) - log(1 - x_pre_clipped) + xnew = x_ + + logj = ( + F.log_softmax(dsparams[:, :, 2 * ndim : 3 * ndim], dim=2) + + nn_.logsigmoid(pre_sigm) + + nn_.logsigmoid(-pre_sigm) + + log(a) + ) + + logj = torch.exp(logj, 2).sum(2) + logdet_ = ( + logj + + np.log(1 - delta) + - (log(x_pre_clipped) + log(-x_pre_clipped + 1)) + ) + logdet = logdet_.sum(1) + logdet + + return xnew, logdet + + +class IAF_DDSF(BaseFlow): + def __init__( + self, + dim, + hid_dim, + context_dim, + num_layers, + activation=nn.ELU(), + fixed_order=False, + num_ds_dim=4, + num_ds_layers=1, + num_ds_multiplier=3, + ): + super(IAF_DDSF, self).__init__() + + self.dim = dim + self.context_dim = context_dim + self.num_ds_dim = num_ds_dim + self.num_ds_layers = num_ds_layers + + if type(dim) is int: + self.mdl = cMADE( + dim, + hid_dim, + context_dim, + num_layers, + int(num_ds_multiplier * (hid_dim / dim) * num_ds_layers), + activation, + fixed_order, + ) + + num_dsparams = 0 + for i in range(num_ds_layers): + if i == 0: + in_dim = 1 + else: + in_dim = num_ds_dim + if i == num_ds_layers - 1: + out_dim = 1 + else: + out_dim = num_ds_dim + + u_dim = in_dim + w_dim = num_ds_dim + a_dim = b_dim = num_ds_dim + num_dsparams += u_dim + w_dim + a_dim + b_dim + + self.add_module( + "sf{}".format(i), DenseSigmoidFlow(in_dim, num_ds_dim, out_dim) + ) + if type(dim) is int: + self.out_to_dsparams = nn.Conv1d( + int(num_ds_multiplier * (hid_dim / dim) * num_ds_layers), + int(num_dsparams), + 1, + ) + else: + self.out_to_dsparams = nn.Conv1d( + num_ds_multiplier * (hid_dim / dim[0]) * num_ds_layers, + num_dsparams, + 1, + ) + + self.reset_parameters() + + def reset_parameters(self): + self.out_to_dsparams.weight.data.uniform_(-0.001, 0.001) + self.out_to_dsparams.bias.data.uniform_(0.0, 0.0) + + def forward(self, inputs): + x, logdet, context = inputs + out, _ = self.mdl((x, context)) + out = out.permute(0, 2, 1) + dsparams = self.out_to_dsparams(out).permute(0, 2, 1) + + start = 0 + + h = x.view(x.size(0), -1)[:, :, None] + n = x.size(0) + dim = self.dim if type(self.dim) is int else self.dim[0] + lgd = Variable( + torch.from_numpy(np.zeros((n, dim, 1, 1)).astype("float32")) + ) + if self.out_to_dsparams.weight.is_cuda: + lgd = lgd.cuda() + for i in range(self.num_ds_layers): + if i == 0: + in_dim = 1 + else: + in_dim = self.num_ds_dim + if i == self.num_ds_layers - 1: + out_dim = 1 + else: + out_dim = self.num_ds_dim + + u_dim = in_dim + w_dim = self.num_ds_dim + a_dim = b_dim = self.num_ds_dim + end = start + u_dim + w_dim + a_dim + b_dim + + params = dsparams[:, :, start:end] + h, lgd = getattr(self, "sf{}".format(i))(h, lgd, params) + start = end + + assert out_dim == 1, "last dsf out dim should be 1" + return h[:, :, 0], lgd[:, :, 0, 0].sum(1) + logdet, context + + +class DenseSigmoidFlow(BaseFlow): + def __init__(self, in_dim, hidden_dim, out_dim): + super(DenseSigmoidFlow, self).__init__() + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.out_dim = out_dim + + self.act_a = lambda x: nn_.softplus(x) + self.act_b = lambda x: x + self.act_w = lambda x: nn_.softmax(x, dim=3) + self.act_u = lambda x: nn_.softmax(x, dim=3) + + self.u_ = Parameter(torch.Tensor(hidden_dim, in_dim)) + self.w_ = Parameter(torch.Tensor(out_dim, hidden_dim)) + + self.reset_parameters() + + def reset_parameters(self): + self.u_.data.uniform_(-0.001, 0.001) + self.w_.data.uniform_(-0.001, 0.001) + + def forward(self, x, logdet, dsparams): + inv = np.log(np.exp(1 - nn_.delta) - 1) + ndim = self.hidden_dim + pre_u = ( + self.u_[None, None, :, :] + + dsparams[:, :, -self.in_dim :][:, :, None, :] + ) + pre_w = ( + self.w_[None, None, :, :] + + dsparams[:, :, 2 * ndim : 3 * ndim][:, :, None, :] + ) + a = self.act_a(dsparams[:, :, 0 * ndim : 1 * ndim] + inv) + b = self.act_b(dsparams[:, :, 1 * ndim : 2 * ndim]) + w = self.act_w(pre_w) + u = self.act_u(pre_u) + + pre_sigm = torch.sum(u * a[:, :, :, None] * x[:, :, None, :], 3) + b + sigm = torch.sigmoid(pre_sigm) + x_pre = torch.sum(w * sigm[:, :, None, :], dim=3) + x_pre_clipped = x_pre * (1 - nn_.delta) + nn_.delta * 0.5 + x_ = log(x_pre_clipped) - log(1 - x_pre_clipped) + xnew = x_ + + logj = ( + F.log_softmax(pre_w, dim=3) + + nn_.logsigmoid(pre_sigm[:, :, None, :]) + + nn_.logsigmoid(-pre_sigm[:, :, None, :]) + + log(a[:, :, None, :]) + ) + # n, d, d2, dh + + logj = ( + logj[:, :, :, :, None] + + F.log_softmax(pre_u, dim=3)[:, :, None, :, :] + ) + # n, d, d2, dh, d1 + + logj = torch.exp(logj, 3).sum(3) + # n, d, d2, d1 + + logdet_ = ( + logj + + np.log(1 - nn_.delta) + - (log(x_pre_clipped) + log(-x_pre_clipped + 1))[:, :, :, None] + ) + + logdet = torch.exp( + logdet_[:, :, :, :, None] + logdet[:, :, None, :, :], 3 + ).sum(3) + # n, d, d2, d1, d0 -> n, d, d2, d0 + + return xnew, logdet + + +class FlipFlow(BaseFlow): + def __init__(self, dim): + self.dim = dim + super(FlipFlow, self).__init__() + + def forward(self, inputs): + input, logdet, context = inputs + + dim = self.dim + index = Variable( + getattr( + torch.arange(input.size(dim) - 1, -1, -1), + ("cpu", "cuda")[input.is_cuda], + )().long() + ) + + output = torch.index_select(input, dim, index) + + return output, logdet, context + + +class Sigmoid(BaseFlow): + def __init__(self): + super(Sigmoid, self).__init__() + + def forward(self, inputs): + if len(inputs) == 2: + input, logdet = inputs + elif len(inputs) == 3: + input, logdet, context = inputs + else: + raise (Exception("inputs length not correct")) + + output = F.sigmoid(input) + logdet += sum_from_one(-F.softplus(input) - F.softplus(-input)) + + if len(inputs) == 2: + return output, logdet + elif len(inputs) == 3: + return output, logdet, context + else: + raise (Exception("inputs length not correct")) + + +class Logit(BaseFlow): + def __init__(self): + super(Logit, self).__init__() + + def forward(self, inputs): + if len(inputs) == 2: + input, logdet = inputs + elif len(inputs) == 3: + input, logdet, context = inputs + else: + raise (Exception("inputs length not correct")) + + output = log(input) - log(1 - input) + logdet -= sum_from_one(log(input) + log(-input + 1)) + + if len(inputs) == 2: + return output, logdet + elif len(inputs) == 3: + return output, logdet, context + else: + raise (Exception("inputs length not correct")) + + +class Shift(BaseFlow): + def __init__(self, b): + self.b = b + super(Shift, self).__init__() + + def forward(self, inputs): + if len(inputs) == 2: + input, logdet = inputs + elif len(inputs) == 3: + input, logdet, context = inputs + else: + raise (Exception("inputs length not correct")) + + output = input + self.b + + if len(inputs) == 2: + return output, logdet + elif len(inputs) == 3: + return output, logdet, context + else: + raise (Exception("inputs length not correct")) + + +class Scale(BaseFlow): + def __init__(self, g): + self.g = g + super(Scale, self).__init__() + + def forward(self, inputs): + if len(inputs) == 2: + input, logdet = inputs + elif len(inputs) == 3: + input, logdet, context = inputs + else: + raise (Exception("inputs length not correct")) + + output = input * self.g + logdet += np.log(np.abs(self.g)) * np.prod(input.size()[1:]) + + if len(inputs) == 2: + return output, logdet + elif len(inputs) == 3: + return output, logdet, context + else: + raise (Exception("inputs length not correct")) + + +class MAF(object): + def __init__(self, args, p): + + self.args = args + self.__dict__.update(args.__dict__) + self.p = p + + dim = p + dimc = 1 + dimh = args.dimh + flowtype = args.flowtype + num_flow_layers = args.num_flow_layers + num_ds_dim = args.num_ds_dim + num_ds_layers = args.num_ds_layers + fixed_order = args.fixed_order + + act = nn.ELU() + if flowtype == "affine": + flow = IAF + elif flowtype == "dsf": + flow = lambda **kwargs: IAF_DSF( + num_ds_dim=num_ds_dim, num_ds_layers=num_ds_layers, **kwargs + ) + elif flowtype == "ddsf": + flow = lambda **kwargs: IAF_DDSF( + num_ds_dim=num_ds_dim, num_ds_layers=num_ds_layers, **kwargs + ) + + sequels = [ + nn_.SequentialFlow( + flow( + dim=dim, + hid_dim=dimh, + context_dim=dimc, + num_layers=args.num_hid_layers + 1, + activation=act, + fixed_order=fixed_order, + ), + FlipFlow(1), + ) + for i in range(num_flow_layers) + ] + [ + LinearFlow(dim, dimc), + ] + + self.flow = nn.Sequential(*sequels) + + +# ============================================================================= +# main +# ============================================================================= + + +"""parsing and configuration""" + + +def parse_args(): + desc = "MAF" + parser = argparse.ArgumentParser(description=desc) + + parser.add_argument( + "--dataset", + type=str, + default="miniboone", + choices=["power", "gas", "hepmass", "miniboone", "bsds300"], + ) + parser.add_argument( + "--epoch", type=int, default=400, help="The number of epochs to run" + ) + parser.add_argument( + "--batch_size", type=int, default=100, help="The size of batch" + ) + parser.add_argument( + "--save_dir", + type=str, + default="models", + help="Directory name to save the model", + ) + parser.add_argument( + "--result_dir", + type=str, + default="results", + help="Directory name to save the generated images", + ) + parser.add_argument( + "--log_dir", + type=str, + default="logs", + help="Directory name to save training logs", + ) + parser.add_argument("--seed", type=int, default=1993, help="Random seed") + parser.add_argument( + "--fn", type=str, default="0", help="Filename of model to be loaded" + ) + parser.add_argument( + "--to_train", type=int, default=1, help="1 if to train 0 if not" + ) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--clip", type=float, default=5.0) + parser.add_argument("--beta1", type=float, default=0.9) + parser.add_argument("--beta2", type=float, default=0.999) + parser.add_argument("--amsgrad", type=int, default=0) + parser.add_argument("--polyak", type=float, default=0.0) + parser.add_argument("--cuda", type=bool, default=False) + + parser.add_argument("--dimh", type=int, default=100) + parser.add_argument("--flowtype", type=str, default="ddsf") + parser.add_argument("--num_flow_layers", type=int, default=10) + parser.add_argument("--num_hid_layers", type=int, default=1) + parser.add_argument("--num_ds_dim", type=int, default=16) + parser.add_argument("--num_ds_layers", type=int, default=1) + parser.add_argument( + "--fixed_order", + type=bool, + default=True, + help="Fix the made ordering to be the given order", + ) + + return check_args(parser.parse_args()) + + +"""checking arguments""" + + +def check_args(args): + # --save_dir + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + + # --result_dir + if not os.path.exists(args.result_dir + "_" + args.dataset): + os.makedirs(args.result_dir + "_" + args.dataset) + + # --result_dir + if not os.path.exists(args.log_dir): + os.makedirs(args.log_dir) + + # --epoch + try: + assert args.epoch >= 1 + except: + print("number of epochs must be larger than or equal to one") + + # --batch_size + try: + assert args.batch_size >= 1 + except: + print("batch size must be larger than or equal to one") + + return args + + +datasets = { + "power": {"d": 6, "dimh": 100, "num_hid_layers": 2}, + "gas": {"d": 8, "dimh": 100, "num_hid_layers": 2}, + "hepmass": {"d": 21, "dimh": 512, "num_hid_layers": 2}, + "miniboone": {"d": 43, "dimh": 512, "num_hid_layers": 1}, + "bsds300": {"d": 63, "dimh": 1024, "num_hid_layers": 2}, +} + + +def format_as_str(num): + if num / 1e9 > 1: + factor, suffix = 1e9, "B" + elif num / 1e6 > 1: + factor, suffix = 1e6, "M" + elif num / 1e3 > 1: + factor, suffix = 1e3, "K" + else: + factor, suffix = 1e0, "" + + num_factored = num / factor + + if num_factored / 1e2 > 1 or True: + num_rounded = str(int(round(num_factored))) + elif num_factored / 1e1 > 1: + num_rounded = f"{num_factored:.1f}" + else: + num_rounded = f"{num_factored:.2f}" + + return f"{num_rounded}{suffix} % {num}" + + +def naf_size(d, **kwargs): + from torchprune.util.net import NetHandle + + args = parse_args() + for key, val in kwargs.items(): + setattr(args, key, val) + model = MAF(args, d) + model = NetHandle(model.flow) + return model.size() + + +for dset, s_kwargs in datasets.items(): + print(f"{dset}: #params: {format_as_str(naf_size(**s_kwargs))}") + print("\n") diff --git a/paper/node/script/sizes/sizes_sos.py b/paper/node/script/sizes/sizes_sos.py new file mode 100644 index 0000000..af5dc28 --- /dev/null +++ b/paper/node/script/sizes/sizes_sos.py @@ -0,0 +1,459 @@ +# %% +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + +from torchprune.util.net import NetHandle + + +def get_mask(in_features, out_features, in_flow_features, mask_type=None): + """ + mask_type: input | None | output + + See Figure 1 for a better illustration: + https://arxiv.org/pdf/1502.03509.pdf + """ + if mask_type == "input": + in_degrees = torch.arange(in_features) % in_flow_features + else: + in_degrees = torch.arange(in_features) % (in_flow_features - 1) + + if mask_type == "output": + out_degrees = torch.arange(out_features) % in_flow_features - 1 + else: + out_degrees = torch.arange(out_features) % (in_flow_features - 1) + + return (out_degrees.unsqueeze(-1) >= in_degrees.unsqueeze(0)).float() + + +class MaskedLinear(nn.Linear): + def __init__(self, in_features, out_features, mask, bias=True): + super(MaskedLinear, self).__init__(in_features, out_features, bias) + self.register_buffer("mask", mask) + + def forward(self, inputs): + return F.linear(inputs, self.weight * self.mask, self.bias) + + +class ConditionerNet(nn.Module): + def __init__(self, input_size, hidden_size, k, m, n_layers=1): + super().__init__() + self.k = k + self.m = m + self.input_size = input_size + self.output_size = k * self.m * input_size + input_size + self.network = self._make_net( + input_size, hidden_size, self.output_size, n_layers + ) + + def _make_net(self, input_size, hidden_size, output_size, n_layers): + if self.input_size > 1: + input_mask = get_mask( + input_size, hidden_size, input_size, mask_type="input" + ) + hidden_mask = get_mask(hidden_size, hidden_size, input_size) + output_mask = get_mask( + hidden_size, output_size, input_size, mask_type="output" + ) + + network = nn.Sequential( + MaskedLinear(input_size, hidden_size, input_mask), + nn.ReLU(), + MaskedLinear(hidden_size, hidden_size, hidden_mask), + nn.ReLU(), + MaskedLinear(hidden_size, output_size, output_mask), + ) + else: + network = nn.Sequential( + nn.Linear(input_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, output_size), + ) + + """ + for module in network.modules(): + if isinstance(module, nn.Linear): + nn.init.orthogonal_(module.weight) + module.bias.data.fill_(0) + """ + + return network + + def forward(self, inputs): + batch_size = inputs.size(0) + params = self.network(inputs) + i = self.k * self.m * self.input_size + c = ( + params[:, :i] + .view(batch_size, -1, self.input_size) + .transpose(1, 2) + .view(batch_size, self.input_size, self.k, self.m, 1) + ) + const = params[:, i:].view(batch_size, self.input_size) + C = torch.matmul(c, c.transpose(3, 4)) + return C, const + + +# +# SOS Block: +# + + +class SOSFlow(nn.Module): + @staticmethod + def power(z, k): + return z ** (torch.arange(k).float().to(z.device)) + + def __init__(self, input_size, hidden_size, k, r, n_layers=1): + super().__init__() + self.k = k + self.m = r + 1 + + self.conditioner = ConditionerNet( + input_size, hidden_size, k, self.m, n_layers + ) + self.register_buffer("filter", self._make_filter()) + + def _make_filter(self): + n = torch.arange(self.m).unsqueeze(1) + e = torch.ones(self.m).unsqueeze(1).long() + filter = (n.mm(e.transpose(0, 1))) + (e.mm(n.transpose(0, 1))) + 1 + return filter.float() + + def forward(self, inputs, mode="direct"): + batch_size, input_size = inputs.size(0), inputs.size(1) + C, const = self.conditioner(inputs) + X = SOSFlow.power(inputs.unsqueeze(-1), self.m).view( + batch_size, input_size, 1, self.m, 1 + ) # bs x d x 1 x m x 1 + Z = self._transform(X, C / self.filter) * inputs + const + logdet = torch.log(torch.abs(self._transform(X, C))).sum( + dim=1, keepdim=True + ) + return Z, logdet + + def _transform(self, X, C): + CX = torch.matmul(C, X) # bs x d x k x m x 1 + XCX = torch.matmul(X.transpose(3, 4), CX) # bs x d x k x 1 x 1 + summed = XCX.squeeze(-1).squeeze(-1).sum(-1) # bs x d + return summed + + def _jacob(self, inputs, mode="direct"): + X = inputs.clone() + X.requires_grad_() + X.retain_grad() + d = X.size(0) + + X_in = X.unsqueeze(0) + C, const = self.conditioner(X_in) + Xpow = SOSFlow.power(X_in.unsqueeze(-1), self.m).view( + 1, d, 1, self.m, 1 + ) # bs x d x 1 x m x 1 + Z = (self._transform(Xpow, C / self.filter) * X_in + const).view(-1) + + J = torch.zeros(d, d) + for i in range(d): + self.zero_grad() + Z[i].backward(retain_graph=True) + J[i, :] = X.grad + + del X, X_in, C, const, Xpow, Z + return J + + +class MADE(nn.Module): + """An implementation of MADE + (https://arxiv.org/abs/1502.03509s). + """ + + def __init__(self, num_inputs, num_hidden, act="relu", pre_exp_tanh=False): + super(MADE, self).__init__() + + activations = {"relu": nn.ReLU, "sigmoid": nn.Sigmoid, "tanh": nn.Tanh} + act_func = activations[act] + + input_mask = get_mask( + num_inputs, num_hidden, num_inputs, mask_type="input" + ) + hidden_mask = get_mask(num_hidden, num_hidden, num_inputs) + output_mask = get_mask( + num_hidden, num_inputs * 2, num_inputs, mask_type="output" + ) + + self.joiner = MaskedLinear(num_inputs, num_hidden, input_mask) + + self.trunk = nn.Sequential( + act_func(), + MaskedLinear(num_hidden, num_hidden, hidden_mask), + act_func(), + MaskedLinear(num_hidden, num_inputs * 2, output_mask), + ) + + def forward(self, inputs, mode="direct"): + if mode == "direct": + h = self.joiner(inputs) + m, a = self.trunk(h).chunk(2, 1) + u = (inputs - m) * torch.exp(-a) + return u, -a.sum(-1, keepdim=True) + + else: + x = torch.zeros_like(inputs) + for i_col in range(inputs.shape[1]): + h = self.joiner(x) + m, a = self.trunk(h).chunk(2, 1) + x[:, i_col] = ( + inputs[:, i_col] * torch.exp(a[:, i_col]) + m[:, i_col] + ) + return x, -a.sum(-1, keepdim=True) + + def _jacob(self, inputs): + X = inputs.clone() + X.requires_grad_() + X.retain_grad() + d = X.size(0) + + X_in = X.unsqueeze(0) + + h = self.joiner(X_in) + m, a = self.trunk(h).chunk(2, 1) + u = ((X_in - m) * torch.exp(-a)).view(-1) + + J = torch.zeros(d, d) + for i in range(d): + self.zero_grad() + u[i].backward(retain_graph=True) + J[i, :] = X.grad + + del X, X_in, h, m, a, u + return J + + +class BatchNormFlow(nn.Module): + """An implementation of a batch normalization layer from + Density estimation using Real NVP + (https://arxiv.org/abs/1605.08803). + """ + + def __init__(self, num_inputs, momentum=0.0, eps=1e-5): + super(BatchNormFlow, self).__init__() + + self.log_gamma = nn.Parameter(torch.zeros(num_inputs)) + self.beta = nn.Parameter(torch.zeros(num_inputs)) + self.momentum = momentum + self.eps = eps + + self.register_buffer("running_mean", torch.zeros(num_inputs)) + self.register_buffer("running_var", torch.ones(num_inputs)) + + def forward(self, inputs, mode="direct"): + if mode == "direct": + if True: # self.training: + self.batch_mean = inputs.mean(0) + self.batch_var = (inputs - self.batch_mean).pow(2).mean( + 0 + ) + self.eps + + self.running_mean.mul_(self.momentum) + self.running_var.mul_(self.momentum) + + self.running_mean.add_( + self.batch_mean.data * (1 - self.momentum) + ) + self.running_var.add_( + self.batch_var.data * (1 - self.momentum) + ) + + mean = self.batch_mean + var = self.batch_var + else: + mean = self.running_mean + var = self.running_var + + x_hat = (inputs - mean) / var.sqrt() + y = torch.exp(self.log_gamma) * x_hat + self.beta + return y, (self.log_gamma - 0.5 * torch.log(var)).sum( + -1, keepdim=True + ) + else: + if True: # self.training: + mean = self.batch_mean + var = self.batch_var + else: + mean = self.running_mean + var = self.running_var + + x_hat = (inputs - self.beta) / torch.exp(self.log_gamma) + + y = x_hat * var.sqrt() + mean + + return y, (-self.log_gamma + 0.5 * torch.log(var)).sum( + -1, keepdim=True + ) + + def _jacob(self, X): + return None + + +class Reverse(nn.Module): + """An implementation of a reversing layer from + Density estimation using Real NVP + (https://arxiv.org/abs/1605.08803). + """ + + def __init__(self, num_inputs): + super(Reverse, self).__init__() + self.perm = np.array(np.arange(0, num_inputs)[::-1]) + self.inv_perm = np.argsort(self.perm) + + def forward(self, inputs, mode="direct"): + if mode == "direct": + return inputs[:, self.perm], torch.zeros( + inputs.size(0), 1, device=inputs.device + ) + else: + return inputs[:, self.inv_perm], torch.zeros( + inputs.size(0), 1, device=inputs.device + ) + + def _jacob(self, X): + return None + + +class FlowSequential(nn.Sequential): + """A sequential container for flows. + In addition to a forward pass it implements a backward pass and + computes log jacobians. + """ + + def forward(self, inputs, mode="direct", logdets=None): + """Performs a forward or backward pass for flow modules. + Args: + inputs: a tuple of inputs and logdets + mode: to run direct computation or inverse + """ + self.num_inputs = inputs.size(-1) + + if logdets is None: + logdets = torch.zeros(inputs.size(0), 1, device=inputs.device) + + assert mode in ["direct", "inverse"] + if mode == "direct": + for module in self._modules.values(): + inputs, logdet = module(inputs, mode) + logdets += logdet + else: + for module in reversed(self._modules.values()): + inputs, logdet = module(inputs, mode) + logdets += logdet + + return inputs, logdets + + def evaluate(self, inputs): + N = len(self._modules) + outputs = torch.zeros( + N + 1, inputs.size(0), inputs.size(1), device=inputs.device + ) + outputs[0, :, :] = inputs + logdets = torch.zeros(N, inputs.size(0), 1, device=inputs.device) + for i in range(N): + outputs[i + 1, :, :], logdets[i, :, :] = self._modules[str(i)]( + outputs[i, :, :], mode="direct" + ) + return outputs, logdets + + def log_probs(self, inputs): + u, log_jacob = self(inputs) + log_probs = (-0.5 * u.pow(2) - 0.5 * math.log(2 * math.pi)).sum( + -1, keepdim=True + ) + return (log_probs + log_jacob).sum(-1, keepdim=True) + + def sample(self, num_samples=None, noise=None, cond_inputs=None): + if noise is None: + noise = torch.Tensor(num_samples, self.num_inputs).normal_() + device = next(self.parameters()).device + noise = noise.to(device) + if cond_inputs is not None: + cond_inputs = cond_inputs.to(device) + samples = self.forward(noise, cond_inputs, mode="inverse")[0] + return samples + + def jacobians(self, X): + assert len(X.size()) == 1 + N = len(self._modules) + num_inputs = X.size(-1) + jacobians = torch.zeros(N, num_inputs, num_inputs) + + n_jacob = 0 + for i in range(N): + J_i = self._modules[str(i)]._jacob(X) + if J_i is not None: + jacobians[n_jacob, :, :] = J_i + n_jacob += 1 + del J_i + + return jacobians[:n_jacob, :, :] + + +def build_model(input_size, hidden_size, k, r, n_blocks): + modules = [] + for i in range(n_blocks): + modules += [ + SOSFlow(input_size, hidden_size, k, r), + BatchNormFlow(input_size), + Reverse(input_size), + ] + model = FlowSequential(*modules) + + for module in model.modules(): + if isinstance(module, nn.Linear): + nn.init.orthogonal_(module.weight) + + return model + + +datasets = { + "power": {"hidden_size": 100, "k": 5, "r": 4, "d": 6, "n_blocks": 8}, + "gas": {"hidden_size": 100, "k": 5, "r": 4, "d": 8, "n_blocks": 8}, + "hepmass": {"hidden_size": 512, "k": 5, "r": 4, "d": 21, "n_blocks": 8}, + "miniboone": {"hidden_size": 512, "k": 5, "r": 4, "d": 43, "n_blocks": 8}, + "bsds300": {"hidden_size": 512, "k": 5, "r": 4, "d": 63, "n_blocks": 8}, + "mnist": {"hidden_size": 100, "k": 5, "r": 4, "d": 784, "n_blocks": 8}, + "cifar": {"hidden_size": 100, "k": 5, "r": 4, "d": 3072, "n_blocks": 8}, +} + + +def format_as_str(num): + if num / 1e9 > 1: + factor, suffix = 1e9, "B" + elif num / 1e6 > 1: + factor, suffix = 1e6, "M" + elif num / 1e3 > 1: + factor, suffix = 1e3, "K" + else: + factor, suffix = 1e0, "" + + num_factored = num / factor + + if num_factored / 1e2 > 1 or True: + num_rounded = str(int(round(num_factored))) + elif num_factored / 1e1 > 1: + num_rounded = f"{num_factored:.1f}" + else: + num_rounded = f"{num_factored:.2f}" + + return f"{num_rounded}{suffix} % {num}" + + +def sos_size(d, hidden_size, k, r, n_blocks): + model = build_model(d, hidden_size, k, r, n_blocks) + model = NetHandle(model) + return model.size() + + +for dset, s_kwargs in datasets.items(): + print(f"{dset}: #params: {format_as_str(sos_size(**s_kwargs))}") diff --git a/paper/node/script/sparsehessian.py b/paper/node/script/sparsehessian.py new file mode 100644 index 0000000..408fe49 --- /dev/null +++ b/paper/node/script/sparsehessian.py @@ -0,0 +1,321 @@ +# * +# @file Different utility functions +# Copyright (c) Zhewei Yao, Amir Gholami +# All rights reserved. +# This file is part of PyHessian library. +# +# PyHessian is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PyHessian is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PyHessian. If not, see . +# * + +import torch +import math +from torch.autograd import Variable +import numpy as np + +from pyhessian.utils import ( + group_product, + group_add, + normalization, + get_params_grad, + hessian_vector_product, + orthnormal, +) + + +class hessian: + """ + The class used to compute : + i) the top 1 (n) eigenvalue(s) of the neural network + ii) the trace of the entire neural network + iii) the estimated eigenvalue density + """ + + def __init__( + self, model, criterion, data=None, dataloader=None, cuda=True + ): + """ + model: the model that needs Hessain information + criterion: the loss function + data: a single batch of data, including inputs and its corresponding labels + dataloader: the data loader including bunch of batches of data + """ + + # make sure we either pass a single batch or a dataloader + assert (data != None and dataloader == None) or ( + data == None and dataloader != None + ) + + self.model = model.eval() # make model is in evaluation model + self.criterion = criterion + + if data != None: + self.data = data + self.full_dataset = False + else: + self.data = dataloader + self.full_dataset = True + + if cuda: + self.device = "cuda" + else: + self.device = "cpu" + + # pre-processing for single batch case to simplify the computation. + if not self.full_dataset: + self.inputs, self.targets = self.data + if self.device == "cuda": + self.inputs, self.targets = ( + self.inputs.cuda(), + self.targets.cuda(), + ) + + # if we only compute the Hessian information for a single batch data, we can re-use the gradients. + outputs = self.model(self.inputs) + loss = self.criterion(outputs, self.targets) + loss.backward(create_graph=True) + + # this step is used to extract the parameters from the model + params, gradsH = get_params_grad(self.model) + self.params = params + self.gradsH = gradsH # gradient used for Hessian computation + + # store sparsity masks + self.masks = [param == 0.0 for param in self.params] + self.masks2 = [param != 0.0 for param in self.params] + + def reduce(self, v): + """Reduct group of vector.""" + return [vi[m2] for vi, m2 in zip(v, self.masks2)] + + def sparsify_(self, v): + """Sparsify a group of vectors according to the masks.""" + for vi, mask in zip(v, self.masks): + vi[mask] = 0 + return v + + def dataloader_hv_product(self, v): + + device = self.device + num_data = 0 # count the number of datum points in the dataloader + + THv = [ + torch.zeros(p.size()).to(device) for p in self.params + ] # accumulate result + for inputs, targets in self.data: + self.model.zero_grad() + tmp_num_data = inputs.size(0) + outputs = self.model(inputs.to(device)) + loss = self.criterion(outputs, targets.to(device)) + loss.backward(create_graph=True) + params, gradsH = get_params_grad(self.model) + self.model.zero_grad() + Hv = torch.autograd.grad( + gradsH, + params, + grad_outputs=v, + only_inputs=True, + retain_graph=False, + ) + THv = [ + THv1 + Hv1 * float(tmp_num_data) + 0.0 + for THv1, Hv1 in zip(THv, Hv) + ] + num_data += float(tmp_num_data) + + THv = [THv1 / float(num_data) for THv1 in THv] + eigenvalue = group_product(THv, v).cpu().item() + return eigenvalue, THv + + def eigenvalues(self, maxIter=100, tol=1e-3, top_n=1): + """ + compute the top_n eigenvalues using power iteration method + maxIter: maximum iterations used to compute each single eigenvalue + tol: the relative tolerance between two consecutive eigenvalue computations from power iteration + top_n: top top_n eigenvalues will be computed + """ + + assert top_n >= 1 + + device = self.device + + eigenvalues = [] + eigenvectors = [] + + computed_dim = 0 + + while computed_dim < top_n: + eigenvalue = None + v = [ + torch.randn(p.size()).to(device) for p in self.params + ] # generate random vector + v = normalization(v) # normalize the vector + v = self.sparsify_(v) # sparsify vector + + for i in range(maxIter): + v = orthnormal(v, eigenvectors) + v = self.sparsify_(v) + self.model.zero_grad() + + if self.full_dataset: + tmp_eigenvalue, Hv = self.dataloader_hv_product(v) + else: + Hv = hessian_vector_product(self.gradsH, self.params, v) + tmp_eigenvalue = group_product(Hv, v).cpu().item() + + v = normalization(Hv) + v = self.sparsify_(v) + + if eigenvalue == None: + eigenvalue = tmp_eigenvalue + else: + if ( + abs(eigenvalue - tmp_eigenvalue) + / (abs(eigenvalue) + 1e-6) + < tol + ): + break + else: + eigenvalue = tmp_eigenvalue + eigenvalues.append(eigenvalue) + eigenvectors.append(v) + computed_dim += 1 + + return eigenvalues, eigenvectors + + def trace(self, maxIter=100, tol=1e-3): + """ + compute the trace of hessian using Hutchinson's method + maxIter: maximum iterations used to compute trace + tol: the relative tolerance + """ + device = self.device + trace_vhv = [] + trace = 0.0 + + for i in range(maxIter): + self.model.zero_grad() + v = [ + torch.randint_like(p, high=2, device=device) + for p in self.params + ] + # generate Rademacher random variables + for v_i in v: + v_i[v_i == 0] = -1 + + # sparsify random variables according to params + v = self.sparsify_(v) + + if self.full_dataset: + _, Hv = self.dataloader_hv_product(v) + else: + Hv = hessian_vector_product(self.gradsH, self.params, v) + Hv = self.reduce(Hv) + v = self.reduce(v) + trace_vhv.append(group_product(Hv, v).cpu().item()) + if abs(np.mean(trace_vhv) - trace) / (trace + 1e-6) < tol: + return trace_vhv + else: + trace = np.mean(trace_vhv) + print("No convergence") + return trace_vhv + + def density(self, iter=100, n_v=1): + """ + compute estimated eigenvalue density using stochastic lanczos algorithm (SLQ) + iter: number of iterations used to compute trace + n_v: number of SLQ runs + """ + device = self.device + eigen_list_full = [] + weight_list_full = [] + + for k in range(n_v): + v = [ + torch.randint_like(p, high=2, device=device) + for p in self.params + ] + # generate Rademacher random variables + for v_i in v: + v_i[v_i == 0] = -1 + v = normalization(v) + v = self.sparsify_(v) + + # standard lanczos algorithm initlization + v_list = [v] + w_list = [] + alpha_list = [] + beta_list = [] + ############### Lanczos + for i in range(iter): + self.model.zero_grad() + w_prime = [ + torch.zeros(p.size()).to(device) for p in self.params + ] + if i == 0: + if self.full_dataset: + _, w_prime = self.dataloader_hv_product(v) + else: + w_prime = hessian_vector_product( + self.gradsH, self.params, v + ) + w_prime = self.sparsify_(w_prime) + alpha = group_product(w_prime, v) + alpha_list.append(alpha.cpu().item()) + w = group_add(w_prime, v, alpha=-alpha) + w_list.append(w) + else: + beta = torch.sqrt(group_product(w, w)) + beta_list.append(beta.cpu().item()) + if beta_list[-1] != 0.0: + # We should re-orth it + v = orthnormal(w, v_list) + v = self.sparsify_(v) + v_list.append(v) + else: + # generate a new vector + w = [ + torch.randn(p.size()).to(device) + for p in self.params + ] + w = self.sparsify_(w) + v = orthnormal(w, v_list) + v = self.sparsify_(v) + v_list.append(v) + if self.full_dataset: + _, w_prime = self.dataloader_hv_product(v) + else: + w_prime = hessian_vector_product( + self.gradsH, self.params, v + ) + w_prime = self.sparsify_(w_prime) + alpha = group_product(w_prime, v) + alpha_list.append(alpha.cpu().item()) + w_tmp = group_add(w_prime, v, alpha=-alpha) + w_tmp = self.sparsify_(w_tmp) + w = group_add(w_tmp, v_list[-2], alpha=-beta) + + T = torch.zeros(iter, iter).to(device) + for i in range(len(alpha_list)): + T[i, i] = alpha_list[i] + if i < len(alpha_list) - 1: + T[i + 1, i] = beta_list[i] + T[i, i + 1] = beta_list[i] + a_, b_ = torch.eig(T, eigenvectors=True) + + eigen_list = a_[:, 0] + weight_list = b_[0, :] ** 2 + eigen_list_full.append(list(eigen_list.cpu().numpy())) + weight_list_full.append(list(weight_list.cpu().numpy())) + + return eigen_list_full, weight_list_full diff --git a/paper/node/script/view_hessian.py b/paper/node/script/view_hessian.py new file mode 100644 index 0000000..142edac --- /dev/null +++ b/paper/node/script/view_hessian.py @@ -0,0 +1,276 @@ +"""Analyze trained networks via Hessian.""" +# %% +import argparse +import os +import warnings +import sys +import copy +import numpy as np +import torch +from torchprune.util.train import _get_loss_handle +from torchprune.util import models as tp_models +import experiment +from experiment.util.file import get_parameters + +PARSER = argparse.ArgumentParser( + description="Sparse Flow - Hessian Analysis", +) + +PARSER.add_argument( + "-p", + "--param", + type=str, + default="paper/node/param/toy/ffjord/gaussians/l2_h128_sigmoid_da.yaml", + dest="param_file", + help="provide a parameter file", +) + + +# switch to root folder for data +FOLDER = os.path.abspath("") +if "paper/node/script" in FOLDER: + SRC_FOLDER = os.path.join(FOLDER, "../../..") + os.chdir(SRC_FOLDER) + +# add script path to sys path +sys.path.append("./paper/node/script") + +# import our custom pyhessian library +from sparsehessian import hessian + + +# %% Some stuff +class HiddenPrints: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + warnings.simplefilter("ignore") + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout + warnings.simplefilter("default") + + +# retrieve file +ARGS = PARSER.parse_args() +FILE = ARGS.param_file + + +# %% Run the Hessian Stats +def _hessian_spectrum(dataset, criterion, net): + """Return the Hutchison-based trace estimator of the Hessian.""" + param0 = next(net.parameters()) + if param0.numel() == 1: + param0.requires_grad = False + + with torch.enable_grad(): + # get Hessian compute model + hessian_comp = hessian(net, criterion, data=dataset, cuda=True) + + if False: + # get trace + return np.mean(hessian_comp.trace(maxIter=200, tol=1e-4)) + if False: + # get top eigenvalue + eigs, _ = hessian_comp.eigenvalues(maxIter=100, tol=1e-4) + return eigs[-1] + + # get spectrum data + eigs, _ = hessian_comp.density(iter=100, n_v=3) + return np.asarray(eigs).mean(axis=0) + + +def get_spectrum_stats(spec_collection, loss_collection): + """Return useful stats from spectrum.""" + spec_collect_filt = [spec[spec > 0] for spec in spec_collection] + + # compute spectral norm, largest eigenvalue + spec_norm = np.max(spec_collection, axis=-1).mean() + + # compute trace, sum over all eigenvalues + trace = np.mean([np.sum(spec) for spec in spec_collect_filt]) + + # compute condition number, max/min eigenvalue + cond_number = np.mean( + [np.max(spec) / np.min(spec) for spec in spec_collect_filt] + ) + + # get average loss + loss = np.mean(loss_collection) + + # print stats + print( + ", ".join( + [ + f"NLL={loss:.5f}", + f"lambda_max={spec_norm:.5f}", + f"trace={trace:.5f}", + f"kappa={cond_number:.5f}", + ] + ) + ) + + +def get_bptt_net(net, param): + """Return the same net with BPTT (autograd) instead of adjoint backprop.""" + net_name = param["network"]["name"] + num_classes = param["network"]["outputSize"] + + net_bptt = getattr(tp_models, f"{net_name}_autograd")(num_classes) + net_bptt.load_state_dict(net.state_dict()) + + return net_bptt + + +def generate_hessian_stats(logger, param, data_size=0.1, num_reps=3): + """Plot the distribution beautifully.""" + save_and_load = True + + with HiddenPrints(): + logger.initialize_from_param(param, setup_print=False) + evaluator = experiment.Evaluator(logger) + loader_train = evaluator.get_dataloader("train")[0] + criterion = _get_loss_handle(evaluator._net_trainer.train_params) + device = "cuda" + print(logger._results_dir) + + # create huge tensor of the data + dataset = loader_train.dataset + inputs = torch.stack([data[0] for data in dataset]).to(device) + targets = torch.tensor([data[1] for data in dataset]).to(device) + + # create a subset of the data as well + indices = torch.randperm(len(inputs))[: int(data_size * len(inputs))] + subset = ( + inputs[indices].detach().clone(), + targets[indices].detach().clone(), + ) + + # store prune ratios and add zero prune ratio + prune_ratios = 1 - np.array(evaluator._keep_ratios) + prune_ratios = np.concatenate(([0.0], prune_ratios)) + + # dictionary to store spectrum results + hessian_tag = "hessian_spectrum" + spectrum_results = {} + + # check and load if anything is already stored + if save_and_load: + spectrum_results.update(logger.load_custom_state(tag=hessian_tag)) + + # check required number of reps + num_nets = evaluator._num_nets + num_reps_experiment = evaluator._num_repetitions + num_reps_per_net = int(np.ceil(num_reps / num_nets)) + + for method_name in evaluator._method_names: + if "ReferenceNet" in method_name: + continue + print("") + for s_idx, pr in enumerate(prune_ratios): + # setup collection of hessian stats for this + spectrum_collection = [] + loss_collection = [] + + print_key = ", ".join( + [method_name, f"pr_idx={s_idx}", f"PR={pr*100:5.1f}%"] + ) + print(f"{print_key}: Estimating Hessian Spectrum") + + # check if we need to compute of these to know whether we save + saving_required = False + + for n_idx in range(num_nets): + for r_idx in range(num_reps_per_net): + key = "_".join( + map( + str, + [ + n_idx, + r_idx, + s_idx, + int(pr * 10000), + method_name, + ], + ) + ) + hessian_key = f"{key}_hessian" + loss_key = f"{key}_loss" + + # only re-compute hessian results if necessary + if hessian_key not in spectrum_results: + with HiddenPrints(): + try: + if pr == 0.0: + lookup_name = "ReferenceNet" + else: + lookup_name = method_name + net = evaluator.get_by_pr( + prune_ratio=pr, + method=lookup_name, + n_idx=n_idx, + r_idx=r_idx % num_reps_experiment, + ).compressed_net.torchnet + except FileNotFoundError: + continue + + # wrap net into net with autograd instead of adjoint + # torchdyn adjoint breaks create_graph=True in + # backwards pass, which you need for any kind of + # Hessian computation ... + net_bptt = get_bptt_net(net, param) + + # generate spectrum + net_bptt = net_bptt.to(device) + spectrum = _hessian_spectrum( + subset, criterion, net_bptt + ) + + # get train loss + net = net.to(device) + loss = criterion(net(inputs), targets).item() + + # update results and store again + spectrum_results[hessian_key] = copy.deepcopy(spectrum) + spectrum_results[loss_key] = copy.deepcopy(loss) + + # finalize + del net, net_bptt, loss, spectrum + torch.cuda.empty_cache() + + # recall to save later on + saving_required = True + + # get stats and collect them together + spectrum = copy.deepcopy(spectrum_results[hessian_key]) + loss = copy.deepcopy(spectrum_results[loss_key]) + + spectrum_collection.append(spectrum) + loss_collection.append(loss) + + # store latest results + if save_and_load and saving_required: + logger.save_custom_state(spectrum_results, hessian_tag) + print("Hessian update saved") + + # process collected spectrums and losses + if len(spectrum_collection) > 0: + spectrum_collection_np = np.asarray(spectrum_collection) + loss_collection_np = np.asarray(loss_collection) + get_spectrum_stats(spectrum_collection_np, loss_collection_np) + else: + print("No networks available") + print("") + + +def main(file): + # get a logger and the parameters + print(file) + logger = experiment.Logger() + param = next(get_parameters(file, 1, 0)) + generate_hessian_stats(logger, param) + + +if __name__ == "__main__": + main(FILE) diff --git a/paper/node/script/view_hessian_wrapper.py b/paper/node/script/view_hessian_wrapper.py new file mode 100644 index 0000000..ee46ea0 --- /dev/null +++ b/paper/node/script/view_hessian_wrapper.py @@ -0,0 +1,36 @@ +"""Wrapper for Hessian since we keep running into """ + +import subprocess +import argparse + + +PARSER = argparse.ArgumentParser( + description="Sparse Flow - Hessian Analysis", +) + +PARSER.add_argument( + "param_file", + type=str, + metavar="param_file", + help="provide a parameter file", +) + +# retrieve file +ARGS = PARSER.parse_args() +FILE = ARGS.param_file + + +def main(file): + for _ in range(5000): + ret_code = subprocess.run( + ["python", "paper/node/script/view_hessian.py", "-p", file] + ).returncode + if ret_code: + print("Catching CUDA-OOM and retrying.") + else: + print("Finished successfully without CUDA-OOM failure.") + break + + +if __name__ == "__main__": + main(FILE) diff --git a/paper/node/script/view_modes.py b/paper/node/script/view_modes.py new file mode 100644 index 0000000..7fe0cf5 --- /dev/null +++ b/paper/node/script/view_modes.py @@ -0,0 +1,344 @@ +"""Analyze CNFs via Mode Analysis.""" +# %% +import argparse +import os +import warnings +import sys +import copy +import numpy as np +import torch +import experiment +from experiment.util.file import get_parameters +from torchprune.util.models import Ffjord, FfjordCNF + +PARSER = argparse.ArgumentParser( + description="Sparse Flow - Mode Analysis", +) + +PARSER.add_argument( + "-p", + "--param", + type=str, + default="paper/node/param/toy/ffjord/gaussians/l4_h64_sigmoid_da.yaml", + dest="param_file", + help="provide a parameter file", +) + + +# switch to root folder for data +FOLDER = os.path.abspath("") +if "paper/node/script" in FOLDER: + SRC_FOLDER = os.path.join(FOLDER, "../../..") + os.chdir(SRC_FOLDER) + +# add script path to sys path +sys.path.append("./paper/node/script") + + +# %% Some stuff +class HiddenPrints: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + warnings.simplefilter("ignore") + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout + warnings.simplefilter("default") + + +# retrieve file +ARGS = PARSER.parse_args() +FILE = ARGS.param_file + + +# %% Main functions + + +def get_modes(dataset): + """Retrieve the modes of the dataset.""" + inputs = torch.stack([data[0] for data in dataset]) + targets = torch.tensor([data[1] for data in dataset]) + + # flatten inputs + inputs = inputs.reshape(inputs.shape[0], -1) + + # unique labels + labels = torch.unique(targets) + + # computes and corresponding covariance matrices + modes = np.zeros((len(labels), inputs.shape[1])) + covs = np.zeros((len(labels), inputs.shape[1], inputs.shape[1])) + for i_lab, label in enumerate(labels): + inputs_lab = inputs[targets == label] + modes[i_lab] = torch.mean(inputs_lab, dim=0).cpu().numpy() + covs[i_lab] = np.cov(inputs_lab.cpu().numpy().T) + + return modes, covs + + +def sample_torchdyn_ffjord(net, num_samples=20000): + """Sample from a torchdyn network.""" + device = next(net.parameters()).device + + # extract ffjord model + model = net.model + + # set s-span but keep old one around! + s_span_backup = copy.deepcopy(model[1].s_span) + model[1].s_span = torch.linspace(1, 0, 2).to(device) + + sample = net.prior.sample(torch.Size([num_samples])).to(device) + with torch.no_grad(): + x_sampled = model(sample) + + # restore s-span + model[1].s_span = s_span_backup + + return x_sampled[:, 1:] + + +def sample_ffjord_cnf(net, dataset, num_samples=2500): + """Sample from ffjord ffjord cnf.""" + device = next(net.parameters()).device + + # extract ffjord model + model = net.model + + # start prior + data_shape = dataset[0][0].shape + data_numel = dataset[0][0].numel() + prior = torch.distributions.MultivariateNormal( + torch.zeros(data_numel), torch.eye(data_numel) + ) + + # sample now from model + batch_size = 1250 + samples_post_all = [] + for _ in range((num_samples - 1) // batch_size + 1): + samples_prior = prior.sample((batch_size,)) + samples_prior = samples_prior.view(batch_size, *data_shape) + with torch.no_grad(): + samples_post = model(samples_prior.to(device), reverse=True) + samples_post = samples_post.view(batch_size, -1).detach().cpu() + samples_post_all.append(samples_post) + + return torch.cat(samples_post_all) + + +def sample(net, dataset): + """Sample from the network.""" + if isinstance(net, Ffjord): + samples = sample_torchdyn_ffjord(net) + elif isinstance(net, FfjordCNF): + samples = sample_ffjord_cnf(net, dataset) + else: + raise NotImplementedError("Only works for torchdyn ffjord currently.") + + return samples.cpu().numpy() + + +def sample_and_compute_mode_distance(net, dataset, modes, covs): + """Sample from the network and compute distance to each mode.""" + # let's sample first + samples = sample(net, dataset) + + # now figure out squared distances of samples to modes as a multiplicative + # factor of variance projected onto this direction from the cov-matrix + # A little more explanation: + # d^2 = "multiplies of variance" == "multiple of std.dev. squared" + # x = sample + # var_unnormed = (x - mode)' * Cov * (x-mode) + # var = var_unnormed / ||x - mode||^2 + # d^2 = ||x - mode||^2 / var + # = ||x - mode||^4 / var_unnormed + # + # d = ||x - mode||^2 / sqrt((x - mode)' * Cov * (x - mode)) + + # now do the computations + # modes.shape == num_modes x dim_state + # shape = batch_size x num_modes x dim_state + samples_centered = samples[:, None] - modes[None] + + # compute "unnormalized variance" using np.matmul broadcasting rules + # covs.shape == num_modes x dim_state x dim_state + # var_unnormed.shape == num_samples x num_modes + var_unnormed = covs[None] @ samples_centered[..., None] + var_unnormed = (samples_centered[:, :, None] @ var_unnormed)[:, :, 0, 0] + + # compute distance now + # shape == num_samples x num_modes + dist_unnormed = np.linalg.norm(samples_centered, ord=2, axis=-1) + dist_normalized = dist_unnormed * dist_unnormed / np.sqrt(var_unnormed) + + return dist_normalized + + +def get_mode_stats(mode_distances): + """Return stats about distance to nearest mode.""" + dist_checkers = [0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 7.0, 10.0, 15.0, 20.0] + min_distances = np.min(mode_distances, axis=-1) + num_modes = mode_distances.shape[-1] + + high_quality_ratio = [ + (min_distances <= dist).sum(axis=-1).mean() / mode_distances.shape[1] + for dist in dist_checkers + ] + + modes_captured = [ + np.any(mode_distances < dist, axis=1).sum(axis=1).mean() + for dist in dist_checkers + ] + + print( + "Normalized std. dev. : " + + " | ".join(map(lambda x: f" {x:6.2f}", dist_checkers)) + ) + print( + "High-quality samples : " + + " | ".join(map(lambda x: f"{x*100:6.2f}%", high_quality_ratio)) + ) + print( + f"Modes captured (Total: {num_modes:3.0f}): " + + " | ".join(map(lambda x: f" {x:6.2f}", modes_captured)) + ) + + +def generate_mode_stats(logger, param, num_reps=15): + """Plot the distribution beautifully.""" + # turn saving/loading on and off + save_and_load = True + + # initialize experiment with logger and evaluator + with HiddenPrints(): + logger.initialize_from_param(param, setup_print=False) + evaluator = experiment.Evaluator(logger) + loader_train = evaluator.get_dataloader("train")[0] + dataset = loader_train.dataset + print(logger._results_dir) + + # do cuda computations + device = "cuda" + + # store prune ratios and add zero prune ratio + prune_ratios = 1 - np.array(evaluator._keep_ratios) + prune_ratios = np.concatenate(([0.0], prune_ratios)) + + # dictionary to store mode results + mode_tag = "mode_analysis" + mode_results = {} + + # check and load if anything is already stored + if save_and_load: + mode_results.update(logger.load_custom_state(tag=mode_tag)) + + # get mean and variance of each mode if not already pre-computed and save + if "modes" in mode_results: + modes, covs = mode_results["modes"], mode_results["covs"] + else: + modes, covs = get_modes(dataset) + mode_results["modes"] = modes + mode_results["covs"] = covs + if save_and_load: + logger.save_custom_state(mode_results, mode_tag) + + # check required number of reps + num_nets = evaluator._num_nets + num_reps_experiment = evaluator._num_repetitions + num_reps_per_net = int(np.ceil(num_reps / num_nets)) + + for method_name in evaluator._method_names: + if "ReferenceNet" in method_name: + continue + print("") + for s_idx, pr in enumerate(prune_ratios): + # setup collection of mode distances for this run + mode_dist_collected = [] + print_key = ", ".join( + [method_name, f"pr_idx={s_idx}", f"PR={pr*100:5.1f}%"] + ) + print(f"{print_key}: Estimating Mode distances") + + # check if we need to compute of these to know whether we save + saving_required = False + + # compute mode_distances + for n_idx in range(num_nets): + for r_idx in range(num_reps_per_net): + key = "_".join( + map( + str, + [ + n_idx, + r_idx, + s_idx, + int(pr * 10000), + method_name, + ], + ) + ) + mode_key = f"{key}_modes" + + # only re-compute mode results if necessary + if mode_key not in mode_results: + with HiddenPrints(): + try: + if pr == 0.0: + lookup_name = "ReferenceNet" + else: + lookup_name = method_name + net = evaluator.get_by_pr( + prune_ratio=pr, + method=lookup_name, + n_idx=n_idx, + r_idx=r_idx % num_reps_experiment, + ).compressed_net.torchnet + except FileNotFoundError: + continue + + # set and generate mode assignments for samples + net = net.to(device) + mode_distances = sample_and_compute_mode_distance( + net, dataset, modes, covs + ) + + # update results + mode_results[mode_key] = copy.deepcopy(mode_distances) + + # finalize + del net, mode_distances + + # recall to save later on + saving_required = True + + # get stats and collect them together + mode_distances = copy.deepcopy(mode_results[mode_key]) + mode_dist_collected.append(mode_distances) + + # store latest results + if save_and_load and saving_required: + logger.save_custom_state(mode_results, mode_tag) + + # process collected mode distances + # shape = num_reps x num_samples x num_modes + if len(mode_dist_collected) > 0: + mode_dist_collected_np = np.asarray(mode_dist_collected) + get_mode_stats(mode_dist_collected_np) + else: + print("No networks available") + print("") + + +## %% Execute main +def main(file): + # get a logger and the parameters + print("\n") + print(file) + logger = experiment.Logger() + param = next(get_parameters(file, 1, 0)) + generate_mode_stats(logger, param) + + +if __name__ == "__main__": + main(FILE) diff --git a/paper/node/script/view_results.py b/paper/node/script/view_results.py new file mode 100644 index 0000000..506a775 --- /dev/null +++ b/paper/node/script/view_results.py @@ -0,0 +1,762 @@ +"""View and plot Neural ODE results.""" +# %% +import os +import warnings +import sys +import copy +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from PIL import Image +from scipy import signal +import experiment +from experiment.util.file import get_parameters + +# change working directory to src +from IPython import get_ipython + +# make sure it's using only GPU here... +os.environ["CUDA_VISIBLE_DEVICES"] = "1" # noqa + + +# switch to root folder for data +folder = os.path.abspath("") +if "paper/node/script" in folder: + src_folder = os.path.join(folder, "../../..") + os.chdir(src_folder) + +# add script path to sys path +sys.path.append("./paper/node/script") + +# %% Define some parameters +FILES = [ + # # TOY, VANILLA CNF GENERATIVE MODEL EXPERIMENTS + # "paper/node/param/toy/ffjord/gaussians/vanilla_l2_h128.yaml", + # "paper/node/param/toy/ffjord/gaussiansspiral/vanilla_l4_h64.yaml", + # "paper/node/param/toy/ffjord/spirals/vanilla_l4_h64.yaml", + # # TOY, GENERATIVE MODEL EXPERIMENTS + # "paper/node/param/toy/ffjord/gaussians/l4_h64_sigmoid_da.yaml", + # "paper/node/param/toy/ffjord/gaussians/l2_h128_sigmoid_da.yaml", + # "paper/node/param/toy/ffjord/gaussiansspiral/l4_h64_sigmoid_da.yaml", + # "paper/node/param/toy/ffjord/spirals/l4_h64_sigmoid_da.yaml", + # # + # # TOY, CLASSIFICATION EXPERIMENTS + # "paper/node/param/toy/node/concentric/l2_h128_tanh_da.yaml", + # "paper/node/param/toy/node/moons/l2_h3_tanh_da.yaml", + # "paper/node/param/toy/node/moons/l2_h32_tanh_da.yaml", + # "paper/node/param/toy/node/moons/l2_h64_tanh_da.yaml", + # "paper/node/param/toy/node/moons/l2_h128_tanh_da.yaml", + # "paper/node/param/toy/node/spirals/l2_h64_relu_da.yaml", + # # + # # TABULAR EXPERIMENTS + # "paper/node/param/tabular/power/l3_hm10_f5_tanh.yaml", + # "paper/node/param/tabular/gas/l3_hm20_f5_tanh.yaml", + # "paper/node/param/tabular/hepmass/l2_hm10_f10_softplus.yaml", + # "paper/node/param/tabular/miniboone/l2_hm20_f1_softplus.yaml", + # "paper/node/param/tabular/bsds300/l3_hm20_f2_softplus.yaml", + # # + # # IMAGE EXPERIMENTS + # "paper/node/param/cnf/mnist_multiscale.yaml", + # "paper/node/param/cnf/cifar_multiscale.yaml", +] + +PLOT_FILTERS = [ + ["WT", "FT"], + ["WT"], + # ["FT"], +] + +STYLE_KWARGS = { + "savgol_on": True, + "savgol_mean": {"window_length": 3, "polyorder": 1}, + "savgol_std": {"window_length": 9, "polyorder": 1}, + "label": {"fontsize": 20}, + "tick": {"labelsize": 16}, + "xlim": [0, 85], + "ylim": [1.5, 1.85], + "legend": { + "loc": "upper left", + "bbox_to_anchor": (0.1, 1.3), + "fontsize": 20, + }, + "WT": { + "plot": {"color": "darkblue", "ls": "-"}, + "fill": {"color": "lightskyblue", "alpha": 0.4}, + }, + "FT": { + "plot": {"color": "darkgreen", "ls": "--"}, + "fill": {"color": "green", "alpha": 0.2}, + }, +} + +NUM_REP_LOSS = 12 # we want a total of 12 reps for the loss for better std dev + +PLOT_FOLDER_SPECIAL = os.path.abspath("data/node/plots") +INLINE_PLOT = False + +GEN_NODE_FIGS = False +REGEN_NODE_FIGS = False +GEN_ALL_NODE_FIGS = False +REGEN_FIGS = False + +GEN_PAPER_FIGS_LOSS = True +GEN_PAPER_FIGS_DISTRIBUTION = True + + +# %% Some helpful functions +class HiddenPrints: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + warnings.simplefilter("ignore") + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout + warnings.simplefilter("default") + + +def generate_node_figs(logger, cnf_plots=False): + """Generate and store the Neural ODE figures for each model.""" + with HiddenPrints(): + evaluator = experiment.Evaluator(logger) + loader_test = evaluator.get_dataloader("test")[0] + + if cnf_plots: + from plots_cnf import plot_all + else: + from plots2d import plot_all + + # for n_idx in range(evaluator._num_nets): + for n_idx in range(1): + for r_idx in range(evaluator._num_repetitions): + for s_idx, keep_ratio in enumerate(evaluator._keep_ratios): + for method_name in evaluator._method_names: + if "ReferenceNet" in method_name and s_idx > 0: + continue + tag = "_".join( + [ + method_name, + f"n{n_idx}", + f"r{r_idx}", + f"i{s_idx}", + f"p{keep_ratio:.4f}", + ] + ) + plt_folder = os.path.join(logger._plot_dir, "flow", tag) + + if os.path.exists(plt_folder) and not REGEN_NODE_FIGS: + continue + with HiddenPrints(): + try: + net = evaluator.get_by_pr( + prune_ratio=1.0 - keep_ratio, + method=method_name, + n_idx=n_idx, + r_idx=r_idx, + ).compressed_net.torchnet + except FileNotFoundError: + continue + print(plt_folder) + plot_all( + net, + loader_test, + plot_folder=plt_folder, + all_p=GEN_ALL_NODE_FIGS + or "ReferenceNet" in method_name, + ) + + +def get_results(file, logger, gen_node, regen_figs): + """Grab all the results according to the file.""" + results = [] + params = [] + labels = [] + graphers_all = [] + # Loop through all experiments + for param in get_parameters(file, 1, 0): + # initialize logger and setup parameters + with HiddenPrints(): + logger.initialize_from_param(param, setup_print=False) + + # don't + try: + state = logger.get_global_state() + except ValueError: + print("Global state not computed, handle with care!") + state = copy.deepcopy(logger._stats) + + # extract the results + results.append(copy.deepcopy(state)) + params.append(copy.deepcopy(param)) + + # extract the legend (based on heuristic) + label = param["generated"]["datasetTest"].split("_") + if len(label) > 2: + label = label[2:] + labels.append("_".join(label)) + + # store custom plots for neural ode as well. + # only do that for Toy Examples though ... + if gen_node and "toy" in file: + generate_node_figs(logger, cnf_plots="ffjord" in file) + + if not regen_figs or not logger.state_loaded: + continue + + # extract the plots and store them. + try: + with HiddenPrints(): + graphers = logger.generate_plots(store_figs=False) + for grapher in graphers: + grapher.store_plot() + graphers_all.append(graphers) + except: + print("Could not generate main graphs.") + graphers_all.append([]) + + return results, params, labels, graphers_all + + +def get_and_store_results(file, logger, gen_node=False, regen_figs=False): + print(f"PARAM FILE: {file}") + # get the results specified in the file (and hopefully pre-computed) + results, params, _, _ = get_results(file, logger, gen_node, regen_figs) + + for param in params: + print(f"PLOT FOLDER: {param['generated']['plotDir']}\n") + + return results, params + + +# %% Retrieve results +# make sure matplotlib works correctly +IN_JUPYTER = True +try: + if INLINE_PLOT: + get_ipython().run_line_magic("matplotlib", "inline") + else: + get_ipython().run_line_magic("matplotlib", "agg") +except AttributeError: + IN_JUPYTER = False + +# get a logger +LOGGER = experiment.Logger() +STATS_ALL = [] +PARAM_ALL = [] +for file in FILES: + STATS, PARAM = get_and_store_results( + file, LOGGER, gen_node=GEN_NODE_FIGS, regen_figs=REGEN_FIGS + ) + STATS_ALL.append(STATS) + PARAM_ALL.append(PARAM) + + +# %% now re-plot the loss so it looks better with smoothing +def resample_loss(logger, i_gen=None): + """Resample the loss from the networks and return results.""" + if i_gen is None: + tag_gen = logger.dataset_test + else: + tag_gen = f"{logger.dataset_test}_regen_{i_gen}" + + # try loading the re-generated loss if it exists and is compatible + # we should also check that we get the same valid sizes since they + # additional data might have generated when less networks were available + stats_new = logger.load_custom_state(tag_gen) + if logger._check_compatibility(stats_new): + mask_new = np.all(stats_new["sizes"] != 0.0, axis=(0, 2)) + mask_old = np.all(logger.sizes != 0.0, axis=(0, 2)) + if np.all(mask_new == mask_old): + print("Loaded re-sampled stats") + return stats_new + + print("Generating re-sampled stats.") + + with HiddenPrints(): + evaluator = experiment.Evaluator(logger) + + # store prune ratios and add zero prune ratio + prune_ratios = 1 - np.array(evaluator._keep_ratios) + prune_ratios = np.concatenate(([0.0], prune_ratios)) + + for n_idx in range(evaluator._num_nets): + for r_idx in range(evaluator._num_repetitions): + for s_idx, keep_ratio in enumerate(evaluator._keep_ratios): + for a_idx, method_name in enumerate(evaluator._method_names): + if "ReferenceNet" in method_name and s_idx > 0: + continue + with HiddenPrints(): + try: + ffjord_net = evaluator.get_by_pr( + prune_ratio=1.0 - keep_ratio, + method=method_name, + n_idx=n_idx, + r_idx=r_idx, + ) + except FileNotFoundError as f_e: + if "ReferenceNet" in method_name: + raise f_e + else: + continue + # now re-do the stats + logger.update_global_state( + n_idx=n_idx, s_idx=s_idx, r_idx=r_idx, a_idx=a_idx + ) + evaluator._do_stats(ffjord_net.cuda()) + + # store re-generated stats + if tag_gen is not None: + logger.save_custom_state(logger._stats, tag_gen) + print("Saving re-generated data") + + return copy.deepcopy(logger._stats) + + +def format_as_str(num): + if num / 1e9 > 1: + factor, suffix = 1e9, "B" + elif num / 1e6 > 1: + factor, suffix = 1e6, "M" + elif num / 1e3 > 1: + factor, suffix = 1e3, "K" + else: + factor, suffix = 1e0, "" + + num_factored = num / factor + if num_factored / 1e2 > 1: + num_rounded = str(int(round(num_factored))) + elif num_factored / 1e1 > 1: + num_rounded = f"{num_factored:.1f}" + else: + num_rounded = f"{num_factored:.2f}" + return f"{num_rounded}{suffix} % {num}" + + +def plot_loss( + logger, + param, + stats, + plot_filters, + style_kwargs, + plt_folder, + num_rep, + compression_rate=False, + use_loss=True, +): + """Plot everything starting from stats and param.""" + # get reference index and names + idx_ref = stats["methods"].index("ReferenceNet") + names = np.delete(stats["names"], idx_ref) + + # initialize logger to current parameters + with HiddenPrints(): + logger.initialize_from_param(param, setup_print=False) + + def _extract_pr_loss(stats): + # [num_nets, num_intervals, num_repetitions, num_algorithms] + prune_ratios = 100.0 * (1.0 - stats["sizes"]) + # prune_ratios = 1.0 / stats["sizes"] + if use_loss: + loss = copy.deepcopy(stats["loss"]) + else: + loss = copy.deepcopy(stats["error"]) * 100.0 + + # add 0 prune ratio to data + prune_ratios = np.pad(prune_ratios, [(0, 0), (1, 0), (0, 0), (0, 0)]) + prune_ratios[:, 0] = 0.0 + loss = np.pad(loss, [(0, 0), (1, 0), (0, 0), (0, 0)]) + loss[:, 0] = loss[:, 1, :, idx_ref : idx_ref + 1] + + # remove ref idx, shape=[num_nets, num_intervals, num_rep, num_alg - 1] + prune_ratios = np.delete(prune_ratios, idx_ref, axis=3) + loss = np.delete(loss, idx_ref, axis=3) + + return prune_ratios, loss + + # get pr and loss with re-sampling always + prune_ratios, loss = None, None + + # re-generate loss until we have enough repetitions + i_gen = 0 + while loss is None or loss[:, 0, :, 0].size < num_rep: + print(f"\nResampling Loss, i_gen={i_gen}") + if num_rep > 1: + stats_new = resample_loss(logger, i_gen) + else: + stats_new = resample_loss(logger) + pr_new, loss_new = _extract_pr_loss(stats_new) + if prune_ratios is None: + prune_ratios = pr_new + loss = loss_new + else: + prune_ratios = np.concatenate((prune_ratios, pr_new), axis=2) + loss = np.concatenate((loss, loss_new), axis=2) + i_gen += 1 + + def _extract_valid_pr_loss(idx_alg): + """Extract valid PRs and losses for desired algorithm index.""" + # shape=[num_nets, num_intervals, num_rep, num_alg - 1] + num_intervals = prune_ratios.shape[1] + pr_m, l_m, l_std = [], [], [] + for i_pr in range(num_intervals): + # extract raw PR, loss for desired algorithm and interval + pr_one_i = prune_ratios[:, i_pr, :, idx_alg].flatten() + loss_one_i = loss[:, i_pr, :, idx_alg].flatten() + + # determine valid entries/repetitions + valid = pr_one_i != 100.0 + + # don't add if nothing valid + if sum(valid) < 1: + continue + + # filter for valid entries + pr_one_i = pr_one_i[valid] + loss_one_i = loss_one_i[valid] + + # store stats + pr_m.append(np.mean(pr_one_i)) + l_m.append(np.mean(loss_one_i)) + l_std.append(np.std(loss_one_i)) + + pr_m, l_m, l_std = np.asarray([pr_m, l_m, l_std]) + return pr_m, l_m, l_std + + def _plot(filter, legend_on=True): + fig = plt.figure() + sns.set_theme() + legends = [] + legends_lookup = { + "WT": "Unstructured Pruning", + "FT": "Structured Pruning", + } + + for name in filter: + # get right data + idx = np.argwhere(names == name) + if len(idx) != 1: + continue + idx = idx[0].item() + + # get valid PRs and loss + pr, l_m, l_std = _extract_valid_pr_loss(idx) + + # collect names for legend + legends.append(legends_lookup[name]) + + # try some smoothing + if style_kwargs["savgol_on"]: + l_m_filt = signal.savgol_filter( + l_m, **style_kwargs["savgol_mean"] + ) + l_std_filt = signal.savgol_filter( + l_std, **style_kwargs["savgol_std"] + ) + else: + l_m_filt = l_m + l_std_filt = l_std + + # plot + # fig.gca().plot(pr, l_m, color="red") + fig.gca().plot(pr, l_m_filt, **style_kwargs[name]["plot"]) + fig.gca().fill_between( + pr, + l_m_filt - l_std_filt, + l_m_filt + l_std_filt, + **style_kwargs[name]["fill"], + ) + + # axis labels + if compression_rate: + fig.gca().set_xlabel("Compression Rate", **style_kwargs["label"]) + else: + fig.gca().set_xlabel("Prune Ratio (%)", **style_kwargs["label"]) + if use_loss: + fig.gca().set_ylabel("Loss (NLL)", **style_kwargs["label"]) + else: + fig.gca().set_ylabel("Top-1 Error (%)", **style_kwargs["label"]) + # ticks + fig.gca().tick_params(axis="both", **style_kwargs["tick"]) + + # x limits and y limits + fig.gca().set_xlim(style_kwargs["xlim"]) + fig.gca().set_ylim(style_kwargs["ylim"]) + + # legend now + if legend_on: + fig.gca().legend( + legends, ncol=len(legends), **style_kwargs["legend"] + ) + + # a few stylistic changes + fig.gca().spines["top"].set_visible(False) + fig.gca().spines["right"].set_visible(False) + fig.set_tight_layout(True) + + return fig + + for filters in plot_filters: + # check if all methods that filters wants exist + if not all([filt in names for filt in filters]): + print(filters) + continue + # generate and store figure + fig = _plot(filters, legend_on=False) + file_name = "_".join(filters) + ".pdf" + file_name = os.path.join(plt_folder, file_name) + os.makedirs(plt_folder, exist_ok=True) + fig.savefig(file_name, bbox_inches="tight") + + # now also print data + size_abs = np.mean(stats["sizes_total"]) + for idx_alg, name in enumerate(names): + prs_one, losses_one, _ = _extract_valid_pr_loss(idx_alg) + for pr, loss_one in zip(prs_one, losses_one): + size_pruned = (1 - pr / 100.0) * size_abs + print( + f"Sparse Flows ({name}, PR={int(round(pr))}\\%) & " + f"{loss_one:.2f} & {format_as_str(size_pruned)}" + ) + + +def plot_flow(logger, param, plt_folder, cnf_plots=True): + """Plot the distribution beautifully.""" + if cnf_plots: + import plots_cnf as plots + else: + import plots2d as plots + + print(f"PLOT FOLDER: {plt_folder}") + + def _plot_distribution(plots_kwargs, tag): + # plots once with the default scatter plot + fig = plt.figure(figsize=(5, 5)) + sns.set_style("ticks") + axis = fig.gca() + + plots.plot_for_sweep(axis=axis, **plots_kwargs) + + if cnf_plots: + axis.set_xlim([-2, 2]) + axis.set_ylim([-2, 2]) + axis.set_aspect("equal") + else: + axis.set_aspect(1.5) + plt.axis("off") + axis.get_xaxis().set_ticks([]) + axis.get_yaxis().set_ticks([]) + plt.tight_layout() + + # store first plot + plt_folder_original = os.path.join(plt_folder, "distribution_original") + file_name = os.path.join(plt_folder_original, tag + ".jpg") + os.makedirs(plt_folder_original, exist_ok=True) + fig.savefig(file_name, bbox_inches="tight", pad_inches=0) + + # now re-load plot and filter out light colors + if cnf_plots: + img = np.copy(np.asarray(Image.open(file_name))) + if not (IN_JUPYTER and INLINE_PLOT): + threshold = 150 + else: + threshold = 200 + img[img > threshold] = 255 + + # show filtered plot + fig2 = plt.figure(figsize=(5, 5)) + plt.imshow(img) + fig2.gca().set_aspect("equal") + plt.axis("off") + plt.tight_layout() + + # store filtered plot + plt_folder_filtered = os.path.join( + plt_folder, "distribution_filtered" + ) + file_name2 = os.path.join(plt_folder_filtered, tag + ".jpg") + os.makedirs(plt_folder_filtered, exist_ok=True) + # Image.fromarray(img).save(file_name2) + fig2.savefig(file_name2, bbox_inches="tight", pad_inches=0) + + if not (IN_JUPYTER and INLINE_PLOT): + plt.close(fig) + if cnf_plots: + plt.close(fig2) + + def _plot_field(plots_kwargs, tag): + fig = plt.figure(figsize=(5, 5)) + sns.set_style("ticks") + axis = fig.gca() + + # PLOTTING CODE + plots.plot_static_vector_field(axis=axis, **plots_kwargs) + + if cnf_plots: + axis.set_xlim([-2, 2]) + axis.set_ylim([-2, 2]) + axis.set_aspect("equal") + else: + axis.set_aspect(1.5) + plt.axis("off") + axis.set_title(None) + axis.get_xaxis().set_ticks([]) + axis.get_yaxis().set_ticks([]) + plt.tight_layout() + + # store plot + plt_folder_field = os.path.join(plt_folder, "field") + file_name = os.path.join(plt_folder_field, tag + ".jpg") + os.makedirs(plt_folder_field, exist_ok=True) + fig.savefig(file_name, bbox_inches="tight", pad_inches=0) + + if not (IN_JUPYTER and INLINE_PLOT): + plt.close(fig) + + def _plot_trajectory(plots_kwargs, tag, labels=False): + sns.set_context("paper", font_scale=1.5) + fig = plt.figure(figsize=(5, 3.5)) + axis1 = fig.add_subplot(211) + axis2 = fig.add_subplot(212) + + # PLOTTING CODE + plots.plot_2D_depth_trajectory( + axis1=axis1, axis2=axis2, **plots_kwargs + ) + + # axis limits + xlim = [0, 1] + if cnf_plots: + xlim = xlim[::-1] + axis1.set_xlim(xlim) + axis2.set_xlim(xlim) + + ylim1 = axis1.get_ylim() + ylim2 = axis2.get_ylim() + ylim = np.maximum(ylim1, ylim2) + axis1.set_ylim(ylim) + axis2.set_ylim(ylim) + + # axis layout + sns.despine(offset=10, trim=True) + axis1.get_xaxis().set_ticks([]) + axis1.get_xaxis().set_visible(False) + axis1.spines["bottom"].set_visible(False) + fig.tight_layout() + + # store plot + plt_folder_traj = os.path.join(plt_folder, "trajectory") + file_name = os.path.join(plt_folder_traj, tag + ".jpg") + os.makedirs(plt_folder_traj, exist_ok=True) + fig.savefig(file_name, bbox_inches="tight", pad_inches=0) + + # setup labels as separate plot + labels = labels and not cnf_plots + if labels: + from matplotlib.lines import Line2D + + legend_handles = {} + for color, label in zip(["midnightblue", "darkorange"], [0, 1]): + legend_handles[f"Class {label}"] = Line2D( + [0], [0], color=color, lw=1.5 + ) + fig_labels = plt.figure(figsize=(1, 1)) + fig_labels.gca().legend( + list(legend_handles.values()), list(legend_handles.keys()) + ) + fig_labels.gca().set_axis_off() + fig_labels.tight_layout() + file_name_labels = "labels.pdf" + file_name_labels = os.path.join(plt_folder_traj, file_name_labels) + fig_labels.savefig( + file_name_labels, bbox_inches="tight", pad_inches=0 + ) + + # close figure + if not (IN_JUPYTER and INLINE_PLOT): + plt.close(fig) + if labels: + plt.close(fig_labels) + + with HiddenPrints(): + logger.initialize_from_param(param, setup_print=False) + evaluator = experiment.Evaluator(logger) + loader_test = evaluator.get_dataloader("test")[0] + + # store prune ratios and add zero prune ratio + prune_ratios = 1 - np.array(evaluator._keep_ratios) + prune_ratios = np.concatenate(([0.0], prune_ratios)) + + for n_idx in range(evaluator._num_nets): + for r_idx in range(evaluator._num_repetitions): + for s_idx, pr in enumerate(prune_ratios): + for method_name in evaluator._method_names: + if "ReferenceNet" in method_name: + continue + with HiddenPrints(): + try: + if pr == 0.0: + lookup_name = "ReferenceNet" + else: + lookup_name = method_name + ffjord_net = evaluator.get_by_pr( + prune_ratio=pr, + method=lookup_name, + n_idx=n_idx, + r_idx=r_idx, + ).compressed_net.torchnet + except FileNotFoundError: + continue + + tag = "_".join( + [ + logger.names[logger.methods.index(method_name)], + f"n{n_idx}", + f"r{r_idx}", + f"i{s_idx:02d}", + f"p{int(pr*100):03d}", + ] + ) + + # setup and generate data, and plot + plots_kwargs = plots.prepare_data( + ffjord_net.cuda(), loader_test, n_samp=50000 + ) + _plot_distribution(plots_kwargs, tag) + try: + _plot_field(plots_kwargs, tag) + except ValueError: + pass + _plot_trajectory( + plots_kwargs, + tag, + labels=n_idx == 0 and r_idx == 0 and s_idx == 0, + ) + + print(f"Done with pr={pr:.2f}, r_idx={r_idx}, n_idx={n_idx}") + + +for STATS, PARAMS in zip(STATS_ALL, PARAM_ALL): + for STAT, PARAM in zip(STATS, PARAMS): + NET_NAME = PARAM["generated"]["netName"] + DSET = PARAM["generated"]["datasetTest"] + NETWORK = PARAM["network"]["name"] + PLT_FOLDER = os.path.join(PLOT_FOLDER_SPECIAL, DSET, NET_NAME) + IS_CNF = "ffjord_" in NET_NAME or "cnf_" in NET_NAME + if GEN_PAPER_FIGS_LOSS: + FOLDER_LOSS = os.path.join(PLT_FOLDER, "loss") + plot_loss( + LOGGER, + PARAM, + STAT, + PLOT_FILTERS, + STYLE_KWARGS, + FOLDER_LOSS, + NUM_REP_LOSS, + use_loss=IS_CNF, + ) + if ( + GEN_PAPER_FIGS_DISTRIBUTION + and "toy" in PARAM["network"]["dataset"].lower() + ): + plot_flow(LOGGER, PARAM, PLT_FOLDER, IS_CNF) diff --git a/paper/node/script/view_toysweep.py b/paper/node/script/view_toysweep.py new file mode 100644 index 0000000..79af8fd --- /dev/null +++ b/paper/node/script/view_toysweep.py @@ -0,0 +1,535 @@ +"""View and plot Neural ODE sweep results.""" +# %% +import os +import sys +import re +import warnings +import copy +import math +import yaml +import experiment +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import numpy as np +from experiment.util.grapher import Grapher +from experiment.util.file import get_parameters, load_param_from_file + +# change working directory to src +from IPython import get_ipython + +# make sure it's using only GPU here... +os.environ["CUDA_VISIBLE_DEVICES"] = "0" # noqa + + +# switch to root folder for data +folder = os.path.abspath("") +if "paper/node/script" in folder: + src_folder = os.path.join(folder, "../../..") + os.chdir(src_folder) + +# add script path to sys path +sys.path.append("./paper/node/script") + +# %% Define some parameters +FILE = "paper/node/param/toy/ffjord/gaussians/sweep_model_da.yaml" + +INLINE_PLOT = False +USE_JPG = True + +GEN_FIGS = False +GEN_ABS_FIGS = True +GEN_POT_FIGS = True +GEN_NODE_FIGS = True +REGEN_NODE_FIGS = True + +COMM_LEVEL = 0.005 +# fmt: off +FILTER_METHODS = [ + ["WT"], + ["WT", "FT"], + # ["SiPP", "PFP"], + # ["WT", "SiPP"], + # ["FT", "PFP"], +] +# fmt: on +IS_FFJORD = "ffjord" in FILE +if IS_FFJORD: + import plots_cnf as plots +else: + import plots2d as plots + +# %% Some helpful functions +class HiddenPrints: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + warnings.simplefilter("ignore") + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout + warnings.simplefilter("default") + + +def plot_abs_size_acc( + logger, params, customizations, plots_dir, plot_loss, inline +): + """Plot the absolute trade-off between # of parameters and accuracy.""" + + # collect data function + def _collect_size_err(param): + # initialize logger and setup parameters + with HiddenPrints(): + logger.initialize_from_param(param, setup_print=False) + + # compute absolute sizes + sizes_abs = logger.sizes * logger.sizes_total[:, None, None, None] + + # get the desired score to plot + if plot_loss: + err_abs = copy.deepcopy(logger.loss) + else: + err_abs = copy.deepcopy(logger.error) + + # add the reference error and size at the beginning + ref_idx = logger.names.index("ReferenceNet") + sizes_abs = np.concatenate( + ( + np.broadcast_to( + logger.sizes_total[:, None, None, None], sizes_abs.shape + )[:, :1], + sizes_abs, + ), + axis=1, + ) + err_abs = np.concatenate( + ( + np.repeat( + err_abs[:, 0:1, :, ref_idx : ref_idx + 1], + err_abs.shape[3], + axis=3, + ), + err_abs, + ), + axis=1, + ) + + return sizes_abs, err_abs + + # Loop through all experiments and collect data + sizes_abs, err_abs = None, None + for i, param in enumerate(params): + sizes_abs_one, err_abs_one = _collect_size_err(param) + if sizes_abs is None: + sizes_abs = np.zeros( + sizes_abs_one.shape + (len(params),), dtype=int + ) + err_abs = copy.deepcopy(sizes_abs).astype(float) + sizes_abs[:, :, :, :, i] = sizes_abs_one + err_abs[:, :, :, :, i] = err_abs_one + + # get dataset now + dataset = logger.dataset_test + + # plot per method now + mcolor_list = list(mcolors.CSS4_COLORS.keys()) + custom_colors = [ + mcolor_list[hash(custom) % len(mcolor_list)] + for custom in customizations + ] + for i_m, method in enumerate(logger.names): + + # graph the absolute trade-off + grapher = Grapher( + x_values=sizes_abs[:, :, :, i_m], + y_values=err_abs[:, :, :, i_m], + folder=os.path.join(plots_dir, "tradeoff"), + file_name=f"err_{method}.pdf", + ref_idx=0, + x_min=0, + x_max=1e20, + legend=customizations, + colors=custom_colors, + xlabel="# of parameters", + ylabel="Loss" if plot_loss else "Error", + title=f"{method}, {dataset}", + ) + grapher.graph( + show_ref=True, + show_delta=False, + remove_outlier=True, + store=False, + percentage_y=not plot_loss, + kwargs_legend={ + "loc": "upper left", + "ncol": 1, + "bbox_to_anchor": (1.05, 1.1), + }, + ) + grapher.store_plot() + if not inline: + plt.close(grapher._figure) + + # plot per customization now + for i_c, custom in enumerate(customizations): + + # graph the absolute trade-off + grapher = Grapher( + x_values=sizes_abs[:, :, :, :, i_c], + y_values=err_abs[:, :, :, :, i_c], + folder=os.path.join(plots_dir, "tradeoff"), + file_name=f"err_{custom}.pdf", + ref_idx=0, + x_min=0, + x_max=1e20, + legend=copy.deepcopy(np.array(logger.names)).tolist(), + colors=copy.deepcopy(np.array(logger._colors)).tolist(), + xlabel="# of parameters", + ylabel="Loss" if plot_loss else "Error", + title=f"{custom}, {dataset}", + ) + grapher.graph( + show_ref=True, + show_delta=False, + remove_outlier=True, + store=False, + percentage_y=not plot_loss, + ) + grapher.store_plot() + if not inline: + plt.close(grapher._figure) + + +def get_results(file, logger, regen_figs): + """Grab all the results according to the file.""" + stats = [] + params = [] + # Loop through all experiments + for param in get_parameters(file, 1, 0): + # initialize logger and setup parameters + with HiddenPrints(): + logger.initialize_from_param(param, setup_print=False) + + # print message if incomplete but don't stop + if not logger.state_loaded: + print("Grabbing incomplete results!") + + # compute the stats + try: + stats_one = logger.compute_stats(store_report=False) + except ValueError as err: + print( + "Computing stats failed. Make sure that all partial results " + "are stored as numpy, e.g., by running it in parallel." + ) + raise err + + # extract the results + stats.append(copy.deepcopy(stats_one)) + params.append(copy.deepcopy(param)) + + # extract the plots and store them. + if not regen_figs or not logger.state_loaded: + continue + try: + with HiddenPrints(): + graphers = logger.generate_plots(store_figs=False) + for grapher in graphers: + grapher.store_plot() + except: + print("Could not generate main graphs.") + + return stats, params + + +def extract_commensurate_size(stats, comm_level): + """Compute prune potential for all experiments and return it.""" + # get the index closest to our desired comm_level + c_idx = np.abs(np.array(stats[0]["commensurate"]) - comm_level).argmin() + + # pre-allocate results array + # stats_all[0]['eBest'] + # has shape (len(commensurate), num_nets, num_rep, num_alg) + _, num_nets, num_rep, num_alg = stats[0]["e_best"].shape + num_exp = len(stats) + size_comm = np.zeros((num_nets, num_exp, num_rep, num_alg)) + + for i, stats_one in enumerate(stats): + if stats_one is not None: + size_comm[:, i, :, :] = stats_one["siz_best"][c_idx] + + return size_comm + + +def get_fig_name(title, tag, legends=[]): + """Get the name of the figure with the title and tag.""" + fig_name = "_".join(re.split("/|-|_|,", title) + legends).replace(" ", "") + return f"{fig_name}_prunepot_{tag}.pdf" + + +def plot_commensurate_size( + size_comm, + legends, + colors, + customizations, + title, + plots_dir, + plots_tag, + comm_level, +): + """Plot the prune potential for all methods.""" + # get the x values + x_val = np.arange(size_comm.shape[1], dtype=float) + + grapher_comm = Grapher( + x_values=np.broadcast_to(x_val[None, :, None, None], size_comm.shape), + y_values=1.0 - size_comm, + folder=plots_dir, + file_name=get_fig_name(title, plots_tag, legends), + ref_idx=None, + x_min=-1e10, + x_max=1e10, + legend=legends, + colors=colors, + xlabel=f"Prune Potential, $\delta={comm_level * 100:.1f}\%$", + ylabel="Method", + title=title, + ) + + with HiddenPrints(): + img_comm = grapher_comm.graph_histo(normalize=False, store=False) + + # set custom x ticks with labels + img_comm.gca().set_xticks(x_val) + img_comm.gca().set_xticklabels(customizations, rotation=75, fontsize=20) + + # then store it + grapher_comm.store_plot() + + return img_comm + + +def generate_one_sweep( + logger, + params, + plt_folder, + n_idx, + r_idx, + method_name, + customizations, + inline, + use_jpg, + regen_figs, +): + # get keep ratios and add zero pruning as well + m_name_ref = "ReferenceNet" + keep_ratios = params[0]["generated"]["keepRatios"] + keep_ratios = np.concatenate(([1.0], keep_ratios)) + + figsize = [6.4, 4.8] + ncols = math.ceil(math.sqrt(len(params))) + nrows = math.ceil(len(params) / ncols) + figsize[0] *= ncols + figsize[1] *= nrows + + for s_idx, keep_ratio in enumerate(keep_ratios): + # only plot for one keep ratio for reference net. + if m_name_ref in method_name and s_idx > 0: + break + + # create folder + os.makedirs(plt_folder, exist_ok=True) + # generate figure name + tag = f"i_{s_idx - 1}_p{keep_ratio:.3f}" + if use_jpg: + file_ending = ".jpg" + else: + file_ending = ".pdf" + fig_file = os.path.join(plt_folder, tag + file_ending) + + # continue if exists and we shouldn't regen figures + if not regen_figs and os.path.exists(fig_file): + continue + + # initialize figure and layout + fig, axes = plt.subplots( + nrows=nrows, + ncols=ncols, + sharex=True, + sharey=True, + figsize=figsize, + squeeze=False, + ) + plt.style.use("default") + plt.rcParams.update( + { + "xtick.labelsize": 16, + "ytick.labelsize": 16, + } + ) + + # go through each parameter config of sweep + for axis, param, custom in zip(axes.flatten(), params, customizations): + logger.initialize_from_param(param, setup_print=False) + with HiddenPrints(): + # initialize evaluator and logger + evaluator = experiment.Evaluator(logger) + + # get data loader + loader_test = evaluator.get_dataloader("test")[0] + + # retrieve model and plot if it exists + try: + net = evaluator.get_by_pr( + prune_ratio=1.0 - keep_ratio, + n_idx=n_idx, + r_idx=r_idx, + method=method_name if keep_ratio < 1.0 else m_name_ref, + ).compressed_net.cuda() + except FileNotFoundError: + continue + + # set plot title + axis.set_title(f"{custom}\n#p={int(net.size())}", fontsize=20) + + # setup and generate plots + plots_kwargs = plots.prepare_data(net.torchnet, loader_test) + plots.plot_for_sweep(axis=axis, **plots_kwargs) + + # store plot at the end now. + fig.suptitle(f"{method_name}, {tag}", fontsize=24) + fig.savefig(fig_file, bbox_inches="tight") + if not inline: + plt.close(fig) + + +def generate_sweepy_figures( + logger, params, customizations, plot_dir, regen_figs, inline, use_jpg +): + """Generate figures with view over sweep.""" + # extract all repetitions from here + num_nets = params[0]["experiments"]["numNets"] + num_repetitions = params[0]["experiments"]["numRepetitions"] + method_names = params[0]["experiments"]["methods"] + + # loop through all repetitions and method names. + for n_idx in range(num_nets): + for r_idx in range(num_repetitions): + for method_name in method_names: + tag = "_".join([method_name, f"n{n_idx}", f"r{r_idx}"]) + plt_folder = os.path.join(plot_dir, tag) + + # print folder + print(plt_folder) + + # now plot sequence of keep ratios + generate_one_sweep( + logger, + params, + plt_folder, + n_idx, + r_idx, + method_name, + customizations, + inline, + use_jpg, + regen_figs, + ) + + # print finish message + print("Done\n") + + +# %% Retrieve results +# make sure matplotlib works correctly +IN_JUPYTER = True +try: + if INLINE_PLOT: + get_ipython().run_line_magic("matplotlib", "inline") + else: + get_ipython().run_line_magic("matplotlib", "agg") +except AttributeError: + IN_JUPYTER = False + +# get a logger +LOGGER = experiment.Logger() +print(f"PARAM FILE: {FILE}") + +# get the results specified in the file (and hopefully pre-computed) +STATS, PARAMS = get_results(FILE, LOGGER, GEN_FIGS) + +# extract some other info from params +LABELS_METHOD = PARAMS[0]["generated"]["network_names"] +COLORS_METHOD = PARAMS[0]["generated"]["network_colors"] +NETWORK_NAME = PARAMS[0]["network"]["name"] +TRAIN_DSET = PARAMS[0]["network"]["dataset"] +TITLE_PR = f"{NETWORK_NAME}, {TRAIN_DSET}" +PLOTS_DIR = os.path.join( + PARAMS[0]["generated"]["plotDir"], + os.path.splitext(os.path.basename(FILE))[0], +) +CUSTOMIZATIONS = [ + ", ".join([f"{k}={v}" for k, v in custom["value"].items()]) + if isinstance(custom["value"], dict) + else str(custom["value"]) + for custom in load_param_from_file(FILE)["customizations"] +] + +# print plot folders for reference +for param, custom in zip(PARAMS, CUSTOMIZATIONS): + print(f"Customizations: {custom}") + print(f"Plot Folder: {param['generated']['plotDir']}\n") + +# special folder as folder +print(f"Sweep plot folder: {PLOTS_DIR}") + + +# %% generate prune potential plot +if GEN_POT_FIGS: + # compute commensurate size for desired comm level for all results + SIZE_COMM = extract_commensurate_size(STATS, COMM_LEVEL) + + PLOTS_DIR_POT = os.path.join(PLOTS_DIR, "prune_pot") + print(PLOTS_DIR_POT) + + # plot a subset of the methods + for methods in FILTER_METHODS: + try: + idx_filt = [LABELS_METHOD.index(method) for method in methods] + except ValueError: + continue + fig = plot_commensurate_size( + size_comm=SIZE_COMM[:, :, :, idx_filt], + legends=methods, + colors=[COLORS_METHOD[i] for i in idx_filt], + customizations=CUSTOMIZATIONS, + title=TITLE_PR, + plots_dir=PLOTS_DIR_POT, + plots_tag="prune_pot", + comm_level=COMM_LEVEL, + ) + print("Done\n") + +# %% Generate abs parameter-error trade-off figures +if GEN_ABS_FIGS: + plot_abs_size_acc( + LOGGER, + PARAMS, + CUSTOMIZATIONS, + PLOTS_DIR, + IS_FFJORD, + INLINE_PLOT, + ) + + +# %% now generate the sweep plots +if GEN_NODE_FIGS: + generate_sweepy_figures( + LOGGER, + PARAMS, + CUSTOMIZATIONS, + PLOTS_DIR, + REGEN_NODE_FIGS, + INLINE_PLOT, + USE_JPG, + ) diff --git a/paper/sipp/README.md b/paper/sipp/README.md index 05311c7..78679f7 100644 --- a/paper/sipp/README.md +++ b/paper/sipp/README.md @@ -6,7 +6,7 @@ [Daniela Rus](http://danielarus.csail.mit.edu/) Implementation of provable pruning using sensitivity as introduced in [SiPPing -Neural Networks: Sensitivity-informed Provable Pruning of Neural Networks](https://arxiv.org/abs/1910.05422) +Neural Networks: Sensitivity-informed Provable Pruning of Neural Networks](https://doi.org/10.1137/20M1383239) (weight pruning). ***Equal contribution** @@ -55,10 +55,14 @@ Please cite our paper when using this codebase. ### Bibtex ``` -@article{baykal2019sipping, - title={SiPPing Neural Networks: Sensitivity-informed Provable Pruning of Neural Networks}, +@article{baykal2022sensitivity, + title={Sensitivity-informed provable pruning of neural networks}, author={Baykal, Cenk and Liebenwein, Lucas and Gilitschenski, Igor and Feldman, Dan and Rus, Daniela}, - journal={arXiv preprint arXiv:1910.05422}, - year={2019} + journal={SIAM Journal on Mathematics of Data Science}, + volume={4}, + number={1}, + pages={26--45}, + year={2022}, + publisher={SIAM} } ``` \ No newline at end of file diff --git a/paper/sipp/param/cifar/cascade/wrn28_2.yaml b/paper/sipp/param/cifar/cascade/wrn28_2.yaml new file mode 100644 index 0000000..2190269 --- /dev/null +++ b/paper/sipp/param/cifar/cascade/wrn28_2.yaml @@ -0,0 +1,31 @@ +network: + name: "wrn28_2" + dataset: "CIFAR10" + outputSize: 10 + +training: + file: "training/cifar/wrn.yaml" + +retraining: {} + +experiments: + methods: + - "SiPPNet" + - "SiPPNetRand" + - "SiPPNetHybrid" + mode: "cascade" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.02 + maxVal: 0.85 + + spacing: + - type: "geometric" + numIntervals: 20 + maxVal: 0.70 + minVal: 0.05 + + retrainIterations: -1 diff --git a/paper/sipp/param/cifar/sweep/resnet20_default.yaml b/paper/sipp/param/cifar/sweep/resnet20_default.yaml new file mode 100644 index 0000000..1a1ec69 --- /dev/null +++ b/paper/sipp/param/cifar/sweep/resnet20_default.yaml @@ -0,0 +1,36 @@ +network: + name: "resnet20" + dataset: "CIFAR10" + outputSize: 10 + +training: + file: "training/cifar/resnet.yaml" + +retraining: + numEpochs: 0 + +experiments: + methods: + - "SiPPNet" + - "SiPPNetRand" + - "SiPPNetHybrid" + mode: "retrain" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.02 + maxVal: 0.99 + + spacing: + - type: "geometric" + numIntervals: 30 + maxVal: 0.90 + minVal: 0.40 + + retrainIterations: -1 + +# coreset parameters +coresets: + deltaS: 1.0e-16 # 183 diff --git a/paper/sipp/param/cifar/sweep/resnet20_sizes.yaml b/paper/sipp/param/cifar/sweep/resnet20_sizes.yaml new file mode 100644 index 0000000..c6ad490 --- /dev/null +++ b/paper/sipp/param/cifar/sweep/resnet20_sizes.yaml @@ -0,0 +1,34 @@ +file: "paper/sipp/param/cifar/sweep/resnet20_default.yaml" +# Make sure num_customization is divisible by num_workers during deployment!!! +# Currently, we have 25 customizations here... +customizations: + - key: ["coresets", "deltaS"] + value: 1.0e+1 + - key: ["coresets", "deltaS"] + value: 1.0e+0 + - key: ["coresets", "deltaS"] + value: 1.0e-1 + - key: ["coresets", "deltaS"] + value: 1.0e-2 + - key: ["coresets", "deltaS"] + value: 1.0e-4 + - key: ["coresets", "deltaS"] + value: 1.0e-6 + - key: ["coresets", "deltaS"] + value: 1.0e-8 + - key: ["coresets", "deltaS"] + value: 1.0e-10 + - key: ["coresets", "deltaS"] + value: 1.0e-12 + - key: ["coresets", "deltaS"] + value: 1.0e-14 + - key: ["coresets", "deltaS"] + value: 1.0e-16 # standard, 183 + - key: ["coresets", "deltaS"] + value: 1.0e-20 + - key: ["coresets", "deltaS"] + value: 1.0e-24 + - key: ["coresets", "deltaS"] + value: 1.0e-28 + - key: ["coresets", "deltaS"] + value: 1.0e-32 diff --git a/paper/sipp/param/mnist/fc_nettrim.yaml b/paper/sipp/param/mnist/fc_nettrim.yaml new file mode 100644 index 0000000..2e72b64 --- /dev/null +++ b/paper/sipp/param/mnist/fc_nettrim.yaml @@ -0,0 +1,33 @@ +network: + name: "fcnet_nettrim" + dataset: "MNIST" + outputSize: 10 + +training: + file: "training/mnist/lenet.yaml" + +retraining: {} + +experiments: + methods: + - "SiPPNet" + - "SiPPNetRand" + - "SiPPNetHybrid" + - "ThresNet" + - "SnipNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.01 + maxVal: 0.3 + + spacing: + - type: "geometric" + numIntervals: 30 # number of intervals to 0.5 in cascade mode + maxVal: 0.80 + minVal: 0.01 + + retrainIterations: -1 diff --git a/paper/sipp/param/mnist/lenet300.yaml b/paper/sipp/param/mnist/lenet300.yaml new file mode 100644 index 0000000..d889727 --- /dev/null +++ b/paper/sipp/param/mnist/lenet300.yaml @@ -0,0 +1,33 @@ +network: + name: "lenet300_100" + dataset: "MNIST" + outputSize: 10 + +training: + file: "training/mnist/lenet.yaml" + +retraining: {} + +experiments: + methods: + - "SiPPNet" + - "SiPPNetRand" + - "SiPPNetHybrid" + - "ThresNet" + - "SnipNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.01 + maxVal: 0.3 + + spacing: + - type: "geometric" + numIntervals: 30 + maxVal: 0.80 + minVal: 0.01 + + retrainIterations: -1 diff --git a/paper/sipp/param/mnist/lenet300_sizes.yaml b/paper/sipp/param/mnist/lenet300_sizes.yaml new file mode 100644 index 0000000..f9775a4 --- /dev/null +++ b/paper/sipp/param/mnist/lenet300_sizes.yaml @@ -0,0 +1,60 @@ +file: "paper/sipp/param/mnist/lenet300.yaml" +# Make sure num_customization is divisible by num_workers during deployment!!! +# Currently, we have 15 customizations here... + +retraining: + numEpochs: 0 + +experiments: + methods: + - "SiPPNet" + - "SiPPNetRand" + - "SiPPNetHybrid" + mode: "retrain" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.01 + maxVal: 0.3 + + spacing: + - type: "geometric" + numIntervals: 30 + maxVal: 0.99 + minVal: 0.30 + + retrainIterations: -1 + +customizations: + - key: ["coresets", "deltaS"] + value: 1.0e+1 + - key: ["coresets", "deltaS"] + value: 1.0e+0 + - key: ["coresets", "deltaS"] + value: 1.0e-1 + - key: ["coresets", "deltaS"] + value: 1.0e-2 + - key: ["coresets", "deltaS"] + value: 1.0e-4 + - key: ["coresets", "deltaS"] + value: 1.0e-6 + - key: ["coresets", "deltaS"] + value: 1.0e-8 + - key: ["coresets", "deltaS"] + value: 1.0e-10 + - key: ["coresets", "deltaS"] + value: 1.0e-12 + - key: ["coresets", "deltaS"] + value: 1.0e-14 + - key: ["coresets", "deltaS"] + value: 1.0e-16 # standard, 183 + - key: ["coresets", "deltaS"] + value: 1.0e-20 + - key: ["coresets", "deltaS"] + value: 1.0e-24 + - key: ["coresets", "deltaS"] + value: 1.0e-28 + - key: ["coresets", "deltaS"] + value: 1.0e-32 diff --git a/paper/sipp/param/mnist/lenet300_sizes2.yaml b/paper/sipp/param/mnist/lenet300_sizes2.yaml new file mode 100644 index 0000000..86ba0b6 --- /dev/null +++ b/paper/sipp/param/mnist/lenet300_sizes2.yaml @@ -0,0 +1,60 @@ +file: "paper/sipp/param/mnist/lenet300.yaml" +# Make sure num_customization is divisible by num_workers during deployment!!! +# Currently, we have 15 customizations here... + +retraining: + startEpoch: 48 + +experiments: + methods: + - "SiPPNet" + - "SiPPNetRand" + - "SiPPNetHybrid" + mode: "retrain" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.01 + maxVal: 0.3 + + spacing: + - type: "geometric" + numIntervals: 30 + maxVal: 0.99 + minVal: 0.30 + + retrainIterations: -1 + +customizations: + - key: ["coresets", "deltaS"] + value: 1.0e+1 + - key: ["coresets", "deltaS"] + value: 1.0e+0 + - key: ["coresets", "deltaS"] + value: 1.0e-1 + - key: ["coresets", "deltaS"] + value: 1.0e-2 + - key: ["coresets", "deltaS"] + value: 1.0e-4 + - key: ["coresets", "deltaS"] + value: 1.0e-6 + - key: ["coresets", "deltaS"] + value: 1.0e-8 + - key: ["coresets", "deltaS"] + value: 1.0e-10 + - key: ["coresets", "deltaS"] + value: 1.0e-12 + - key: ["coresets", "deltaS"] + value: 1.0e-14 + - key: ["coresets", "deltaS"] + value: 1.0e-16 # standard, 183 + - key: ["coresets", "deltaS"] + value: 1.0e-20 + - key: ["coresets", "deltaS"] + value: 1.0e-24 + - key: ["coresets", "deltaS"] + value: 1.0e-28 + - key: ["coresets", "deltaS"] + value: 1.0e-32 diff --git a/paper/sipp/param/mnist/lenet5.yaml b/paper/sipp/param/mnist/lenet5.yaml new file mode 100644 index 0000000..2c143ef --- /dev/null +++ b/paper/sipp/param/mnist/lenet5.yaml @@ -0,0 +1,43 @@ +network: + name: "lenet5" + dataset: "MNIST" + outputSize: 10 + +training: + file: "training/mnist/lenet.yaml" + transformsTrain: + - type: Pad + kwargs: { padding: 4 } + - type: RandomCrop + kwargs: { size: 32 } + transformsTest: + - type: Pad + kwargs: { padding: 4 } + - type: CenterCrop + kwargs: { size: 32 } + +retraining: {} + +experiments: + methods: + - "SiPPNet" + - "SiPPNetRand" + - "SiPPNetHybrid" + - "ThresNet" + - "SnipNet" + mode: "cascade" + + numRepetitions: 1 + numNets: 3 + + plotting: + minVal: 0.01 + maxVal: 0.3 + + spacing: + - type: "geometric" + numIntervals: 30 + maxVal: 0.80 + minVal: 0.01 + + retrainIterations: -1 diff --git a/paper/sipp/script/results_comparisons.py b/paper/sipp/script/results_comparisons.py new file mode 100644 index 0000000..b25528f --- /dev/null +++ b/paper/sipp/script/results_comparisons.py @@ -0,0 +1,200 @@ +# %% +# make sure the setup is correct everywhere +import os +import copy +import numpy as np + +# change working directory to src +from IPython import get_ipython +import experiment +from experiment.util.file import get_parameters + +# make sure it's using only GPU here... +os.environ["CUDA_VISIBLE_DEVICES"] = "0" # noqa + +# switch to root folder for data +folder = os.path.abspath("") +if "paper/sipp/script" in folder: + src_folder = os.path.join(folder, "../../..") + os.chdir(src_folder) + +# %% +# parameters for running the test +FILE = "paper/sipp/param/mnist/fc_nettrim.yaml" +# FILE = "paper/sipp/param/mnist/lenet5.yaml" +FILE = "paper/sipp/param/cifar/cascade/wrn28_2.yaml" +INLINE_PLOT = True + +# %% Manually recorded numbers + +# RESULTS FROM NET-TRIM PAPER + +# NET-TRIM for FC-Nettrim and LeNet5 +# fmt: off +pr_nt = [44, 53.5, 59, 62, 66, 75, 80, 84.5, 88, 91, 93.5, 97, 99] +acc_nt = [98.7, 98.55, 98.6, 98.65, 98.62, 98.68, 98.62, 98.55, 98.45, 98.31, 98.07, 96.52, 0.0] +# fmt: on + +pr_nt5 = [0.0, 76.25, 79.25, 87.5, 94.75, 96.4, 97.75, 98.15, 98.7, 98.75] +acc_nt5 = [99.46, 99.46, 99.48, 99.48, 99.43, 99.41, 99.38, 99.33, 99.21, 0.0] + +# BAYESIAN COMPRESSION for FC-Nettrim and LeNet5 +pr_bc = [39.0, 45.0, 47.0, 59.0, 74.2, 79.0, 83.0, 99.0] +acc_bc = [98.05, 98.05, 97.95, 97.75, 97.40, 97.15, 97.05, 0.0] + +pr_bc5 = [0.0, 75.2, 78.33, 83.33, 93.2, 94.7, 99.5] +acc_bc5 = [99.3, 99.3, 99.15, 99.12, 98.66, 98.45, 0.0] + +# DYNAMIC NETWORK SURGERY for FC-Nettrim and LeNet5 +pr_dns = [51.5, 56.0, 57.5, 76.0, 88.0, 95.0, 97.7, 98.8, 99.5] +acc_dns = [98.60, 98.65, 98.62, 98.53, 98.60, 98.31, 96.68, 95.3, 0.0] + +pr_dns5 = [0.0, 78.5, 84.7, 90.00, 93.1, 97.2, 97.95, 98.5, 99.5] +acc_dns5 = [99.51, 99.51, 99.48, 99.39, 99.41, 99.35, 99.1, 98.85, 0.0] + +# RESULTS FROM DSR PAPER + +# DSR (Dynamic Sparse Reparameterization) for WRN28-2 +pr_dsr = [50.0, 60.0, 70.0, 80.0, 90.0] +acc_dsr = [94.7, 94.7, 94.52, 94.47, 93.65] + +# DeepR for WRN28-2 +pr_deepr = [50.0, 60.0, 70.0, 80.0, 90.0] +acc_deepr = [93.05, 92.83, 92.62, 92.45, 91.45] + +# SET for WRN28-2 +pr_set = [50.0, 60.0, 70.0, 80.0, 90.0] +acc_set = [94.72, 94.57, 94.38, 94.3, 93.3] + +# "Compressed Sparse" for WRN2-2 (To prune, not to prune Zhu and Ghupta, 2017) +pr_tpntp = [50.0, 60.0, 70.0, 80.0, 90.0] +acc_tpntp = [94.53, 94.53, 94.53, 94.17, 93.8] + +# %% make sure matplotlib works correctly +IN_JUPYTER = True +try: + if INLINE_PLOT: + get_ipython().run_line_magic("matplotlib", "inline") + else: + get_ipython().run_line_magic("matplotlib", "agg") +except AttributeError: + IN_JUPYTER = False + + +# %% get results +def get_results(file, logger): + """Grab all the results according to the hyperparameter file.""" + results = [] + params = [] + labels = [] + # Loop through all experiments + for param in get_parameters(file, 1, 0): + # initialize logger and setup parameters + logger.initialize_from_param(param) + # run the experiment (only if necessary) + try: + state = logger.get_global_state() + except ValueError: + experiment.Evaluator(logger).run() + state = logger.get_global_state() + # extract the results + results.append(copy.deepcopy(state)) + params.append(copy.deepcopy(param)) + # extract the legend (based on heuristic) + label = param["generated"]["datasetTest"].split("_") + if len(label) > 2: + label = label[2:] + labels.append("_".join(label)) + # extract the plots + graphers = logger.generate_plots(store_figs=False) + + return results, params, labels, graphers + + +# get a logger +logger = experiment.Logger() + +# get the results specified in the file (and hopefully pre-computed) +results, params, labels, graphers = get_results(FILE, logger) + +error_res = results[0]["error"] +sizes_res = results[0]["sizes"] +names = results[0]["names"] + +if "mnist/fc_nettrim" in FILE: + legends = ["Net-Trim", "BC", "DNS"] + colors = ["green", "black", "purple"] + pr_new = [pr_nt, pr_bc, pr_dns] + acc_new = [acc_nt, acc_bc, acc_dns] +elif "mnist/lenet5" in FILE: + legends = ["Net-Trim", "BC", "DNS"] + colors = ["green", "black", "purple"] + pr_new = [pr_nt5, pr_bc5, pr_dns5] + acc_new = [acc_nt5, acc_bc5, acc_dns5] +elif "cifar/cascade/wrn28_2" in FILE: + legends = ["DSR", "DeepR", "SET", "TPNTP"] + colors = ["green", "black", "purple", "magenta"] + pr_new = [pr_dsr, pr_deepr, pr_set, pr_tpntp] + acc_new = [acc_dsr, acc_deepr, acc_set, acc_tpntp] +else: + raise ValueError("Please provide a valid parameter file") + +prune_ratios = np.ones(sizes_res.shape[:3] + (len(legends),)) * 100.0 +acc = np.zeros_like(prune_ratios) + + +# now store all the elements +for i, (pr_one, acc_one) in enumerate(zip(pr_new, acc_new)): + # set values + prune_ratios[:, : len(pr_one), :, i] = np.asarray(pr_one)[None, :, None] + acc[:, : len(acc_one), :, i] = np.asarray(acc_one)[None, :, None] + # set last value for remaining ... + prune_ratios[:, len(pr_one) :, :, i] = pr_one[-1] + acc[:, len(acc_one) :, :, i] = acc_one[-1] +# normalize them now +errors_manual = 1.0 - acc / 100.0 +sizes_manual = 1.0 - prune_ratios / 100.0 + +# now we need to merge results + +# remove other methods first +IDX_REMOVE = 4 +errors_res = error_res[:, :, :, :IDX_REMOVE] +sizes_res = sizes_res[:, :, :, :IDX_REMOVE] +names = names[:IDX_REMOVE] + +# merge results +legends_merged = names + legends +errors_merged = np.concatenate((errors_res, errors_manual), axis=-1) +sizes_merged = np.concatenate((sizes_res, sizes_manual), axis=-1) + + +# re-use grapher and store plot +grapher = graphers[0] +grapher._linestyles = ["-"] * len(legends_merged) +grapher._x_values = 1.0 - sizes_merged +grapher._y_values = 1.0 - errors_merged +grapher._legend = legends_merged +colors_merged = grapher._colors[:IDX_REMOVE] +colors_merged.extend(colors) +grapher._colors = colors_merged + +grapher.graph( + percentage_x=True, + percentage_y=True, + store=False, + show_ref=False, + show_delta=False, + remove_outlier=False, +) +if "mnist/lenet5" in FILE: + grapher._figure.gca().set_xlim([80, 100.0]) + grapher._figure.gca().set_ylim([97, 99.9]) +elif "mnist/fc_nettrim" in FILE: + grapher._figure.gca().set_xlim([69, 100.0]) + grapher._figure.gca().set_ylim([95, 99.5]) +elif "wrn28_2" in FILE: + grapher._figure.gca().set_xlim([70, 95]) + grapher._figure.gca().set_ylim([92, 95.2]) + +grapher.store_plot() \ No newline at end of file diff --git a/paper/sipp/results_viewer.py b/paper/sipp/script/results_viewer.py similarity index 73% rename from paper/sipp/results_viewer.py rename to paper/sipp/script/results_viewer.py index 3807709..a4db8c5 100644 --- a/paper/sipp/results_viewer.py +++ b/paper/sipp/script/results_viewer.py @@ -1,8 +1,4 @@ -# To add a new markdown cell, type '# %% [markdown]' -# %% Change working directory from the workspace root to the ipynb file -# location. Turn this addition off with the DataScience.changeDirOnImportExport -# setting -# ms-python.python added +# %% from __future__ import print_function # make sure the setup is correct everywhere @@ -17,22 +13,16 @@ # make sure it's using only GPU here... os.environ["CUDA_VISIBLE_DEVICES"] = "0" # noqa -# make sure matplotlib works if we are running the script as notebook -in_jupyter = True -try: - get_ipython().run_line_magic("matplotlib", "inline") -except AttributeError: - in_jupyter = False - # switch to root folder for data folder = os.path.abspath("") -if "paper/sipp" in folder: - src_folder = os.path.join(folder, "../..") +if "paper/sipp/script" in folder: + src_folder = os.path.join(folder, "../../..") os.chdir(src_folder) # %% -# parameters for running the test -FILE = "paper/sipp/param/cifar/resnet56.yaml" +# parameter file to plot results +FILE = "paper/sipp/param/mnist/lenet5.yaml" +INLINE_PLOT = True def get_results(file, logger): @@ -69,7 +59,7 @@ def get_results(file, logger): remove_outlier=False, ) - if "_rand" in file or "retraininit" in file: + elif "_rand" in file or "retraininit" in file: for i, grapher in enumerate(graphers[:6]): percentage_y = bool((i + 1) % 3) grapher.graph( @@ -83,17 +73,38 @@ def get_results(file, logger): if percentage_y: # grapher._figure.gca().set_xlim([50, 99]) grapher._figure.gca().set_ylim([80, 95]) + elif "mnist" in file: + graphers[0].graph( + percentage_x=True, + percentage_y=True, + store=False, + show_ref=False, + show_delta=False, + remove_outlier=False, + ) + graphers[0]._figure.gca().set_xlim([97, 100]) + graphers[0]._figure.gca().set_ylim([95, 99.5]) return results, params, labels, graphers +# %% Run through plots +# make sure matplotlib works correctly +IN_JUPYTER = True +try: + if INLINE_PLOT: + get_ipython().run_line_magic("matplotlib", "inline") + else: + get_ipython().run_line_magic("matplotlib", "agg") +except AttributeError: + IN_JUPYTER = False + # get a logger logger = experiment.Logger() # get the results specified in the file (and hopefully pre-computed) results, params, labels, graphers = get_results(FILE, logger) -# %% show the results for grapher in graphers: grapher._figure.show() grapher.store_plot() diff --git a/paper/sipp/script/sweep_s.py b/paper/sipp/script/sweep_s.py new file mode 100644 index 0000000..1d75f3a --- /dev/null +++ b/paper/sipp/script/sweep_s.py @@ -0,0 +1,269 @@ +# To add a new markdown cell, type '# %% [markdown]' +# %% Set imports and working directory +from __future__ import print_function + +import os +import sys +import copy +import re +import glob +from collections import OrderedDict + +from IPython import get_ipython +import numpy as np +import experiment +import experiment.util as util +from matplotlib import ticker + +# make sure matplotlib works if we are running the script as notebook +IN_JUPYTER = True +try: + get_ipython().run_line_magic("matplotlib", "inline") +except AttributeError: + IN_JUPYTER = False + +# switch to root folder for data +if "paper/sipp/script" in os.path.abspath(""): + os.chdir(os.path.abspath("../../..")) + +# %% set parameters for testing +FILE = "paper/sipp/param/cifar/sweep/resnet20_sizes.yaml" +FILE = "paper/sipp/param/mnist/lenet300_sizes.yaml" +# FILE = "paper/sipp/param/mnist/lenet300_sizes2.yaml" + +INLINE_PLOT = True +LEGEND_ON = True + +# commensurate level for prune potential +COMM_LEVELS = [0.04] +# COMM_LEVELS = [0.01] + +# folder for param/acc plot... +PLOT_FOLDER_SPECIAL = os.path.abspath(os.path.join("data/results/sipp_plots")) + +# %% define functions +def get_results(file, logger): + """Grab all the results according to the hyperparameter file.""" + results = [] + params = [] + deltas_s = [] + # fmt: off + sizes_of_s = [35, 42, 49, 56, 70, 83, 97, 111, 125, 139, 152, 180, 208, 235, 263] + # fmt: on + # Loop through all experiments + for param in util.file.get_parameters(file, 1, 0): + # initialize logger and setup parameters + logger.initialize_from_param(param) + # run the experiment (only if necessary) + try: + state = logger.get_global_state() + delta_s = param["coresets"]["deltaS"] + # evaluator = experiment.Evaluator(logger) + # size_of_s = len(evaluator.get_dataloader("s_set")[0].dataset) + except ValueError as err: + raise ValueError("Please compute global state first") from err + # extract the results + results.append(copy.deepcopy(state)) + params.append(copy.deepcopy(param)) + deltas_s.append(delta_s) + # sizes_of_s.append(size_of_s) + print(deltas_s) + print(sizes_of_s) + return ( + OrderedDict(zip(sizes_of_s, results)), + OrderedDict(zip(sizes_of_s, params)), + ) + + +# do some plotting and analysis of the results now ... +def get_fig_name(title, tag): + """Get the name of the figure with the title and tag.""" + fig_name = "_".join(re.split("/|-|_|,", title)).replace(" ", "") + return f"{fig_name}_sweep_{tag}.pdf" + + +def extract_commensurate_size(stats, comm_level): + """Compute prune potential for each result and return it.""" + # get the index closest to our desired comm_level + c_idx = np.abs(np.array(stats[0]["commensurate"]) - comm_level).argmin() + + # pre-allocate results array + # stats_all[0]['eBest'] + # has shape (len(commensurate), num_nets, num_rep, num_alg) + _, num_nets, num_rep, num_alg = stats[0]["e_best"].shape + num_sweeps = len(stats) + size_comm = np.zeros((num_nets, num_sweeps, num_rep, num_alg)) + flops_comm = np.zeros_like(size_comm) + e_comm = np.zeros_like(size_comm) + e5_comm = np.zeros_like(size_comm) + + for i, stats_one in enumerate(stats): + size_comm[:, i] = stats_one["siz_best"][c_idx] + flops_comm[:, i] = stats_one["flo_best"][c_idx] + e_comm[:, i] = stats_one["e_best"][c_idx] + e5_comm[:, i] = stats_one["e5_best"][c_idx] + + return size_comm, flops_comm, e_comm, e5_comm + + +def plot_prune_potential( + sizes_s, + size_comm, + idx_ref, + legends, + colors, + title, + plots_dir, + plots_tag, + comm_level, + legend_on, + folder_special, +): + """Plot the prune potential for all methods.""" + + sizes_s = np.broadcast_to(sizes_s[None, :, None, None], size_comm.shape) + grapher_pp = util.grapher.Grapher( + x_values=sizes_s, + y_values=1.0 - size_comm, + folder=plots_dir, + file_name=get_fig_name(title, plots_tag), + ref_idx=idx_ref, + x_min=0, + x_max=1000, + legend=legends, + colors=colors, + xlabel="Size of S", + ylabel=f"Prune Ratio, $\Delta\leq{comm_level * 100:.1f}\%$", + title=title, + ) + + img_pp = grapher_pp.graph( + show_ref=False, + show_delta=False, + percentage_x=False, + percentage_y=True, + remove_outlier=False, + logplot=False, + store=False, + ) + + # set axis limits + img_pp.gca().set_xlim(120, 270) + img_pp.gca().set_ylim(35, 72) + + # check for legend off + if not legend_on: + img_pp.gca().get_legend().remove() + + # then store it + grapher_pp.store_plot() + + # and again in special folder + grapher_pp._folder = folder_special + grapher_pp.store_plot() + + return img_pp + + +def get_and_store_results( + file, logger, comm_levels, legend_on, folder_special +): + # get the results specified in the file (and hopefully pre-computed) + results, params = get_results(file, logger) + + # reset stdout after our logger modifies it ... + sys.stdout = sys.stdout._stdout_original + + # %% extract some additional information from the results + results_one = list(results.values())[0] + param_one = list(params.values())[0] + train_dset = param_one["network"]["dataset"] + + labels_method = param_one["generated"]["network_names"] + colors_method = param_one["generated"]["network_colors"] + + # some more stuff for plotting + network_name = param_one["network"]["name"] + title_pr = f"{network_name}, {train_dset}" + if "rewind" in param_one["experiments"]["mode"]: + title_pr += ", rewind" + + plots_dir = os.path.join( + param_one["generated"]["resultsDir"], "plots", "sweep" + ) + + # get reference indices + idx_ref_method = labels_method.index("ReferenceNet") + + # recall number of retraining as relative + sizes_s = np.array(list(results.keys())) + + s_c_all, f_c_all, e_c_all, e5_c_all = (None,) * 4 + + for i_c, comm_level in enumerate(comm_levels): + # compute commensurate size for desired comm level for all results + s_c, f_c, e_c, e5_c = extract_commensurate_size( + [res["stats_comm"] for res in results.values()], comm_level + ) + + if s_c_all is None: + s_c_all = np.zeros((len(comm_levels),) + s_c.shape) + f_c_all = np.zeros_like(s_c_all) + e_c_all = np.zeros_like(s_c_all) + e5_c_all = np.zeros_like(s_c_all) + + # store info + s_c_all[i_c] = s_c + f_c_all[i_c] = f_c + e_c_all[i_c] = e_c + e5_c_all[i_c] = e5_c + + # now plot the commensurate size (prune potential) + # plot a subset of the methods + fig = plot_prune_potential( + sizes_s=sizes_s, + size_comm=s_c, + idx_ref=idx_ref_method, + legends=labels_method, + colors=colors_method, + title=title_pr, + plots_dir=plots_dir, + plots_tag=f"prune_pot_delta_{comm_level:.3f}", + comm_level=comm_level, + legend_on=legend_on, + folder_special=folder_special, + ) + + print(f"PLOT DIR: {plots_dir}") + return ( + param_one, + { + "names": results_one["names"], + "sizes_s": sizes_s, + "sizes": s_c_all, + "flops": f_c_all, + "e": e_c_all, + "e5": e5_c_all, + }, + ) + + +# %% plot and store for all files now + +# make sure matplotlib works correctly +IN_JUPYTER = True +try: + if INLINE_PLOT: + get_ipython().run_line_magic("matplotlib", "inline") + else: + get_ipython().run_line_magic("matplotlib", "agg") +except AttributeError: + IN_JUPYTER = False + +# get a logger +LOGGER = experiment.Logger() + +# go through sweep +PARAM, STATS_COMM = get_and_store_results( + FILE, LOGGER, COMM_LEVELS, LEGEND_ON, PLOT_FOLDER_SPECIAL +) diff --git a/src/experiment/experiment/evaluator.py b/src/experiment/experiment/evaluator.py index ee58a14..8de0306 100644 --- a/src/experiment/experiment/evaluator.py +++ b/src/experiment/experiment/evaluator.py @@ -107,8 +107,6 @@ def __init__(self, logger): # Failure probability for constructing S. self._delta = param["coresets"]["deltaS"] - # Failure probability to pick a good coreset (only used for SiPPNet++). - self._delta_best = param["coresets"]["deltaBest"] # extract keep and compress ratios self._keep_ratios = param["generated"]["keepRatios"] diff --git a/src/experiment/experiment/param/default.yaml b/src/experiment/experiment/param/default.yaml index b09d137..677aca0 100644 --- a/src/experiment/experiment/param/default.yaml +++ b/src/experiment/experiment/param/default.yaml @@ -39,6 +39,7 @@ directories: network_names: ReferenceNet: ReferenceNet NetHandle: ReferenceNet + FakeNet: FakeNet EllOneAndTwoNet: $\frac{\ell_1+\ell_2}{2}$ EllOneNet: $\ell_1$ EllTwoNet: $\ell_2$ @@ -77,6 +78,7 @@ network_names: network_colors: ReferenceNet: black NetHandle: black + FakeNet: grey EllOneAndTwoNet: purple EllOneNet: blueviolet EllTwoNet: magenta diff --git a/src/experiment/experiment/param/training/mnist/lenet5.yaml b/src/experiment/experiment/param/training/cifar/lenet5.yaml similarity index 100% rename from src/experiment/experiment/param/training/mnist/lenet5.yaml rename to src/experiment/experiment/param/training/cifar/lenet5.yaml diff --git a/src/experiment/experiment/param/training/mnist/lenet.yaml b/src/experiment/experiment/param/training/mnist/lenet.yaml index 1e11b7e..96b8d0f 100644 --- a/src/experiment/experiment/param/training/mnist/lenet.yaml +++ b/src/experiment/experiment/param/training/mnist/lenet.yaml @@ -6,7 +6,11 @@ batchSize: 64 -transformsTrain: [] +transformsTrain: + - type: Pad + kwargs: { padding: 4 } + - type: RandomCrop + kwargs: { size: 28 } transformsTest: [] transformsFinal: - type: ToTensor @@ -21,7 +25,7 @@ optimizerKwargs: nesterov: False momentum: 0.9 -numEpochs: 40 +numEpochs: 50 lrSchedulers: - type: MultiStepLR diff --git a/src/experiment/experiment/util/data.py b/src/experiment/experiment/util/data.py index 953be25..9fd782a 100644 --- a/src/experiment/experiment/util/data.py +++ b/src/experiment/experiment/util/data.py @@ -88,15 +88,26 @@ def get_dset(transform, train): def get_dataloader(dataset, num_threads, shuffle=False, b_size=batch_size): """Construct data loader.""" + # ensure that we don't parallelize in data loader with glue. + # It does not play out well ... + # (also no need to do that for other small-scale datasets) + no_thread_classes = ( + dsets.MNIST, + dsets.BaseGlue, + dsets.BaseToyDataset, + ) + if isinstance(dataset, no_thread_classes): + num_threads = 0 + else: + num_threads = 4 + loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=b_size, num_workers=num_threads, shuffle=shuffle, pin_memory=True, - collate_fn=_glue_data_collator - if isinstance(dataset, dsets.BaseGlue) - else None, + collate_fn=_glue_data_collator if isinstance(dataset, dsets.BaseGlue) else None, ) return loader @@ -110,6 +121,8 @@ def get_dataloader(dataset, num_threads, shuffle=False, b_size=batch_size): no_thread_classes = ( dsets.MNIST, dsets.BaseGlue, + dsets.BaseToyDataset, + dsets.BaseTabularDataset, ) many_thread_classes = ( dsets.ImageNet, @@ -119,9 +132,7 @@ def get_dataloader(dataset, num_threads, shuffle=False, b_size=batch_size): if isinstance(set_test, no_thread_classes): num_threads = 0 elif isinstance(set_test, many_thread_classes): - num_threads = 10 * np.clip( - param["generated"]["numAvailableGPUs"], 1, 4 - ) + num_threads = 10 * np.clip(param["generated"]["numAvailableGPUs"], 1, 4) else: num_threads = 4 @@ -134,12 +145,17 @@ def get_dataloader(dataset, num_threads, shuffle=False, b_size=batch_size): set_valid = get_dset(transform_test, train=True) # get train/valid split - idx_train, idx_valid = _get_valid_split( - data_dir, - dset_name, - len(set_train), - valid_ratio, - ) + if hasattr(set_train, "get_valid_split"): + # use pre-defined split if it exists + idx_train, idx_valid = set_train.get_valid_split() + else: + # use a random split and record it + idx_train, idx_valid = _get_valid_split( + data_dir, + dset_name, + len(set_train), + valid_ratio, + ) # now split the data set_train = torch.utils.data.Subset(set_train, idx_train) @@ -162,15 +178,11 @@ def get_dataloader(dataset, num_threads, shuffle=False, b_size=batch_size): size_s = min(size_s, int(math.ceil(val_split_max * len(set_valid)))) # Now split validation set into valid set, S Set - set_valid, set_s = torch.utils.data.random_split( - set_valid, [len(set_valid) - size_s, size_s] - ) + set_valid, set_s = torch.utils.data.random_split(set_valid, [len(set_valid) - size_s, size_s]) # create the remaining loaders now loader_train = get_dataloader(set_train, num_threads, shuffle=True) - loader_valid = get_dataloader( - set_valid, num_threads, b_size=test_batch_size - ) + loader_valid = get_dataloader(set_valid, num_threads, b_size=test_batch_size) loader_s = get_dataloader(set_s, 0, b_size=min(4, test_batch_size)) return { @@ -220,8 +232,6 @@ def _get_valid_split(data_dir, dset_name, dset_len, ratio_valid): os.makedirs(data_dir, exist_ok=True) # save data - np.savez( - file, set2Idx=idx_train, set1Idx=idx_valid, ratioSet1=ratio_valid - ) + np.savez(file, set2Idx=idx_train, set1Idx=idx_valid, ratioSet1=ratio_valid) return idx_train, idx_valid diff --git a/src/experiment/setup.py b/src/experiment/setup.py index 435815a..303f128 100644 --- a/src/experiment/setup.py +++ b/src/experiment/setup.py @@ -23,6 +23,7 @@ install_requires=[ "matplotlib", "numpy", + "pyhessian", "pyyaml", "seaborn", "torch", diff --git a/src/torchprune/README.md b/src/torchprune/README.md index 1bdcb91..5c8c994 100644 --- a/src/torchprune/README.md +++ b/src/torchprune/README.md @@ -413,6 +413,22 @@ download location. place the `SegmentationClassAug.zip` file into `file_dir`. * Everything else is handled automatically. +#### **Toy Datasets** +* Code: + [torchprune/util/datasets/toy.py](./torchprune/util/datasets/toy.py) +* Description: [Torchdyn + tutorial](https://torchdyn.readthedocs.io/en/latest/tutorials/00_quickstart.html#Generate-data-from-a-static-toy-dataset) +* Everything is generated automatically using the `torchdyn` library. + +#### **Tabular Datasets** +* Code: + [torchprune/util/datasets/tabular.py](./torchprune/util/datasets/tabular.py) +* Description: [Pre-processed tabular datasets](https://github.com/gpapamak/maf#how-to-get-the-datasets) +* Please download data from + [here](https://zenodo.org/record/1161203#.Wmtf_XVl8eN) + place the `data.tar.gz` file into `file_dir`. +* Everything else is handled automatically. + #### **Corrupted PASCAL VOC Segmentation Datasets** * Code: [torchprune/util/datasets/voc.py](./torchprune/util/datasets/voc.py) diff --git a/src/torchprune/setup.py b/src/torchprune/setup.py index 3d209bb..569ada7 100644 --- a/src/torchprune/setup.py +++ b/src/torchprune/setup.py @@ -34,6 +34,7 @@ "torch", "torchvision", "tensorboard", + "torchdyn==1.0.1", "protobuf", "wand", "transformers @ git+https://github.com/huggingface/transformers.git", diff --git a/src/torchprune/torchprune/method/base/base_sparsifier.py b/src/torchprune/torchprune/method/base/base_sparsifier.py index 04cd225..fb406d8 100644 --- a/src/torchprune/torchprune/method/base/base_sparsifier.py +++ b/src/torchprune/torchprune/method/base/base_sparsifier.py @@ -84,10 +84,13 @@ def _reweigh(self, counts, num_samples, probs_div): def _generate_counts(self, num_samples, probs): mask = torch.zeros_like(probs, dtype=torch.bool) numel = probs.numel() - num_samples = int(np.clip(1, int(num_samples), numel)) - idx_top = np.argpartition(probs.view(-1).cpu().numpy(), -num_samples)[ - -num_samples: - ] + num_samples = int(np.clip(0, int(num_samples), numel)) + if num_samples > 0: + idx_top = np.argpartition( + probs.view(-1).cpu().numpy(), -num_samples + )[-num_samples:] + else: + idx_top = [] mask.view(-1)[idx_top] = True return mask diff --git a/src/torchprune/torchprune/method/ref/__init__.py b/src/torchprune/torchprune/method/ref/__init__.py index 2514ad7..cc76b3a 100644 --- a/src/torchprune/torchprune/method/ref/__init__.py +++ b/src/torchprune/torchprune/method/ref/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa: F403,F401 """A package with a fake compression to act as ReferenceNet.""" -from .ref_net import ReferenceNet +from .ref_net import ReferenceNet, FakeNet diff --git a/src/torchprune/torchprune/method/ref/ref_net.py b/src/torchprune/torchprune/method/ref/ref_net.py index 7ee4df8..d4d6fd6 100644 --- a/src/torchprune/torchprune/method/ref/ref_net.py +++ b/src/torchprune/torchprune/method/ref/ref_net.py @@ -39,3 +39,16 @@ def size(self): def flops(self): """Fake flops with current keep_ratio.""" return super().flops() * float(self._keep_ratio_latest) + + +class FakeNet(ReferenceNet): + """A reference net that is retrainable. + + By enabling retraining we can simulate an unpruned network that gets the + exact same training and retraining as the pruned networks. + """ + + @property + def retrainable(self): + """Indicate whether we can retrain after applying this method.""" + return True diff --git a/src/torchprune/torchprune/method/sipp/sipp_allocator.py b/src/torchprune/torchprune/method/sipp/sipp_allocator.py index 4964fe5..f459d2f 100644 --- a/src/torchprune/torchprune/method/sipp/sipp_allocator.py +++ b/src/torchprune/torchprune/method/sipp/sipp_allocator.py @@ -332,7 +332,8 @@ def get_num_samples(self, layer): """Get the number of samples for a particular layer index.""" # get optimal sample numbers from randAllocator num_samples = self._allocator_rand.get_num_samples(layer) - num_samples[num_samples < 0] = 1 + no_sampling = num_samples < 0 + num_samples[no_sampling] = 1 # check error for both methods error_rand = self._allocator_rand._get_error_theoretical( @@ -346,4 +347,7 @@ def get_num_samples(self, layer): use_rand = error_det > error_rand num_samples[use_rand] = -num_samples[use_rand] + # reset zero samples to zero + num_samples[no_sampling] = 0 + return num_samples diff --git a/src/torchprune/torchprune/util/datasets/__init__.py b/src/torchprune/torchprune/util/datasets/__init__.py index d3e2bd5..6ebe0fe 100644 --- a/src/torchprune/torchprune/util/datasets/__init__.py +++ b/src/torchprune/torchprune/util/datasets/__init__.py @@ -11,9 +11,11 @@ from .driving import Driving from .imagenet import ImageNet -from .imagenet_c import * # noqa: F403,F401 +from .imagenet_c import * from .objectnet import ObjectNet -from .cifar10 import * # noqa: F403,F401 +from .cifar10 import * from .dds import DownloadDataset -from .voc import * # noqa: F403,F401 +from .voc import * from .glue import * +from .toy import * +from .tabular import * diff --git a/src/torchprune/torchprune/util/datasets/dds.py b/src/torchprune/torchprune/util/datasets/dds.py index 09c8dce..43af272 100644 --- a/src/torchprune/torchprune/util/datasets/dds.py +++ b/src/torchprune/torchprune/util/datasets/dds.py @@ -67,7 +67,7 @@ def __init__( Args: root (str): where to store the data set to be downloaded. - file_dir (str): where to look before downloading it from S3. + file_dir (str): where to look for downloading data. train (bool, optional): train or test data set. Defaults to True. transform (torchvision.transforms, optional): set of transforms to apply to input data. Defaults to None. @@ -139,7 +139,7 @@ def __getitem__(self, index): # it might not be a PIL image yet ... img = self._convert_to_pil(img) - # target might be weird, so convert it forst + # target might be weird, so convert it first target = self._convert_target(target) if self.transform is not None: diff --git a/src/torchprune/torchprune/util/datasets/tabular.py b/src/torchprune/torchprune/util/datasets/tabular.py new file mode 100644 index 0000000..be0822f --- /dev/null +++ b/src/torchprune/torchprune/util/datasets/tabular.py @@ -0,0 +1,105 @@ +"""Tabular data sets for FFJORD experiments.""" + +from abc import ABC, abstractmethod +import numpy as np +import torch +from .dds import DownloadDataset +from ..external.ffjord import datasets as tabular_external + + +class BaseTabularDataset(DownloadDataset, ABC): + """An abstract interface for tabular datasets.""" + + def __init__(self, *args, **kwargs): + """Initialize like parent but add train_split, val_split.""" + self._train_split, self._val_split = None, None + super().__init__(*args, **kwargs) + + @property + @abstractmethod + def _dataset_name(self): + """Return the name of the tabular dataset.""" + + @property + def _train_tar_file_name(self): + return "data.tar.gz" + + @property + def _test_tar_file_name(self): + return self._train_tar_file_name + + @property + def _train_dir(self): + return "data/" + + @property + def _test_dir(self): + return self._train_dir + + def _get_train_data(self, download): + dset = getattr(tabular_external, self._dataset_name)(self._data_path) + val_trn_x = self._join_and_store_split(dset) + data = torch.tensor(val_trn_x) + return [(x_data, 0) for x_data in data] + + def _get_test_data(self, download): + dset = getattr(tabular_external, self._dataset_name)(self._data_path) + self._join_and_store_split(dset) + return [(x_data, 0) for x_data in torch.tensor(dset.tst.x)] + + def _convert_to_pil(self, img): + return img + + def _convert_target(self, target): + return int(target) + + def _join_and_store_split(self, dset): + """Join train and validation set but recall split.""" + val_trn_x = np.concatenate((dset.trn.x, dset.val.x)) + self._train_split = list(np.arange(len(dset.trn.x))) + self._val_split = list(np.arange(len(dset.trn.x), len(val_trn_x))) + return val_trn_x + + def get_valid_split(self): + """Return indices corresponding to training and validation.""" + return self._train_split, self._val_split + + +class Bsds300(BaseTabularDataset): + """The Bsds300 dataset.""" + + @property + def _dataset_name(self): + return "BSDS300" + + +class Gas(BaseTabularDataset): + """The Gas dataset.""" + + @property + def _dataset_name(self): + return "GAS" + + +class Hepmass(BaseTabularDataset): + """The Hepmass dataset.""" + + @property + def _dataset_name(self): + return "HEPMASS" + + +class Miniboone(BaseTabularDataset): + """The Miniboone dataset.""" + + @property + def _dataset_name(self): + return "MINIBOONE" + + +class Power(BaseTabularDataset): + """The Power dataset.""" + + @property + def _dataset_name(self): + return "POWER" diff --git a/src/torchprune/torchprune/util/datasets/toy.py b/src/torchprune/torchprune/util/datasets/toy.py new file mode 100644 index 0000000..adfa0ed --- /dev/null +++ b/src/torchprune/torchprune/util/datasets/toy.py @@ -0,0 +1,296 @@ +"""A wrapper module for torchdyn data sets with configurations.""" +import os +import shutil +import warnings +from abc import ABC, abstractmethod +from urllib.request import URLError + +import tarfile +import numpy as np +import torch +import torchdyn.datasets as dyn_data + +from .dds import DownloadDataset + + +class BaseToyDataset(DownloadDataset, ABC): + """An abstract interface for torchdyn toy datasets.""" + + @property + @abstractmethod + def _dataset_type(self): + """Return the type of toy dataset we want to get.""" + + @property + @abstractmethod + def _n_samples(self): + """Return number of samples we should generate.""" + + @property + @abstractmethod + def _dataset_kwargs(self): + """Return the kwargs to initialize the toy dataset.""" + + @property + def _dataset_tag(self): + """Return the tag used to identify the files related to dataset.""" + return self._dataset_type + + @property + def _train_tar_file_name(self): + return f"torchdyn_toy_{self._dataset_tag}.tar.gz" + + @property + def _test_tar_file_name(self): + return self._train_tar_file_name + + @property + def _train_dir(self): + return f"torchdyn_toy_{self._dataset_tag}" + + @property + def _test_dir(self): + return self._train_dir + + def _get_train_data(self, download): + x_data = np.load(os.path.join(self._data_path, "x_data_train.npy")) + y_data = np.load(os.path.join(self._data_path, "y_data_train.npy")) + return list(zip(x_data, y_data)) + + def _get_test_data(self, download): + x_data = np.load(os.path.join(self._data_path, "x_data_test.npy")) + y_data = np.load(os.path.join(self._data_path, "y_data_test.npy")) + return list(zip(x_data, y_data)) + + def _convert_to_pil(self, img): + return torch.tensor(img) + + def _convert_target(self, target): + return int(target) + + def _download(self): + """Download data set and generate first if necessary.""" + try: + super()._download() + except URLError: + self._generate_data() + super()._download() + + def _generate_data(self): + """Generate and store data now.""" + # issue warning at the beginning to remind user of this change + warnings.warn(f"Generating new data for {type(self)}.") + + def _sample_data(n_samples): + toy_dset = dyn_data.ToyDataset() + x_data, y_data = toy_dset.generate( + n_samples=n_samples, + dataset_type=self._dataset_type, + **self._dataset_kwargs, + ) + # check that y_data is not none + if y_data is None: + y_data = torch.zeros(len(x_data), dtype=torch.long) + + # normalize x data now + x_data -= x_data.mean() + x_data /= x_data.std() + + return x_data.cpu().numpy(), y_data.cpu().numpy() + + def _x_tmp_file(tag): + return os.path.join("/tmp", f"{self._dataset_tag}_{tag}_x.npy") + + def _y_tmp_file(tag): + return os.path.join("/tmp", f"{self._dataset_tag}_{tag}_y.npy") + + # generate and save train/test data + tags_size = {"train": self._n_samples, "test": self._n_samples // 2} + for tag, n_samples in tags_size.items(): + x_data, y_data = _sample_data(n_samples) + np.save(_x_tmp_file(tag), x_data) + np.save(_y_tmp_file(tag), y_data) + + # now store in tar file + tar_file = os.path.join(self._file_dir, self._train_tar_file_name) + tmp_tar_file = os.path.join("/tmp", self._train_tar_file_name) + with tarfile.open(tmp_tar_file, "w:gz") as tar: + for tag in tags_size: + tar.add( + _x_tmp_file(tag), + arcname=os.path.join(self._train_dir, f"x_data_{tag}.npy"), + ) + tar.add( + _y_tmp_file(tag), + arcname=os.path.join(self._train_dir, f"y_data_{tag}.npy"), + ) + + # move tar file to right location + shutil.move(tmp_tar_file, tar_file) + + # print reminder at the end + print(f"Generated new data for {type(self)}.") + + +class ToyConcentric(BaseToyDataset): + """The concentric spheres dataset.""" + + @property + def _dataset_type(self): + """Return the type of toy dataset we want to get.""" + return "spheres" + + @property + def _n_samples(self): + """Return number of samples we should generate.""" + return 1024 + + @property + def _dataset_kwargs(self): + """Return the kwargs to initialize the toy dataset.""" + return { + "dim": 2, + "noise": 1e-1, + "inner_radius": 0.75, + "outer_radius": 1.5, + } + + +class ToyMoons(BaseToyDataset): + """The moons dataset.""" + + @property + def _dataset_type(self): + """Return the type of toy dataset we want to get.""" + return "moons" + + @property + def _n_samples(self): + """Return number of samples we should generate.""" + return 1024 + + @property + def _dataset_kwargs(self): + """Return the kwargs to initialize the toy dataset.""" + return {"noise": 1e-1} + + +class ToySpirals(BaseToyDataset): + """The spirals dataset.""" + + @property + def _dataset_type(self): + """Return the type of toy dataset we want to get.""" + return "spirals" + + @property + def _n_samples(self): + """Return number of samples we should generate.""" + return 1024 + + @property + def _dataset_kwargs(self): + """Return the kwargs to initialize the toy dataset.""" + return {"noise": 0.9} + + +class ToySpirals2(BaseToyDataset): + """The spirals dataset with more samples and less noise.""" + + @property + def _dataset_type(self): + """Return the type of toy dataset we want to get.""" + return "spirals" + + @property + def _dataset_tag(self): + """Return the tag used to identify the files related to dataset.""" + return "spirals2" + + @property + def _n_samples(self): + """Return number of samples we should generate.""" + return int(2 ** 14) + + @property + def _dataset_kwargs(self): + """Return the kwargs to initialize the toy dataset.""" + return {"noise": 0.5} + + +class ToyGaussians(BaseToyDataset): + """The moon dataset.""" + + @property + def _n_gaussians(self): + return 6 + + @property + def _dataset_type(self): + """Return the type of toy dataset we want to get.""" + return "gaussians" + + @property + def _n_samples(self): + """Return number of samples we should generate.""" + return 2 ** 14 // self._n_gaussians + + @property + def _dataset_kwargs(self): + """Return the kwargs to initialize the toy dataset.""" + return { + "n_gaussians": self._n_gaussians, + "std_gaussians": 0.5, + "radius": 4, + "dim": 2, + } + + +class ToyGaussiansSpiral(BaseToyDataset): + """The Gaussian spiral dataset.""" + + @property + def _n_gaussians(self): + return 10 + + @property + def _dataset_type(self): + """Return the type of toy dataset we want to get.""" + return "gaussians_spiral" + + @property + def _n_samples(self): + """Return number of samples we should generate.""" + return 2 ** 14 // self._n_gaussians + + @property + def _dataset_kwargs(self): + """Return the kwargs to initialize the toy dataset.""" + return { + "n_gaussians": self._n_gaussians, + "n_gaussians_per_loop": 6, + "dim": 2, + "radius_start": 4.0, + "radius_end": 1.0, + "std_gaussians_start": 0.3, + "std_gaussians_end": 0.1, + } + + +class ToyDiffeqml(BaseToyDataset): + """The diffeqml dataset.""" + + @property + def _dataset_type(self): + """Return the type of toy dataset we want to get.""" + return "diffeqml" + + @property + def _n_samples(self): + """Return number of samples we should generate.""" + return 2 ** 14 + + @property + def _dataset_kwargs(self): + """Return the kwargs to initialize the toy dataset.""" + return {"noise": 5e-2} diff --git a/src/torchprune/torchprune/util/external/cnn/models/cifar/wrn.py b/src/torchprune/torchprune/util/external/cnn/models/cifar/wrn.py index 49eff18..67d2712 100644 --- a/src/torchprune/torchprune/util/external/cnn/models/cifar/wrn.py +++ b/src/torchprune/torchprune/util/external/cnn/models/cifar/wrn.py @@ -3,7 +3,7 @@ import torch.nn as nn import torch.nn.functional as F -__all__ = ['wrn', 'wrn16_8', 'wrn28_10', 'wrn40_1', 'wrn40_4'] +__all__ = ['wrn', 'wrn16_8', 'wrn28_2', 'wrn28_10', 'wrn40_1', 'wrn40_4'] class BasicBlock(nn.Module): def __init__(self, in_planes, out_planes, stride, dropRate=0.0): @@ -126,3 +126,9 @@ def wrn28_10(**kwargs): WRN, depth 28, widening factor 10 """ return WideResNet(depth=28, widen_factor=10, **kwargs) + +def wrn28_2(**kwargs): + """ + WRN, depth 28, widening factor 2 + """ + return WideResNet(depth=28, widen_factor=2, **kwargs) diff --git a/src/torchprune/torchprune/util/external/ffjord/.gitignore b/src/torchprune/torchprune/util/external/ffjord/.gitignore new file mode 100644 index 0000000..d2e4bd6 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/.gitignore @@ -0,0 +1,3 @@ +*__pycache__* +*.pyc +data/* diff --git a/src/torchprune/torchprune/util/external/ffjord/LICENSE b/src/torchprune/torchprune/util/external/ffjord/LICENSE new file mode 100644 index 0000000..5afe560 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Ricky Tian Qi Chen and Will Grathwohl + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/torchprune/torchprune/util/external/ffjord/README.md b/src/torchprune/torchprune/util/external/ffjord/README.md new file mode 100644 index 0000000..e0e8ef5 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/README.md @@ -0,0 +1,56 @@ +# Free-form Jacobian of Reversible Dynamics (FFJORD) + +Code for reproducing the experiments in the paper: + +> Will Grathwohl*, Ricky T. Q. Chen*, Jesse Bettencourt, Ilya Sutskever, David Duvenaud. "FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models." _International Conference on Learning Representations_ (2019). +> [[arxiv]](https://arxiv.org/abs/1810.01367) [[bibtex]](http://www.cs.toronto.edu/~rtqichen/bibtex/ffjord.bib) + + +## Prerequisites + +Install `torchdiffeq` from https://github.com/rtqichen/torchdiffeq. + +## Usage + +Different scripts are provided for different datasets. To see all options, use the `-h` flag. + +Toy 2d: +``` +python train_toy.py --data 8gaussians --dims 64-64-64 --layer_type concatsquash --save experiment1 +``` + +Tabular datasets from [MAF](https://github.com/gpapamak/maf): +``` +python train_tabular.py --data miniboone --nhidden 2 --hdim_factor 20 --num_blocks 1 --nonlinearity softplus --batch_size 1000 --lr 1e-3 +``` + +MNIST/CIFAR10: +``` +python train_cnf.py --data mnist --dims 64,64,64 --strides 1,1,1,1 --num_blocks 2 --layer_type concat --multiscale True --rademacher True +``` + +VAE Experiments (based on [Sylvester VAE](https://github.com/riannevdberg/sylvester-flows)): +``` +python train_vae_flow.py --dataset mnist --flow cnf_rank --rank 64 --dims 1024-1024 --num_blocks 2 +``` + +Glow / Real NVP experiments are run using `train_discrete_toy.py` and `train_discrete_tabular.py`. + +## Datasets + +### Tabular (UCI + BSDS300) +Follow instructions from https://github.com/gpapamak/maf and place them in `data/`. + +### VAE datasets +Follow instructions from https://github.com/riannevdberg/sylvester-flows and place them in `data/`. + +## Bespoke Flows + +Here's a fun script that you can use to create your own 2D flow from an image! +``` +python train_img2d.py --img imgs/github.png --save github_flow +``` + +

+ +

diff --git a/src/torchprune/torchprune/util/models/cnn/__init__.py b/src/torchprune/torchprune/util/external/ffjord/__init__.py similarity index 100% rename from src/torchprune/torchprune/util/models/cnn/__init__.py rename to src/torchprune/torchprune/util/external/ffjord/__init__.py diff --git a/src/torchprune/torchprune/util/external/ffjord/assets/github_flow.gif b/src/torchprune/torchprune/util/external/ffjord/assets/github_flow.gif new file mode 100644 index 0000000..26e8d54 Binary files /dev/null and b/src/torchprune/torchprune/util/external/ffjord/assets/github_flow.gif differ diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/LICENSE.txt b/src/torchprune/torchprune/util/external/ffjord/datasets/LICENSE.txt new file mode 100644 index 0000000..6faf4ff --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/LICENSE.txt @@ -0,0 +1,26 @@ +Copyright (c) 2017, George Papamakarios +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +The views and conclusions contained in the software and documentation are those +of the authors and should not be interpreted as representing official policies, +either expressed or implied, of anybody else. diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/__init__.py b/src/torchprune/torchprune/util/external/ffjord/datasets/__init__.py new file mode 100644 index 0000000..7220778 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/__init__.py @@ -0,0 +1,5 @@ +from .power import POWER +from .gas import GAS +from .hepmass import HEPMASS +from .miniboone import MINIBOONE +from .bsds300 import BSDS300 diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/bsds300.py b/src/torchprune/torchprune/util/external/ffjord/datasets/bsds300.py new file mode 100644 index 0000000..0d1e026 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/bsds300.py @@ -0,0 +1,33 @@ +import os +import numpy as np +import h5py + + +class BSDS300: + """ + A dataset of patches from BSDS300. + """ + + class Data: + """ + Constructs the dataset. + """ + + def __init__(self, data): + + self.x = data[:] + self.N = self.x.shape[0] + + def __init__(self, root): + + # load dataset + f = h5py.File(os.path.join(root, "BSDS300", "BSDS300.hdf5"), "r") + + self.trn = self.Data(f["train"]) + self.val = self.Data(f["validation"]) + self.tst = self.Data(f["test"]) + + self.n_dims = self.trn.x.shape[1] + self.image_size = [int(np.sqrt(self.n_dims + 1))] * 2 + + f.close() diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/gas.py b/src/torchprune/torchprune/util/external/ffjord/datasets/gas.py new file mode 100644 index 0000000..5bd68c9 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/gas.py @@ -0,0 +1,69 @@ +import os +import pandas as pd +import numpy as np + + +class GAS: + class Data: + def __init__(self, data): + + self.x = data.astype(np.float32) + self.N = self.x.shape[0] + + def __init__(self, root): + + file = os.path.join(root, "gas", "ethylene_CO.pickle") + trn, val, tst = load_data_and_clean_and_split(file) + + self.trn = self.Data(trn) + self.val = self.Data(val) + self.tst = self.Data(tst) + + self.n_dims = self.trn.x.shape[1] + + +def load_data(file): + + data = pd.read_pickle(file) + # data = pd.read_pickle(file).sample(frac=0.25) + # data.to_pickle(file) + data.drop("Meth", axis=1, inplace=True) + data.drop("Eth", axis=1, inplace=True) + data.drop("Time", axis=1, inplace=True) + return data + + +def get_correlation_numbers(data): + C = data.corr() + A = C > 0.98 + B = A.to_numpy().sum(axis=1) + return B + + +def load_data_and_clean(file): + + data = load_data(file) + B = get_correlation_numbers(data) + + while np.any(B > 1): + col_to_remove = np.where(B > 1)[0][0] + col_name = data.columns[col_to_remove] + data.drop(col_name, axis=1, inplace=True) + B = get_correlation_numbers(data) + # print(data.corr()) + data = (data - data.mean()) / data.std() + + return data + + +def load_data_and_clean_and_split(file): + + data = load_data_and_clean(file).to_numpy() + N_test = int(0.1 * data.shape[0]) + data_test = data[-N_test:] + data_train = data[0:-N_test] + N_validate = int(0.1 * data_train.shape[0]) + data_validate = data_train[-N_validate:] + data_train = data_train[0:-N_validate] + + return data_train, data_validate, data_test diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/hepmass.py b/src/torchprune/torchprune/util/external/ffjord/datasets/hepmass.py new file mode 100644 index 0000000..2fe35d9 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/hepmass.py @@ -0,0 +1,112 @@ +import os +import pandas as pd +import numpy as np +from collections import Counter + + +class HEPMASS: + """ + The HEPMASS data set. + http://archive.ics.uci.edu/ml/datasets/HEPMASS + """ + + class Data: + def __init__(self, data): + + self.x = data.astype(np.float32) + self.N = self.x.shape[0] + + def __init__(self, root): + + path = os.path.join(root, "hepmass") + trn, val, tst = load_data_no_discrete_normalised_as_array(path) + + self.trn = self.Data(trn) + self.val = self.Data(val) + self.tst = self.Data(tst) + + self.n_dims = self.trn.x.shape[1] + + +def load_data(path): + + data_train = pd.read_csv( + filepath_or_buffer=os.path.join(path, "1000_train.csv"), + index_col=False, + ) + data_test = pd.read_csv( + filepath_or_buffer=os.path.join(path, "1000_test.csv"), index_col=False + ) + + return data_train, data_test + + +def load_data_no_discrete(path): + """ + Loads the positive class examples from the first 10 percent of the dataset. + """ + data_train, data_test = load_data(path) + + # Gets rid of any background noise examples i.e. class label 0. + data_train = data_train[data_train[data_train.columns[0]] == 1] + data_train = data_train.drop(data_train.columns[0], axis=1) + data_test = data_test[data_test[data_test.columns[0]] == 1] + data_test = data_test.drop(data_test.columns[0], axis=1) + # Because the data set is messed up! + data_test = data_test.drop(data_test.columns[-1], axis=1) + + return data_train, data_test + + +def load_data_no_discrete_normalised(path): + + data_train, data_test = load_data_no_discrete(path) + mu = data_train.mean() + s = data_train.std() + data_train = (data_train - mu) / s + data_test = (data_test - mu) / s + + return data_train, data_test + + +def load_data_no_discrete_normalised_as_array(path): + + data_train, data_test = load_data_no_discrete_normalised(path) + data_train, data_test = data_train.to_numpy(), data_test.to_numpy() + + i = 0 + # Remove any features that have too many re-occurring real values. + features_to_remove = [] + for feature in data_train.T: + c = Counter(feature) + max_count = np.array([v for k, v in sorted(c.items())])[0] + if max_count > 5: + features_to_remove.append(i) + i += 1 + data_train = data_train[ + :, + np.array( + [ + i + for i in range(data_train.shape[1]) + if i not in features_to_remove + ] + ), + ] + data_test = data_test[ + :, + np.array( + [ + i + for i in range(data_test.shape[1]) + if i not in features_to_remove + ] + ), + ] + + N = data_train.shape[0] + N_validate = int(N * 0.1) + data_validate = data_train[-N_validate:] + data_train = data_train[0:-N_validate] + + return data_train, data_validate, data_test diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/miniboone.py b/src/torchprune/torchprune/util/external/ffjord/datasets/miniboone.py new file mode 100644 index 0000000..10f8881 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/miniboone.py @@ -0,0 +1,66 @@ +import os +import numpy as np + + +class MINIBOONE: + class Data: + def __init__(self, data): + + self.x = data.astype(np.float32) + self.N = self.x.shape[0] + + def __init__(self, root): + + file = os.path.join(root, "miniboone", "data.npy") + trn, val, tst = load_data_normalised(file) + + self.trn = self.Data(trn) + self.val = self.Data(val) + self.tst = self.Data(tst) + + self.n_dims = self.trn.x.shape[1] + + +def load_data(root_path): + # NOTE: To remember how the pre-processing was done. + # data = pd.read_csv(root_path, names=[str(x) for x in range(50)], delim_whitespace=True) + # print data.head() + # data = data.as_matrix() + # # Remove some random outliers + # indices = (data[:, 0] < -100) + # data = data[~indices] + # + # i = 0 + # # Remove any features that have too many re-occuring real values. + # features_to_remove = [] + # for feature in data.T: + # c = Counter(feature) + # max_count = np.array([v for k, v in sorted(c.iteritems())])[0] + # if max_count > 5: + # features_to_remove.append(i) + # i += 1 + # data = data[:, np.array([i for i in range(data.shape[1]) if i not in features_to_remove])] + # np.save("~/data/miniboone/data.npy", data) + + data = np.load(root_path) + N_test = int(0.1 * data.shape[0]) + data_test = data[-N_test:] + data = data[0:-N_test] + N_validate = int(0.1 * data.shape[0]) + data_validate = data[-N_validate:] + data_train = data[0:-N_validate] + + return data_train, data_validate, data_test + + +def load_data_normalised(root_path): + + data_train, data_validate, data_test = load_data(root_path) + data = np.vstack((data_train, data_validate)) + mu = data.mean(axis=0) + s = data.std(axis=0) + data_train = (data_train - mu) / s + data_validate = (data_validate - mu) / s + data_test = (data_test - mu) / s + + return data_train, data_validate, data_test diff --git a/src/torchprune/torchprune/util/external/ffjord/datasets/power.py b/src/torchprune/torchprune/util/external/ffjord/datasets/power.py new file mode 100644 index 0000000..5539732 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/datasets/power.py @@ -0,0 +1,71 @@ +import os +import numpy as np + + +class POWER: + class Data: + def __init__(self, data): + + self.x = data.astype(np.float32) + self.N = self.x.shape[0] + + def __init__(self, root): + + trn, val, tst = load_data_normalised(root) + + self.trn = self.Data(trn) + self.val = self.Data(val) + self.tst = self.Data(tst) + + self.n_dims = self.trn.x.shape[1] + + +def load_data(root): + return np.load(os.path.join(root, "power", "data.npy")) + + +def load_data_split_with_noise(root): + + rng = np.random.RandomState(42) + + data = load_data(root) + rng.shuffle(data) + N = data.shape[0] + + data = np.delete(data, 3, axis=1) + data = np.delete(data, 1, axis=1) + ############################ + # Add noise + ############################ + # global_intensity_noise = 0.1*rng.rand(N, 1) + voltage_noise = 0.01 * rng.rand(N, 1) + # grp_noise = 0.001*rng.rand(N, 1) + gap_noise = 0.001 * rng.rand(N, 1) + sm_noise = rng.rand(N, 3) + time_noise = np.zeros((N, 1)) + # noise = np.hstack((gap_noise, grp_noise, voltage_noise, global_intensity_noise, sm_noise, time_noise)) + # noise = np.hstack((gap_noise, grp_noise, voltage_noise, sm_noise, time_noise)) + noise = np.hstack((gap_noise, voltage_noise, sm_noise, time_noise)) + data = data + noise + + N_test = int(0.1 * data.shape[0]) + data_test = data[-N_test:] + data = data[0:-N_test] + N_validate = int(0.1 * data.shape[0]) + data_validate = data[-N_validate:] + data_train = data[0:-N_validate] + + return data_train, data_validate, data_test + + +def load_data_normalised(root): + + data_train, data_validate, data_test = load_data_split_with_noise(root) + data = np.vstack((data_train, data_validate)) + mu = data.mean(axis=0) + s = data.std(axis=0) + data_train = (data_train - mu) / s + data_validate = (data_validate - mu) / s + data_test = (data_test - mu) / s + + return data_train, data_validate, data_test diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/approx_error_1d.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/approx_error_1d.py new file mode 100644 index 0000000..c4a52d7 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/approx_error_1d.py @@ -0,0 +1,265 @@ +from inspect import getsourcefile +import sys +import os + +current_path = os.path.abspath(getsourcefile(lambda: 0)) +current_dir = os.path.dirname(current_path) +parent_dir = current_dir[:current_dir.rfind(os.path.sep)] +sys.path.insert(0, parent_dir) + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import numpy as np +import argparse +import os +import time + +import torch +import torch.optim as optim + +import lib.utils as utils +import lib.layers.odefunc as odefunc + +from train_misc import standard_normal_logprob +from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time +from train_misc import build_model_tabular + +import seaborn as sns +sns.set_style("whitegrid") +colors = ["windows blue", "amber", "greyish", "faded green", "dusty purple"] +sns.palplot(sns.xkcd_palette(colors)) + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser('Continuous Normalizing Flow') +parser.add_argument( + "--layer_type", type=str, default="concatsquash", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument('--dims', type=str, default='64-64-64') +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') +parser.add_argument('--time_length', type=float, default=0.5) +parser.add_argument('--train_T', type=eval, default=True) +parser.add_argument("--divergence_fn", type=str, default="brute_force", choices=["brute_force", "approximate"]) +parser.add_argument("--nonlinearity", type=str, default="tanh", choices=odefunc.NONLINEARITIES) + +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) + +parser.add_argument('--niters', type=int, default=10000) +parser.add_argument('--batch_size', type=int, default=100) +parser.add_argument('--test_batch_size', type=int, default=1000) +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--weight_decay', type=float, default=1e-5) + +# Track quantities +parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1") +parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2") +parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2") +parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F") +parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F") +parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") + +parser.add_argument('--save', type=str, default='experiments/approx_error_1d') +parser.add_argument('--viz_freq', type=int, default=100) +parser.add_argument('--val_freq', type=int, default=100) +parser.add_argument('--log_freq', type=int, default=10) +parser.add_argument('--gpu', type=int, default=0) +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) + +if args.layer_type == "blend": + logger.info("!! Setting time_length from None to 1.0 due to use of Blend layers.") + args.time_length = 1.0 + +logger.info(args) + +device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + + +def normal_log_density(x, mean=0, stdev=1): + term = (x - mean) / stdev + return -0.5 * (np.log(2 * np.pi) + 2 * np.log(stdev) + term * term) + + +def data_sample(batch_size): + x1 = np.random.randn(batch_size) * np.sqrt(0.4) - 2.8 + x2 = np.random.randn(batch_size) * np.sqrt(0.4) - 0.9 + x3 = np.random.randn(batch_size) * np.sqrt(0.4) + 2. + xs = np.concatenate([x1[:, None], x2[:, None], x3[:, None]], 1) + k = np.random.randint(0, 3, batch_size) + x = xs[np.arange(batch_size), k] + return torch.tensor(x[:, None]).float().to(device) + + +def data_density(x): + p1 = normal_log_density(x, mean=-2.8, stdev=np.sqrt(0.4)) + p2 = normal_log_density(x, mean=-0.9, stdev=np.sqrt(0.4)) + p3 = normal_log_density(x, mean=2.0, stdev=np.sqrt(0.4)) + return torch.log(p1.exp() / 3 + p2.exp() / 3 + p3.exp() / 3) + + +def model_density(x, model): + x = x.to(device) + z, delta_logp = model(x, torch.zeros_like(x)) + logpx = standard_normal_logprob(z) - delta_logp + return logpx + + +def model_sample(model, batch_size): + z = torch.randn(batch_size, 1) + logqz = standard_normal_logprob(z) + x, logqx = model(z, logqz, reverse=True) + return x, logqx + + +def compute_loss(args, model, batch_size=None): + if batch_size is None: batch_size = args.batch_size + + x = data_sample(batch_size) + logpx = model_density(x, model) + return -torch.mean(logpx) + + +def train(): + + model = build_model_tabular(args, 1).to(device) + set_cnf_options(args, model) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + time_meter = utils.RunningAverageMeter(0.93) + loss_meter = utils.RunningAverageMeter(0.93) + nfef_meter = utils.RunningAverageMeter(0.93) + nfeb_meter = utils.RunningAverageMeter(0.93) + tt_meter = utils.RunningAverageMeter(0.93) + + end = time.time() + best_loss = float('inf') + model.train() + for itr in range(1, args.niters + 1): + optimizer.zero_grad() + + loss = compute_loss(args, model) + loss_meter.update(loss.item()) + + total_time = count_total_time(model) + nfe_forward = count_nfe(model) + + loss.backward() + optimizer.step() + + nfe_total = count_nfe(model) + nfe_backward = nfe_total - nfe_forward + nfef_meter.update(nfe_forward) + nfeb_meter.update(nfe_backward) + + time_meter.update(time.time() - end) + tt_meter.update(total_time) + + log_message = ( + 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' + ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( + itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, + nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg + ) + ) + logger.info(log_message) + + if itr % args.val_freq == 0 or itr == args.niters: + with torch.no_grad(): + model.eval() + test_loss = compute_loss(args, model, batch_size=args.test_batch_size) + test_nfe = count_nfe(model) + log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe) + logger.info(log_message) + + if test_loss.item() < best_loss: + best_loss = test_loss.item() + utils.makedirs(args.save) + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + }, os.path.join(args.save, 'checkpt.pth')) + model.train() + + if itr % args.viz_freq == 0: + with torch.no_grad(): + model.eval() + + xx = torch.linspace(-10, 10, 10000).view(-1, 1) + true_p = data_density(xx) + plt.plot(xx.view(-1).cpu().numpy(), true_p.view(-1).exp().cpu().numpy(), label='True') + + true_p = model_density(xx, model) + plt.plot(xx.view(-1).cpu().numpy(), true_p.view(-1).exp().cpu().numpy(), label='Model') + + utils.makedirs(os.path.join(args.save, 'figs')) + plt.savefig(os.path.join(args.save, 'figs', '{:06d}.jpg'.format(itr))) + plt.close() + + model.train() + + end = time.time() + + logger.info('Training has finished.') + + +def evaluate(): + model = build_model_tabular(args, 1).to(device) + set_cnf_options(args, model) + + checkpt = torch.load(os.path.join(args.save, 'checkpt.pth')) + model.load_state_dict(checkpt['state_dict']) + model.to(device) + + tols = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] + errors = [] + with torch.no_grad(): + for tol in tols: + args.rtol = tol + args.atol = tol + set_cnf_options(args, model) + + xx = torch.linspace(-15, 15, 500000).view(-1, 1).to(device) + prob_xx = model_density(xx, model).double().view(-1).cpu() + xx = xx.double().cpu().view(-1) + dxx = torch.log(xx[1:] - xx[:-1]) + num_integral = torch.logsumexp(prob_xx[:-1] + dxx, 0).exp() + errors.append(float(torch.abs(num_integral - 1.))) + + print(errors[-1]) + + plt.figure(figsize=(5, 3)) + plt.plot(tols, errors, linewidth=3, marker='o', markersize=7) + # plt.plot([-1, 0.2], [-1, 0.2], '--', color='grey', linewidth=1) + plt.xscale("log", nonposx='clip') + # plt.yscale("log", nonposy='clip') + plt.xlabel('Solver Tolerance', fontsize=17) + plt.ylabel('$| 1 - \int p(x) |$', fontsize=17) + plt.tight_layout() + plt.savefig('ode_solver_error_vs_tol.pdf') + + +if __name__ == '__main__': + # train() + evaluate() diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_bottleneck_losses.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_bottleneck_losses.py new file mode 100644 index 0000000..82761c0 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_bottleneck_losses.py @@ -0,0 +1,70 @@ +import re +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import scipy.signal +import scipy.ndimage + +# BASE = "experiments/cnf_mnist_64-64-128-128-64-64/logs" +# RESIDUAL = "experiments/cnf_mnist_64-64-128-128-64-64_residual/logs" +# RADEMACHER = "experiments/cnf_mnist_64-64-128-128-64-64_rademacher/logs" + +BOTTLENECK = "experiments/cnf_mnist_bottleneck_64-64-128-5-128-64-64/logs" +BOTTLENECK_EST = "experiments/cnf_mnist_bottleneck_64-64-128-5-128-64-64_ae-est/logs" +RAD_BOTTLENECK = "experiments/cnf_mnist_bottleneck_64-64-128-5-128-64-64_rademacher/logs" +RAD_BOTTLENECK_EST = "experiments/cnf_mnist_bottleneck_64-64-128-5-128-64-64_ae-est_rademacher/logs" + +# ET_ALL = "experiments/cnf_mnist_bottleneck_64-64-128-5-128-64-64_ae-est_residual_rademacher/logs" + + +def get_losses(filename): + with open(filename, "r") as f: + lines = f.readlines() + + losses = [] + + for line in lines: + w = re.findall(r"Bit/dim [^|(]*\([0-9\.]*\)", line) + if w: w = re.findall(r"\([0-9\.]*\)", w[0]) + if w: w = re.findall(r"[0-9\.]+", w[0]) + if w: + losses.append(float(w[0])) + + return losses + + +bottleneck_loss = get_losses(BOTTLENECK) +bottleneck_est_loss = get_losses(BOTTLENECK_EST) +rademacher_bottleneck_loss = get_losses(RAD_BOTTLENECK) +rademacher_bottleneck_est_loss = get_losses(RAD_BOTTLENECK_EST) + +bottleneck_loss = scipy.signal.medfilt(bottleneck_loss, 21) +bottleneck_est_loss = scipy.signal.medfilt(bottleneck_est_loss, 21) +rademacher_bottleneck_loss = scipy.signal.medfilt(rademacher_bottleneck_loss, 21) +rademacher_bottleneck_est_loss = scipy.signal.medfilt(rademacher_bottleneck_est_loss, 21) + +import seaborn as sns +sns.set_style("whitegrid") +colors = ["windows blue", "amber", "greyish", "faded green", "dusty purple"] +sns.palplot(sns.xkcd_palette(colors)) + +import brewer2mpl +line_colors = brewer2mpl.get_map('Set2', 'qualitative', 4).mpl_colors +dark_colors = brewer2mpl.get_map('Dark2', 'qualitative', 4).mpl_colors +# plt.style.use('ggplot') + +plt.figure(figsize=(4, 3)) +plt.plot(np.arange(len(bottleneck_loss)) / 30, bottleneck_loss, ':', color=line_colors[1], label="Gaussian w/o Trick") +plt.plot(np.arange(len(bottleneck_est_loss)) / 30, bottleneck_est_loss, color=dark_colors[1], label="Gaussian w/ Trick") +plt.plot(np.arange(len(rademacher_bottleneck_loss)) / 30, rademacher_bottleneck_loss, ':', color=line_colors[2], label="Rademacher w/o Trick") +plt.plot(np.arange(len(rademacher_bottleneck_est_loss)) / 30, rademacher_bottleneck_est_loss, color=dark_colors[2], label="Rademacher w/ Trick") + +plt.legend(frameon=True, fontsize=10.5, loc='upper right') +plt.ylim([1.1, 1.7]) +# plt.yscale("log", nonposy='clip') +plt.xlabel("Epoch", fontsize=18) +plt.ylabel("Bits/dim", fontsize=18) +plt.xlim([0, 170]) +plt.tight_layout() +plt.savefig('bottleneck_losses.pdf') diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_compare_multiscale.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_compare_multiscale.py new file mode 100644 index 0000000..36ff110 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_compare_multiscale.py @@ -0,0 +1,60 @@ +import re +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +MNIST_SINGLESCALE = "diagnostics/mnist.log" +MNIST_MULTISCALE = "diagnostics/mnist_multiscale.log" + + +def get_values(filename): + with open(filename, "r") as f: + lines = f.readlines() + + losses = [] + nfes = [] + + for line in lines: + + w = re.findall(r"Steps [^|(]*\([0-9\.]*\)", line) + if w: w = re.findall(r"\([0-9\.]*\)", w[0]) + if w: w = re.findall(r"[0-9\.]+", w[0]) + if w: + nfes.append(float(w[0])) + + w = re.findall(r"Bit/dim [^|(]*\([0-9\.]*\)", line) + if w: w = re.findall(r"\([0-9\.]*\)", w[0]) + if w: w = re.findall(r"[0-9\.]+", w[0]) + if w: + losses.append(float(w[0])) + + return losses, nfes + + +mnist_singlescale_loss, mnist_singlescale_nfes = get_values(MNIST_SINGLESCALE) +mnist_multiscale_loss, mnist_multiscale_nfes = get_values(MNIST_MULTISCALE) + +import brewer2mpl +line_colors = brewer2mpl.get_map('Set2', 'qualitative', 4).mpl_colors +dark_colors = brewer2mpl.get_map('Dark2', 'qualitative', 4).mpl_colors + +import seaborn as sns +sns.set_style("whitegrid") +colors = ["windows blue", "amber", "greyish", "faded green", "dusty purple"] +sns.palplot(sns.xkcd_palette(colors)) + +plt.figure(figsize=(4, 2.6)) +plt.scatter(mnist_singlescale_nfes[::10], mnist_singlescale_loss[::10], color=line_colors[1], label="Single FFJORD") +plt.scatter(mnist_multiscale_nfes[::10], mnist_multiscale_loss[::10], color=line_colors[2], label="Multiscale FFJORD") + +plt.ylim([0.9, 1.25]) +plt.legend(frameon=True, fontsize=10.5) +plt.xlabel("NFE", fontsize=18) +plt.ylabel("Bits/dim", fontsize=18) + +ax = plt.gca() +ax.tick_params(axis='both', which='major', labelsize=14) +ax.tick_params(axis='both', which='minor', labelsize=10) + +plt.tight_layout() +plt.savefig('multiscale_loss_vs_nfe.pdf') diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_flows.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_flows.py new file mode 100644 index 0000000..d284a33 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_flows.py @@ -0,0 +1,157 @@ +from inspect import getsourcefile +import sys +import os + +current_path = os.path.abspath(getsourcefile(lambda: 0)) +current_dir = os.path.dirname(current_path) +parent_dir = current_dir[:current_dir.rfind(os.path.sep)] +sys.path.insert(0, parent_dir) + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import argparse +import os + +import torch + +import lib.toy_data as toy_data +import lib.utils as utils +import lib.visualize_flow as viz_flow +import lib.layers.odefunc as odefunc +import lib.layers as layers + +from train_misc import standard_normal_logprob +from train_misc import build_model_tabular, count_parameters + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser('Continuous Normalizing Flow') +parser.add_argument( + '--data', choices=['swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings'], + type=str, default='pinwheel' +) + +parser.add_argument('--discrete', action='store_true') + +parser.add_argument('--depth', help='number of coupling layers', type=int, default=10) +parser.add_argument('--glow', type=eval, choices=[True, False], default=False) + +parser.add_argument( + "--layer_type", type=str, default="concatsquash", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument('--dims', type=str, default='64-64-64') +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') +parser.add_argument('--time_length', type=float, default=0.5) +parser.add_argument('--train_T', type=eval, default=True) +parser.add_argument("--divergence_fn", type=str, default="brute_force", choices=["brute_force", "approximate"]) +parser.add_argument("--nonlinearity", type=str, default="tanh", choices=odefunc.NONLINEARITIES) + +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) + +parser.add_argument('--niters', type=int, default=2500) +parser.add_argument('--batch_size', type=int, default=100) +parser.add_argument('--test_batch_size', type=int, default=1000) +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--weight_decay', type=float, default=1e-5) + +parser.add_argument('--checkpt', type=str, required=True) +parser.add_argument('--save', type=str, default='experiments/cnf') +parser.add_argument('--viz_freq', type=int, default=100) +parser.add_argument('--val_freq', type=int, default=100) +parser.add_argument('--log_freq', type=int, default=10) +parser.add_argument('--gpu', type=int, default=0) +args = parser.parse_args() + +device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + + +def construct_discrete_model(): + + chain = [] + for i in range(args.depth): + if args.glow: chain.append(layers.BruteForceLayer(2)) + chain.append(layers.CouplingLayer(2, swap=i % 2 == 0)) + return layers.SequentialFlow(chain) + + +def get_transforms(model): + + def sample_fn(z, logpz=None): + if logpz is not None: + return model(z, logpz, reverse=True) + else: + return model(z, reverse=True) + + def density_fn(x, logpx=None): + if logpx is not None: + return model(x, logpx, reverse=False) + else: + return model(x, reverse=False) + + return sample_fn, density_fn + + +if __name__ == '__main__': + + if args.discrete: + model = construct_discrete_model().to(device) + model.load_state_dict(torch.load(args.checkpt)['state_dict']) + else: + model = build_model_tabular(args, 2).to(device) + + sd = torch.load(args.checkpt)['state_dict'] + fixed_sd = {} + for k, v in sd.items(): + fixed_sd[k.replace('odefunc.odefunc', 'odefunc')] = v + model.load_state_dict(fixed_sd) + + print(model) + print("Number of trainable parameters: {}".format(count_parameters(model))) + + model.eval() + p_samples = toy_data.inf_train_gen(args.data, batch_size=800**2) + + with torch.no_grad(): + sample_fn, density_fn = get_transforms(model) + + plt.figure(figsize=(10, 10)) + ax = ax = plt.gca() + viz_flow.plt_samples(p_samples, ax, npts=800) + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + fig_filename = os.path.join(args.save, 'figs', 'true_samples.jpg') + utils.makedirs(os.path.dirname(fig_filename)) + plt.savefig(fig_filename) + plt.close() + + plt.figure(figsize=(10, 10)) + ax = ax = plt.gca() + viz_flow.plt_flow_density(standard_normal_logprob, density_fn, ax, npts=800, memory=200, device=device) + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + fig_filename = os.path.join(args.save, 'figs', 'model_density.jpg') + utils.makedirs(os.path.dirname(fig_filename)) + plt.savefig(fig_filename) + plt.close() + + plt.figure(figsize=(10, 10)) + ax = ax = plt.gca() + viz_flow.plt_flow_samples(torch.randn, sample_fn, ax, npts=800, memory=200, device=device) + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + fig_filename = os.path.join(args.save, 'figs', 'model_samples.jpg') + utils.makedirs(os.path.dirname(fig_filename)) + plt.savefig(fig_filename) + plt.close() diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_nfe_vs_dim_vae.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_nfe_vs_dim_vae.py new file mode 100644 index 0000000..26126c5 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_nfe_vs_dim_vae.py @@ -0,0 +1,50 @@ +import os.path +import re +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import scipy.ndimage + +import seaborn as sns +sns.set_style("whitegrid") +colors = ["windows blue", "amber", "greyish", "faded green", "dusty purple"] +sns.palplot(sns.xkcd_palette(colors)) + +dims = [16, 32, 48, 64] +dirs = [ + 'vae_mnist_cnf_num_flows_4_256-256_num_blocks_1__2018-09-16_17_27_03', + 'vae_mnist_cnf_num_flows_4_256-256_num_blocks_1__2018-09-16_17_26_41', + 'vae_mnist_cnf_num_flows_4_256-256_num_blocks_1__2018-09-16_17_23_35', + 'vae_mnist_cnf_num_flows_4_256-256_num_blocks_1__2018-09-16_17_25_03', +] + +nfe_all = [] + +for dim, dirname in zip(dims, dirs): + with open(os.path.join('snapshots', dirname, 'logs'), 'r') as f: + lines = f.readlines() + + nfes_ = [] + + for line in lines: + w = re.findall(r"NFE Forward [0-9]*", line) + if w: w = re.findall(r"[0-9]+", w[0]) + if w: + nfes_.append(float(w[0])) + + nfe_all.append(nfes_) + +plt.figure(figsize=(4, 2.4)) +for i, (dim, nfes) in enumerate(zip(dims, nfe_all)): + nfes = np.array(nfes) + xx = (np.arange(len(nfes)) + 1) / 50 + nfes = scipy.ndimage.gaussian_filter(nfes, 101) + plt.plot(xx, nfes, '--', label='Dim {}'.format(dim)) + +plt.legend(frameon=True, fontsize=10.5) +plt.xlabel('Epoch', fontsize=18) +plt.ylabel('NFE', fontsize=18) +plt.xlim([0, 200]) +plt.tight_layout() +plt.savefig("nfes_vs_dim_vae.pdf") diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_sn_losses.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_sn_losses.py new file mode 100644 index 0000000..fc425f8 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/plot_sn_losses.py @@ -0,0 +1,81 @@ +import re +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +CIFAR10 = "diagnostics/cifar10_multiscale.log" +CIFAR10_SN = "diagnostics/cifar10_multiscale_sn.log" + +MNIST = "diagnostics/mnist_multiscale.log" +MNIST_SN = "diagnostics/mnist_multiscale_sn.log" + + +def get_values(filename): + with open(filename, "r") as f: + lines = f.readlines() + + losses = [] + nfes = [] + + for line in lines: + + w = re.findall(r"Steps [^|(]*\([0-9\.]*\)", line) + if w: w = re.findall(r"\([0-9\.]*\)", w[0]) + if w: w = re.findall(r"[0-9\.]+", w[0]) + if w: + nfes.append(float(w[0])) + + w = re.findall(r"Bit/dim [^|(]*\([0-9\.]*\)", line) + if w: w = re.findall(r"\([0-9\.]*\)", w[0]) + if w: w = re.findall(r"[0-9\.]+", w[0]) + if w: + losses.append(float(w[0])) + + return losses, nfes + + +cifar10_loss, cifar10_nfes = get_values(CIFAR10) +cifar10_sn_loss, cifar10_sn_nfes = get_values(CIFAR10_SN) +mnist_loss, mnist_nfes = get_values(MNIST) +mnist_sn_loss, mnist_sn_nfes = get_values(MNIST_SN) + +import brewer2mpl +line_colors = brewer2mpl.get_map('Set2', 'qualitative', 4).mpl_colors +dark_colors = brewer2mpl.get_map('Dark2', 'qualitative', 4).mpl_colors +plt.style.use('ggplot') + +# CIFAR10 plot +plt.figure(figsize=(6, 7)) +plt.scatter(cifar10_nfes, cifar10_loss, color=line_colors[1], label="w/o Spectral Norm") +plt.scatter(cifar10_sn_nfes, cifar10_sn_loss, color=line_colors[2], label="w/ Spectral Norm") + +plt.ylim([3, 5]) +plt.legend(fontsize=18) +plt.xlabel("NFE", fontsize=30) +plt.ylabel("Bits/dim", fontsize=30) + +ax = plt.gca() +ax.tick_params(axis='both', which='major', labelsize=24) +ax.tick_params(axis='both', which='minor', labelsize=16) +ax.yaxis.set_ticks([3, 3.5, 4, 4.5, 5]) + +plt.tight_layout() +plt.savefig('cifar10_sn_loss_vs_nfe.pdf') + +# MNIST plot +plt.figure(figsize=(6, 7)) +plt.scatter(mnist_nfes, mnist_loss, color=line_colors[1], label="w/o Spectral Norm") +plt.scatter(mnist_sn_nfes, mnist_sn_loss, color=line_colors[2], label="w/ Spectral Norm") + +plt.ylim([0.9, 2]) +plt.legend(fontsize=18) +plt.xlabel("NFE", fontsize=30) +plt.ylabel("Bits/dim", fontsize=30) + +ax = plt.gca() +ax.tick_params(axis='both', which='major', labelsize=24) +ax.tick_params(axis='both', which='minor', labelsize=16) +# ax.yaxis.set_ticks([3, 3.5, 4, 4.5, 5]) + +plt.tight_layout() +plt.savefig('mnist_sn_loss_vs_nfe.pdf') diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/scrap_log.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/scrap_log.py new file mode 100644 index 0000000..c08e99b --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/scrap_log.py @@ -0,0 +1,64 @@ +import os +import re +import csv + + +def log_to_csv(log_filename, csv_filename): + with open(log_filename, 'r') as f: + lines = f.readlines() + + with open(csv_filename, 'w', newline='') as csvfile: + fieldnames = None + writer = None + + for line in lines: + if line.startswith('Iter'): + # A dictionary of quantity : value. + quants = _line_to_dict(line) + # Create writer and write header. + if fieldnames is None: + fieldnames = quants.keys() + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + # Write a line. + writer.writerow(quants) + + +def _line_to_dict(line): + line = re.sub(':', '', line) # strip colons. + line = re.sub('\([^)]*\)', '', line) # strip running averages. + + quants = {} + for quant_str in line.split('|'): + quant_str = quant_str.strip() # strip beginning and ending whitespaces. + key, val = quant_str.split(' ') + quants[key] = val + + return quants + + +def plot_pairplot(csv_filename, fig_filename, top=None): + import seaborn as sns + import pandas as pd + + sns.set(style="ticks", color_codes=True) + quants = pd.read_csv(csv_filename) + if top is not None: + quants = quants[:top] + + g = sns.pairplot(quants, kind='reg', diag_kind='kde', markers='.') + g.savefig(fig_filename) + + +if __name__ == '__main__': + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--log', type=str, required=True) + parser.add_argument('--top_iters', type=int, default=None) + args = parser.parse_args() + + print('Parsing log into csv.') + log_to_csv(args.log, args.log + '.csv') + print('Creating correlation plot.') + plot_pairplot(args.log + '.csv', os.path.join(os.path.dirname(args.log), 'quants.png'), args.top_iters) diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_cnf.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_cnf.py new file mode 100644 index 0000000..3cb73c6 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_cnf.py @@ -0,0 +1,257 @@ +from inspect import getsourcefile +import sys +import os +import subprocess + +current_path = os.path.abspath(getsourcefile(lambda: 0)) +current_dir = os.path.dirname(current_path) +parent_dir = current_dir[:current_dir.rfind(os.path.sep)] +sys.path.insert(0, parent_dir) + +import argparse +import torch +import torchvision.datasets as dset +import torchvision.transforms as tforms +from torchvision.utils import save_image + +import lib.layers as layers +import lib.spectral_norm as spectral_norm +import lib.utils as utils + + +def add_noise(x): + """ + [0, 1] -> [0, 255] -> add noise -> [0, 1] + """ + noise = x.new().resize_as_(x).uniform_() + x = x * 255 + noise + x = x / 256 + return x + + +def get_dataset(args): + trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise]) + + if args.data == "mnist": + im_dim = 1 + im_size = 28 if args.imagesize is None else args.imagesize + train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True) + test_set = dset.MNIST(root="./data", train=False, transform=trans(im_size), download=True) + elif args.data == "svhn": + im_dim = 3 + im_size = 32 if args.imagesize is None else args.imagesize + train_set = dset.SVHN(root="./data", split="train", transform=trans(im_size), download=True) + test_set = dset.SVHN(root="./data", split="test", transform=trans(im_size), download=True) + elif args.data == "cifar10": + im_dim = 3 + im_size = 32 if args.imagesize is None else args.imagesize + train_set = dset.CIFAR10(root="./data", train=True, transform=trans(im_size), download=True) + test_set = dset.CIFAR10(root="./data", train=False, transform=trans(im_size), download=True) + elif args.dataset == 'celeba': + im_dim = 3 + im_size = 64 if args.imagesize is None else args.imagesize + train_set = dset.CelebA( + train=True, transform=tforms.Compose([ + tforms.ToPILImage(), + tforms.Resize(im_size), + tforms.RandomHorizontalFlip(), + tforms.ToTensor(), + add_noise, + ]) + ) + test_set = dset.CelebA( + train=False, transform=tforms.Compose([ + tforms.ToPILImage(), + tforms.Resize(args.imagesize), + tforms.ToTensor(), + add_noise, + ]) + ) + data_shape = (im_dim, im_size, im_size) + + train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True) + test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=args.batch_size, shuffle=False) + return train_loader, test_loader, data_shape + + +def add_spectral_norm(model): + def recursive_apply_sn(parent_module): + for child_name in list(parent_module._modules.keys()): + child_module = parent_module._modules[child_name] + classname = child_module.__class__.__name__ + if classname.find('Conv') != -1 and 'weight' in child_module._parameters: + del parent_module._modules[child_name] + parent_module.add_module(child_name, spectral_norm.spectral_norm(child_module, 'weight')) + else: + recursive_apply_sn(child_module) + + recursive_apply_sn(model) + + +def build_model(args, state_dict): + # load dataset + train_loader, test_loader, data_shape = get_dataset(args) + + hidden_dims = tuple(map(int, args.dims.split(","))) + strides = tuple(map(int, args.strides.split(","))) + + # neural net that parameterizes the velocity field + if args.autoencode: + + def build_cnf(): + autoencoder_diffeq = layers.AutoencoderDiffEqNet( + hidden_dims=hidden_dims, + input_shape=data_shape, + strides=strides, + conv=args.conv, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = layers.AutoencoderODEfunc( + autoencoder_diffeq=autoencoder_diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + cnf = layers.CNF( + odefunc=odefunc, + T=args.time_length, + solver=args.solver, + ) + return cnf + else: + + def build_cnf(): + diffeq = layers.ODEnet( + hidden_dims=hidden_dims, + input_shape=data_shape, + strides=strides, + conv=args.conv, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = layers.ODEfunc( + diffeq=diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + cnf = layers.CNF( + odefunc=odefunc, + T=args.time_length, + solver=args.solver, + ) + return cnf + + chain = [layers.LogitTransform(alpha=args.alpha), build_cnf()] + if args.batch_norm: + chain.append(layers.MovingBatchNorm2d(data_shape[0])) + model = layers.SequentialFlow(chain) + + if args.spectral_norm: + add_spectral_norm(model) + + model.load_state_dict(state_dict) + + return model, test_loader.dataset + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Visualizes experiments trained using train_cnf.py.") + parser.add_argument("--checkpt", type=str, required=True) + parser.add_argument("--nsamples", type=int, default=50) + parser.add_argument("--ntimes", type=int, default=100) + parser.add_argument("--save", type=str, default="imgs") + args = parser.parse_args() + + checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) + ck_args = checkpt["args"] + state_dict = checkpt["state_dict"] + + model, test_set = build_model(ck_args, state_dict) + real_samples = torch.stack([test_set[i][0] for i in range(args.nsamples)], dim=0) + data_shape = real_samples.shape[1:] + fake_latents = torch.randn(args.nsamples, *data_shape) + + # Transfer to GPU if available. + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print("Running on {}".format(device)) + model.to(device) + real_samples = real_samples.to(device) + fake_latents = fake_latents.to(device) + + # Construct fake samples + fake_samples = model(fake_latents, reverse=True).view(-1, *data_shape) + samples = torch.cat([real_samples, fake_samples], dim=0) + + still_diffeq = torch.zeros_like(samples) + im_indx = 0 + + # Image-saving helper function + def save_im(im, diffeq): + global im_indx + filename = os.path.join(current_dir, args.save, "flow_%05d.png" % im_indx) + utils.makedirs(os.path.dirname(filename)) + + diffeq = diffeq.clone() + de_min, de_max = float(diffeq.min()), float(diffeq.max()) + diffeq.clamp_(min=de_min, max=de_max) + diffeq.add_(-de_min).div_(de_max - de_min + 1e-5) + + assert im.shape == diffeq.shape + shape = im.shape + interleaved = torch.stack([im, diffeq]).transpose(0, 1).contiguous().view(2 * shape[0], *shape[1:]) + save_image(interleaved, filename, nrow=20, padding=0, range=(0, 1)) + im_indx += 1 + + # Still frames with image samples. + for _ in range(30): + save_im(samples, still_diffeq) + + # Forward image to latent. + logits = model.chain[0](samples) + for i in range(1, len(model.chain)): + assert isinstance(model.chain[i], layers.CNF) + cnf = model.chain[i] + tt = torch.linspace(cnf.integration_times[0], cnf.integration_times[-1], args.ntimes) + z_t = cnf(logits, integration_times=tt) + logits = z_t[-1] + + # transform back to image space + im_t = model.chain[0](z_t.view(args.ntimes * args.nsamples * 2, *data_shape), + reverse=True).view(args.ntimes, 2 * args.nsamples, *data_shape) + + # save each step as an image + for t, im in zip(tt, im_t): + diffeq = cnf.odefunc(t, (im, None))[0] + diffeq = model.chain[0](diffeq, reverse=True) + save_im(im, diffeq) + + # Still frames with latent samples. + latents = model.chain[0](logits, reverse=True) + for _ in range(30): + save_im(latents, still_diffeq) + + # Forward image to latent. + for i in range(len(model.chain) - 1, 0, -1): + assert isinstance(model.chain[i], layers.CNF) + cnf = model.chain[i] + tt = torch.linspace(cnf.integration_times[-1], cnf.integration_times[0], args.ntimes) + z_t = cnf(logits, integration_times=tt) + logits = z_t[-1] + + # transform back to image space + im_t = model.chain[0](z_t.view(args.ntimes * args.nsamples * 2, *data_shape), + reverse=True).view(args.ntimes, 2 * args.nsamples, *data_shape) + # save each step as an image + for t, im in zip(tt, im_t): + diffeq = cnf.odefunc(t, (im, None))[0] + diffeq = model.chain[0](diffeq, reverse=True) + save_im(im, -diffeq) + + # Combine the images into a movie + bashCommand = r"ffmpeg -y -i {}/flow_%05d.png {}".format( + os.path.join(current_dir, args.save), os.path.join(current_dir, args.save, "flow.mp4") + ) + process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) + output, error = process.communicate() diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_high_fidelity_toy.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_high_fidelity_toy.py new file mode 100644 index 0000000..b4181c2 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_high_fidelity_toy.py @@ -0,0 +1,132 @@ +import os +import math +from tqdm import tqdm +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import torch + + +def standard_normal_logprob(z): + logZ = -0.5 * math.log(2 * math.pi) + return logZ - z.pow(2) / 2 + + +def makedirs(dirname): + if not os.path.exists(dirname): + os.makedirs(dirname) + + +def save_density_traj(model, data_samples, savedir, ntimes=101, memory=0.01, device='cpu'): + model.eval() + + # sample from a grid + npts = 800 + side = np.linspace(-4, 4, npts) + xx, yy = np.meshgrid(side, side) + xx = torch.from_numpy(xx).type(torch.float32).to(device) + yy = torch.from_numpy(yy).type(torch.float32).to(device) + z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1) + + with torch.no_grad(): + # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container. + logpz_grid = torch.sum(standard_normal_logprob(z_grid), 1, keepdim=True) + for cnf in model.chain: + end_time = cnf.sqrt_end_time * cnf.sqrt_end_time + viz_times = torch.linspace(0., end_time, ntimes) + + logpz_grid = [standard_normal_logprob(z_grid).sum(1, keepdim=True)] + for t in tqdm(viz_times[1:]): + inds = torch.arange(0, z_grid.shape[0]).to(torch.int64) + logpz_t = [] + for ii in torch.split(inds, int(z_grid.shape[0] * memory)): + z0, delta_logp = cnf( + z_grid[ii], + torch.zeros(z_grid[ii].shape[0], 1).to(z_grid), integration_times=torch.tensor([0., + t.item()]) + ) + logpz_t.append(standard_normal_logprob(z0).sum(1, keepdim=True) - delta_logp) + logpz_grid.append(torch.cat(logpz_t, 0)) + logpz_grid = torch.stack(logpz_grid, 0).cpu().detach().numpy() + z_grid = z_grid.cpu().detach().numpy() + + plt.figure(figsize=(8, 8)) + for t in range(logpz_grid.shape[0]): + + plt.clf() + ax = plt.gca() + + # plot the density + z, logqz = z_grid, logpz_grid[t] + + xx = z[:, 0].reshape(npts, npts) + yy = z[:, 1].reshape(npts, npts) + qz = np.exp(logqz).reshape(npts, npts) + + plt.pcolormesh(xx, yy, qz, cmap='binary') + ax.set_xlim(-4, 4) + ax.set_ylim(-4, 4) + cmap = matplotlib.cm.get_cmap('binary') + ax.set_axis_bgcolor(cmap(0.)) + ax.invert_yaxis() + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + plt.tight_layout() + + makedirs(savedir) + plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg")) + + +def trajectory_to_video(savedir): + import subprocess + bashCommand = 'ffmpeg -y -i {} {}'.format(os.path.join(savedir, 'viz-%05d.jpg'), os.path.join(savedir, 'traj.mp4')) + process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) + output, error = process.communicate() + + +if __name__ == '__main__': + import argparse + import sys + + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))) + + import lib.toy_data as toy_data + from train_misc import count_parameters + from train_misc import set_cnf_options, add_spectral_norm, create_regularization_fns + from train_misc import build_model_tabular + + def get_ckpt_model_and_data(args): + # Load checkpoint. + checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) + ckpt_args = checkpt['args'] + state_dict = checkpt['state_dict'] + + # Construct model and restore checkpoint. + regularization_fns, regularization_coeffs = create_regularization_fns(ckpt_args) + model = build_model_tabular(ckpt_args, 2, regularization_fns).to(device) + if ckpt_args.spectral_norm: add_spectral_norm(model) + set_cnf_options(ckpt_args, model) + + model.load_state_dict(state_dict) + model.to(device) + + print(model) + print("Number of trainable parameters: {}".format(count_parameters(model))) + + # Load samples from dataset + data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000) + + return model, data_samples + + parser = argparse.ArgumentParser() + parser.add_argument('--checkpt', type=str, required=True) + parser.add_argument('--ntimes', type=int, default=101) + parser.add_argument('--memory', type=float, default=0.01, help='Higher this number, the more memory is consumed.') + parser.add_argument('--save', type=str, default='trajectory') + args = parser.parse_args() + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + model, data_samples = get_ckpt_model_and_data(args) + save_density_traj(model, data_samples, args.save, ntimes=args.ntimes, memory=args.memory, device=device) + trajectory_to_video(args.save) diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_multiscale.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_multiscale.py new file mode 100644 index 0000000..b3fcd9d --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_multiscale.py @@ -0,0 +1,222 @@ +from inspect import getsourcefile +import sys +import os +import math + +current_path = os.path.abspath(getsourcefile(lambda: 0)) +current_dir = os.path.dirname(current_path) +parent_dir = current_dir[:current_dir.rfind(os.path.sep)] +sys.path.insert(0, parent_dir) + +import argparse + +import lib.layers as layers +import lib.odenvp as odenvp +import torch +import torchvision.transforms as tforms +import torchvision.datasets as dset +from torchvision.utils import save_image +import lib.utils as utils + +from train_misc import add_spectral_norm, set_cnf_options, count_parameters + +parser = argparse.ArgumentParser("Continuous Normalizing Flow") +parser.add_argument("--checkpt", type=str, required=True) +parser.add_argument("--data", choices=["mnist", "svhn", "cifar10", 'lsun_church'], type=str, default="cifar10") +parser.add_argument("--dims", type=str, default="64,64,64") +parser.add_argument("--num_blocks", type=int, default=2, help='Number of stacked CNFs.') +parser.add_argument("--divergence_fn", type=str, default="approximate", choices=["brute_force", "approximate"]) +parser.add_argument( + "--nonlinearity", type=str, default="softplus", choices=["tanh", "relu", "softplus", "elu", "swish"] +) +parser.add_argument("--conv", type=eval, default=True, choices=[True, False]) + +parser.add_argument('--solver', type=str, default='dopri5') +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument("--imagesize", type=int, default=None) +parser.add_argument("--alpha", type=float, default=-1.0) +parser.add_argument('--time_length', type=float, default=1.0) +parser.add_argument('--train_T', type=eval, default=True) + +parser.add_argument("--add_noise", type=eval, default=True, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=True, choices=[True, False]) +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) + +parser.add_argument('--ntimes', type=int, default=50) +parser.add_argument('--save', type=str, default='img_trajectory') + +args = parser.parse_args() + +BATCH_SIZE = 8 * 8 + + +def add_noise(x): + """ + [0, 1] -> [0, 255] -> add noise -> [0, 1] + """ + if args.add_noise: + noise = x.new().resize_as_(x).uniform_() + x = x * 255 + noise + x = x / 256 + return x + + +def get_dataset(args): + trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise]) + + if args.data == "mnist": + im_dim = 1 + im_size = 28 if args.imagesize is None else args.imagesize + train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True) + elif args.data == "cifar10": + im_dim = 3 + im_size = 32 if args.imagesize is None else args.imagesize + train_set = dset.CIFAR10( + root="./data", train=True, transform=tforms.Compose([ + tforms.Resize(im_size), + tforms.RandomHorizontalFlip(), + tforms.ToTensor(), + add_noise, + ]), download=True + ) + elif args.data == 'lsun_church': + im_dim = 3 + im_size = 64 if args.imagesize is None else args.imagesize + train_set = dset.LSUN( + 'data', ['church_outdoor_train'], transform=tforms.Compose([ + tforms.Resize(96), + tforms.RandomCrop(64), + tforms.Resize(im_size), + tforms.ToTensor(), + add_noise, + ]) + ) + data_shape = (im_dim, im_size, im_size) + if not args.conv: + data_shape = (im_dim * im_size * im_size,) + + return train_set, data_shape + + +def create_model(args, data_shape): + hidden_dims = tuple(map(int, args.dims.split(","))) + + model = odenvp.ODENVP( + (BATCH_SIZE, *data_shape), + n_blocks=args.num_blocks, + intermediate_dims=hidden_dims, + nonlinearity=args.nonlinearity, + alpha=args.alpha, + cnf_kwargs={"T": args.time_length, "train_T": args.train_T}, + ) + if args.spectral_norm: add_spectral_norm(model) + set_cnf_options(args, model) + return model + + +if __name__ == '__main__': + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) + + # load dataset + train_set, data_shape = get_dataset(args) + + # build model + model = create_model(args, data_shape) + + print(model) + print("Number of trainable parameters: {}".format(count_parameters(model))) + + # restore parameters + checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) + pruned_sd = {} + for k, v in checkpt['state_dict'].items(): + pruned_sd[k.replace('odefunc.odefunc', 'odefunc')] = v + model.load_state_dict(pruned_sd) + + train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True) + + data_samples, _ = train_loader.__iter__().__next__() + + # cosine interpolate between 4 real images. + z = data_samples[:4] + print('Inferring base values for 4 example images.') + z = model(z) + + phi0 = torch.linspace(0, 0.5, int(math.sqrt(BATCH_SIZE))) * math.pi + phi1 = torch.linspace(0, 0.5, int(math.sqrt(BATCH_SIZE))) * math.pi + phi0, phi1 = torch.meshgrid([phi0, phi1]) + phi0, phi1 = phi0.contiguous().view(-1, 1), phi1.contiguous().view(-1, 1) + z = torch.cos(phi0) * (torch.cos(phi1) * z[0] + torch.sin(phi1) * z[1]) + \ + torch.sin(phi0) * (torch.cos(phi1) * z[2] + torch.sin(phi1) * z[3]) + print('Reconstructing images from latent interpolation.') + z = model(z, reverse=True) + + non_cnf_layers = [] + + utils.makedirs(args.save) + img_idx = 0 + + def save_imgs_figure(xs): + global img_idx + save_image( + list(xs), + os.path.join(args.save, "img_{:05d}.jpg".format(img_idx)), nrow=int(math.sqrt(BATCH_SIZE)), normalize=True, + range=(0, 1) + ) + img_idx += 1 + + class FactorOut(torch.nn.Module): + + def __init__(self, factor_out): + super(FactorOut, self).__init__() + self.factor_out = factor_out + + def forward(self, x, reverse=True): + assert reverse + T = x.shape[0] // self.factor_out.shape[0] + return torch.cat([x, self.factor_out.repeat(T, *([1] * (self.factor_out.ndimension() - 1)))], 1) + + time_ratio = 1.0 + print('Visualizing transformations.') + with torch.no_grad(): + for idx, stacked_layers in enumerate(model.transforms): + for layer in stacked_layers.chain: + print(z.shape) + print(non_cnf_layers) + if isinstance(layer, layers.CNF): + # linspace over time, and visualize by reversing through previous non_cnf_layers. + cnf = layer + end_time = (cnf.sqrt_end_time * cnf.sqrt_end_time) + ntimes = int(args.ntimes * time_ratio) + integration_times = torch.linspace(0, end_time.item(), ntimes) + z_traj = cnf(z, integration_times=integration_times) + + # reverse z(t) for all times to the input space + z_flatten = z_traj.view(ntimes * BATCH_SIZE, *z_traj.shape[2:]) + for prev_layer in non_cnf_layers[::-1]: + z_flatten = prev_layer(z_flatten, reverse=True) + z_inv = z_flatten.view(ntimes, BATCH_SIZE, *data_shape) + for t in range(1, z_inv.shape[0]): + z_t = z_inv[t] + save_imgs_figure(z_t) + z = z_traj[-1] + else: + # update z and place in non_cnf_layers. + z = layer(z) + non_cnf_layers.append(layer) + if idx < len(model.transforms) - 1: + d = z.shape[1] // 2 + z, factor_out = z[:, :d], z[:, d:] + non_cnf_layers.append(FactorOut(factor_out)) + + # After every factor out, we half the time for visualization. + time_ratio = time_ratio / 2 diff --git a/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_toy.py b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_toy.py new file mode 100644 index 0000000..7ebf029 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/diagnostics/viz_toy.py @@ -0,0 +1,179 @@ +import os +import math +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import torch + + +def standard_normal_logprob(z): + logZ = -0.5 * math.log(2 * math.pi) + return logZ - z.pow(2) / 2 + + +def makedirs(dirname): + if not os.path.exists(dirname): + os.makedirs(dirname) + + +def save_trajectory(model, data_samples, savedir, ntimes=101, memory=0.01, device='cpu'): + model.eval() + + # Sample from prior + z_samples = torch.randn(2000, 2).to(device) + + # sample from a grid + npts = 800 + side = np.linspace(-4, 4, npts) + xx, yy = np.meshgrid(side, side) + xx = torch.from_numpy(xx).type(torch.float32).to(device) + yy = torch.from_numpy(yy).type(torch.float32).to(device) + z_grid = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], 1) + + with torch.no_grad(): + # We expect the model is a chain of CNF layers wrapped in a SequentialFlow container. + logp_samples = torch.sum(standard_normal_logprob(z_samples), 1, keepdim=True) + logp_grid = torch.sum(standard_normal_logprob(z_grid), 1, keepdim=True) + t = 0 + for cnf in model.chain: + end_time = (cnf.sqrt_end_time * cnf.sqrt_end_time) + integration_times = torch.linspace(0, end_time, ntimes) + + z_traj, _ = cnf(z_samples, logp_samples, integration_times=integration_times, reverse=True) + z_traj = z_traj.cpu().numpy() + + grid_z_traj, grid_logpz_traj = [], [] + inds = torch.arange(0, z_grid.shape[0]).to(torch.int64) + for ii in torch.split(inds, int(z_grid.shape[0] * memory)): + _grid_z_traj, _grid_logpz_traj = cnf( + z_grid[ii], logp_grid[ii], integration_times=integration_times, reverse=True + ) + _grid_z_traj, _grid_logpz_traj = _grid_z_traj.cpu().numpy(), _grid_logpz_traj.cpu().numpy() + grid_z_traj.append(_grid_z_traj) + grid_logpz_traj.append(_grid_logpz_traj) + grid_z_traj = np.concatenate(grid_z_traj, axis=1) + grid_logpz_traj = np.concatenate(grid_logpz_traj, axis=1) + + plt.figure(figsize=(8, 8)) + for _ in range(z_traj.shape[0]): + + plt.clf() + + # plot target potential function + ax = plt.subplot(2, 2, 1, aspect="equal") + + ax.hist2d(data_samples[:, 0], data_samples[:, 1], range=[[-4, 4], [-4, 4]], bins=200) + ax.invert_yaxis() + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + ax.set_title("Target", fontsize=32) + + # plot the density + ax = plt.subplot(2, 2, 2, aspect="equal") + + z, logqz = grid_z_traj[t], grid_logpz_traj[t] + + xx = z[:, 0].reshape(npts, npts) + yy = z[:, 1].reshape(npts, npts) + qz = np.exp(logqz).reshape(npts, npts) + + plt.pcolormesh(xx, yy, qz) + ax.set_xlim(-4, 4) + ax.set_ylim(-4, 4) + cmap = matplotlib.cm.get_cmap(None) + ax.set_axis_bgcolor(cmap(0.)) + ax.invert_yaxis() + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + ax.set_title("Density", fontsize=32) + + # plot the samples + ax = plt.subplot(2, 2, 3, aspect="equal") + + zk = z_traj[t] + ax.hist2d(zk[:, 0], zk[:, 1], range=[[-4, 4], [-4, 4]], bins=200) + ax.invert_yaxis() + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + ax.set_title("Samples", fontsize=32) + + # plot vector field + ax = plt.subplot(2, 2, 4, aspect="equal") + + K = 13j + y, x = np.mgrid[-4:4:K, -4:4:K] + K = int(K.imag) + zs = torch.from_numpy(np.stack([x, y], -1).reshape(K * K, 2)).to(device, torch.float32) + logps = torch.zeros(zs.shape[0], 1).to(device, torch.float32) + dydt = cnf.odefunc(integration_times[t], (zs, logps))[0] + dydt = -dydt.cpu().detach().numpy() + dydt = dydt.reshape(K, K, 2) + + logmag = 2 * np.log(np.hypot(dydt[:, :, 0], dydt[:, :, 1])) + ax.quiver( + x, y, dydt[:, :, 0], dydt[:, :, 1], + np.exp(logmag), cmap="coolwarm", scale=20., width=0.015, pivot="mid" + ) + ax.set_xlim(-4, 4) + ax.set_ylim(-4, 4) + ax.axis("off") + ax.set_title("Vector Field", fontsize=32) + + makedirs(savedir) + plt.savefig(os.path.join(savedir, f"viz-{t:05d}.jpg")) + t += 1 + + +def trajectory_to_video(savedir): + import subprocess + bashCommand = 'ffmpeg -y -i {} {}'.format(os.path.join(savedir, 'viz-%05d.jpg'), os.path.join(savedir, 'traj.mp4')) + process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) + output, error = process.communicate() + + +if __name__ == '__main__': + import argparse + import sys + + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))) + + import lib.toy_data as toy_data + from train_misc import count_parameters + from train_misc import set_cnf_options, add_spectral_norm, create_regularization_fns + from train_misc import build_model_tabular + + def get_ckpt_model_and_data(args): + # Load checkpoint. + checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) + ckpt_args = checkpt['args'] + state_dict = checkpt['state_dict'] + + # Construct model and restore checkpoint. + regularization_fns, regularization_coeffs = create_regularization_fns(ckpt_args) + model = build_model_tabular(ckpt_args, 2, regularization_fns).to(device) + if ckpt_args.spectral_norm: add_spectral_norm(model) + set_cnf_options(ckpt_args, model) + + model.load_state_dict(state_dict) + model.to(device) + + print(model) + print("Number of trainable parameters: {}".format(count_parameters(model))) + + # Load samples from dataset + data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000) + + return model, data_samples + + parser = argparse.ArgumentParser() + parser.add_argument('--checkpt', type=str, required=True) + parser.add_argument('--ntimes', type=int, default=101) + parser.add_argument('--memory', type=float, default=0.01, help='Higher this number, the more memory is consumed.') + parser.add_argument('--save', type=str, default='trajectory') + args = parser.parse_args() + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + model, data_samples = get_ckpt_model_and_data(args) + save_trajectory(model, data_samples, args.save, ntimes=args.ntimes, memory=args.memory, device=device) + trajectory_to_video(args.save) diff --git a/src/torchprune/torchprune/util/external/ffjord/imgs/github.png b/src/torchprune/torchprune/util/external/ffjord/imgs/github.png new file mode 100644 index 0000000..487ed2f Binary files /dev/null and b/src/torchprune/torchprune/util/external/ffjord/imgs/github.png differ diff --git a/src/torchprune/torchprune/util/external/ffjord/imgs/maple_leaf.jpg b/src/torchprune/torchprune/util/external/ffjord/imgs/maple_leaf.jpg new file mode 100644 index 0000000..05bcc19 Binary files /dev/null and b/src/torchprune/torchprune/util/external/ffjord/imgs/maple_leaf.jpg differ diff --git a/src/torchprune/torchprune/util/external/ffjord/train_cnf.py b/src/torchprune/torchprune/util/external/ffjord/train_cnf.py new file mode 100644 index 0000000..afbc364 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_cnf.py @@ -0,0 +1,444 @@ +import argparse +import os +import time +import numpy as np + +import torch +import torch.optim as optim +import torchvision.datasets as dset +import torchvision.transforms as tforms +from torchvision.utils import save_image + +import lib.layers as layers +import lib.utils as utils +import lib.odenvp as odenvp +import lib.multiscale_parallel as multiscale_parallel + +from train_misc import standard_normal_logprob +from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time +from train_misc import add_spectral_norm, spectral_norm_power_iteration +from train_misc import create_regularization_fns, get_regularization, append_regularization_to_log + +# go fast boi!! +torch.backends.cudnn.benchmark = True +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams'] +parser = argparse.ArgumentParser("Continuous Normalizing Flow") +parser.add_argument("--data", choices=["mnist", "svhn", "cifar10", 'lsun_church'], type=str, default="mnist") +parser.add_argument("--dims", type=str, default="8,32,32,8") +parser.add_argument("--strides", type=str, default="2,2,1,-2,-2") +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') + +parser.add_argument("--conv", type=eval, default=True, choices=[True, False]) +parser.add_argument( + "--layer_type", type=str, default="ignore", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument("--divergence_fn", type=str, default="approximate", choices=["brute_force", "approximate"]) +parser.add_argument( + "--nonlinearity", type=str, default="softplus", choices=["tanh", "relu", "softplus", "elu", "swish"] +) +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument("--imagesize", type=int, default=None) +parser.add_argument("--alpha", type=float, default=1e-6) +parser.add_argument('--time_length', type=float, default=1.0) +parser.add_argument('--train_T', type=eval, default=True) + +parser.add_argument("--num_epochs", type=int, default=1000) +parser.add_argument("--batch_size", type=int, default=200) +parser.add_argument( + "--batch_size_schedule", type=str, default="", help="Increases the batchsize at every given epoch, dash separated." +) +parser.add_argument("--test_batch_size", type=int, default=200) +parser.add_argument("--lr", type=float, default=1e-3) +parser.add_argument("--warmup_iters", type=float, default=1000) +parser.add_argument("--weight_decay", type=float, default=0.0) +parser.add_argument("--spectral_norm_niter", type=int, default=10) + +parser.add_argument("--add_noise", type=eval, default=True, choices=[True, False]) +parser.add_argument("--batch_norm", type=eval, default=False, choices=[True, False]) +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--autoencode', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=True, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--multiscale', type=eval, default=False, choices=[True, False]) +parser.add_argument('--parallel', type=eval, default=False, choices=[True, False]) + +# Regularizations +parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1") +parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2") +parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2") +parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F") +parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F") +parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") + +parser.add_argument("--time_penalty", type=float, default=0, help="Regularization on the end_time.") +parser.add_argument( + "--max_grad_norm", type=float, default=1e10, + help="Max norm of graidents (default is just stupidly high to avoid any clipping)" +) + +parser.add_argument("--begin_epoch", type=int, default=1) +parser.add_argument("--resume", type=str, default=None) +parser.add_argument("--save", type=str, default="experiments/cnf") +parser.add_argument("--val_freq", type=int, default=1) +parser.add_argument("--log_freq", type=int, default=10) + +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) + +if args.layer_type == "blend": + logger.info("!! Setting time_length from None to 1.0 due to use of Blend layers.") + args.time_length = 1.0 + +logger.info(args) + + +def add_noise(x): + """ + [0, 1] -> [0, 255] -> add noise -> [0, 1] + """ + if args.add_noise: + noise = x.new().resize_as_(x).uniform_() + x = x * 255 + noise + x = x / 256 + return x + + +def update_lr(optimizer, itr): + iter_frac = min(float(itr + 1) / max(args.warmup_iters, 1), 1.0) + lr = args.lr * iter_frac + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def get_train_loader(train_set, epoch): + if args.batch_size_schedule != "": + epochs = [0] + list(map(int, args.batch_size_schedule.split("-"))) + n_passed = sum(np.array(epochs) <= epoch) + current_batch_size = int(args.batch_size * n_passed) + else: + current_batch_size = args.batch_size + train_loader = torch.utils.data.DataLoader( + dataset=train_set, batch_size=current_batch_size, shuffle=True, drop_last=True, pin_memory=True + ) + logger.info("===> Using batch size {}. Total {} iterations/epoch.".format(current_batch_size, len(train_loader))) + return train_loader + + +def get_dataset(args): + trans = lambda im_size: tforms.Compose([tforms.Resize(im_size), tforms.ToTensor(), add_noise]) + + if args.data == "mnist": + im_dim = 1 + im_size = 28 if args.imagesize is None else args.imagesize + train_set = dset.MNIST(root="./data", train=True, transform=trans(im_size), download=True) + test_set = dset.MNIST(root="./data", train=False, transform=trans(im_size), download=True) + elif args.data == "svhn": + im_dim = 3 + im_size = 32 if args.imagesize is None else args.imagesize + train_set = dset.SVHN(root="./data", split="train", transform=trans(im_size), download=True) + test_set = dset.SVHN(root="./data", split="test", transform=trans(im_size), download=True) + elif args.data == "cifar10": + im_dim = 3 + im_size = 32 if args.imagesize is None else args.imagesize + train_set = dset.CIFAR10( + root="./data", train=True, transform=tforms.Compose([ + tforms.Resize(im_size), + tforms.RandomHorizontalFlip(), + tforms.ToTensor(), + add_noise, + ]), download=True + ) + test_set = dset.CIFAR10(root="./data", train=False, transform=trans(im_size), download=True) + elif args.data == 'celeba': + im_dim = 3 + im_size = 64 if args.imagesize is None else args.imagesize + train_set = dset.CelebA( + train=True, transform=tforms.Compose([ + tforms.ToPILImage(), + tforms.Resize(im_size), + tforms.RandomHorizontalFlip(), + tforms.ToTensor(), + add_noise, + ]) + ) + test_set = dset.CelebA( + train=False, transform=tforms.Compose([ + tforms.ToPILImage(), + tforms.Resize(im_size), + tforms.ToTensor(), + add_noise, + ]) + ) + elif args.data == 'lsun_church': + im_dim = 3 + im_size = 64 if args.imagesize is None else args.imagesize + train_set = dset.LSUN( + 'data', ['church_outdoor_train'], transform=tforms.Compose([ + tforms.Resize(96), + tforms.RandomCrop(64), + tforms.Resize(im_size), + tforms.ToTensor(), + add_noise, + ]) + ) + test_set = dset.LSUN( + 'data', ['church_outdoor_val'], transform=tforms.Compose([ + tforms.Resize(96), + tforms.RandomCrop(64), + tforms.Resize(im_size), + tforms.ToTensor(), + add_noise, + ]) + ) + data_shape = (im_dim, im_size, im_size) + if not args.conv: + data_shape = (im_dim * im_size * im_size,) + + test_loader = torch.utils.data.DataLoader( + dataset=test_set, batch_size=args.test_batch_size, shuffle=False, drop_last=True + ) + return train_set, test_loader, data_shape + + +def compute_bits_per_dim(x, model): + zero = torch.zeros(x.shape[0], 1).to(x) + + # Don't use data parallelize if batch size is small. + # if x.shape[0] < 200: + # model = model.module + + z, delta_logp = model(x, zero) # run model forward + + logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) + logpx = logpz - delta_logp + + logpx_per_dim = torch.sum(logpx) / x.nelement() # averaged over batches + bits_per_dim = -(logpx_per_dim - np.log(256)) / np.log(2) + + return bits_per_dim + + +def create_model(args, data_shape, regularization_fns): + hidden_dims = tuple(map(int, args.dims.split(","))) + strides = tuple(map(int, args.strides.split(","))) + + if args.multiscale: + model = odenvp.ODENVP( + (args.batch_size, *data_shape), + n_blocks=args.num_blocks, + intermediate_dims=hidden_dims, + nonlinearity=args.nonlinearity, + alpha=args.alpha, + cnf_kwargs={"T": args.time_length, "train_T": args.train_T, "regularization_fns": regularization_fns}, + ) + elif args.parallel: + model = multiscale_parallel.MultiscaleParallelCNF( + (args.batch_size, *data_shape), + n_blocks=args.num_blocks, + intermediate_dims=hidden_dims, + alpha=args.alpha, + time_length=args.time_length, + ) + else: + if args.autoencode: + + def build_cnf(): + autoencoder_diffeq = layers.AutoencoderDiffEqNet( + hidden_dims=hidden_dims, + input_shape=data_shape, + strides=strides, + conv=args.conv, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = layers.AutoencoderODEfunc( + autoencoder_diffeq=autoencoder_diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + cnf = layers.CNF( + odefunc=odefunc, + T=args.time_length, + regularization_fns=regularization_fns, + solver=args.solver, + ) + return cnf + else: + + def build_cnf(): + diffeq = layers.ODEnet( + hidden_dims=hidden_dims, + input_shape=data_shape, + strides=strides, + conv=args.conv, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = layers.ODEfunc( + diffeq=diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + cnf = layers.CNF( + odefunc=odefunc, + T=args.time_length, + train_T=args.train_T, + regularization_fns=regularization_fns, + solver=args.solver, + ) + return cnf + + chain = [layers.LogitTransform(alpha=args.alpha)] if args.alpha > 0 else [layers.ZeroMeanTransform()] + chain = chain + [build_cnf() for _ in range(args.num_blocks)] + if args.batch_norm: + chain.append(layers.MovingBatchNorm2d(data_shape[0])) + model = layers.SequentialFlow(chain) + return model + + +if __name__ == "__main__": + + # get deivce + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) + + # load dataset + train_set, test_loader, data_shape = get_dataset(args) + + # build model + regularization_fns, regularization_coeffs = create_regularization_fns(args) + model = create_model(args, data_shape, regularization_fns) + + if args.spectral_norm: add_spectral_norm(model, logger) + set_cnf_options(args, model) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + # optimizer + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + # restore parameters + if args.resume is not None: + checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage) + model.load_state_dict(checkpt["state_dict"]) + if "optim_state_dict" in checkpt.keys(): + optimizer.load_state_dict(checkpt["optim_state_dict"]) + # Manually move optimizer state to device. + for state in optimizer.state.values(): + for k, v in state.items(): + if torch.is_tensor(v): + state[k] = cvt(v) + + if torch.cuda.is_available(): + model = torch.nn.DataParallel(model).cuda() + + # For visualization. + fixed_z = cvt(torch.randn(100, *data_shape)) + + time_meter = utils.RunningAverageMeter(0.97) + loss_meter = utils.RunningAverageMeter(0.97) + steps_meter = utils.RunningAverageMeter(0.97) + grad_meter = utils.RunningAverageMeter(0.97) + tt_meter = utils.RunningAverageMeter(0.97) + + if args.spectral_norm and not args.resume: spectral_norm_power_iteration(model, 500) + + best_loss = float("inf") + itr = 0 + for epoch in range(args.begin_epoch, args.num_epochs + 1): + model.train() + train_loader = get_train_loader(train_set, epoch) + for _, (x, y) in enumerate(train_loader): + start = time.time() + update_lr(optimizer, itr) + optimizer.zero_grad() + + if not args.conv: + x = x.view(x.shape[0], -1) + + # cast data and move to device + x = cvt(x) + # compute loss + loss = compute_bits_per_dim(x, model) + if regularization_coeffs: + reg_states = get_regularization(model, regularization_coeffs) + reg_loss = sum( + reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 + ) + loss = loss + reg_loss + total_time = count_total_time(model) + loss = loss + total_time * args.time_penalty + + loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + optimizer.step() + + if args.spectral_norm: spectral_norm_power_iteration(model, args.spectral_norm_niter) + + time_meter.update(time.time() - start) + loss_meter.update(loss.item()) + steps_meter.update(count_nfe(model)) + grad_meter.update(grad_norm) + tt_meter.update(total_time) + + if itr % args.log_freq == 0: + log_message = ( + "Iter {:04d} | Time {:.4f}({:.4f}) | Bit/dim {:.4f}({:.4f}) | " + "Steps {:.0f}({:.2f}) | Grad Norm {:.4f}({:.4f}) | Total Time {:.2f}({:.2f})".format( + itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, steps_meter.val, + steps_meter.avg, grad_meter.val, grad_meter.avg, tt_meter.val, tt_meter.avg + ) + ) + if regularization_coeffs: + log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) + logger.info(log_message) + + itr += 1 + + # compute test loss + model.eval() + if epoch % args.val_freq == 0: + with torch.no_grad(): + start = time.time() + logger.info("validating...") + losses = [] + for (x, y) in test_loader: + if not args.conv: + x = x.view(x.shape[0], -1) + x = cvt(x) + loss = compute_bits_per_dim(x, model) + losses.append(loss) + + loss = np.mean(losses) + logger.info("Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}".format(epoch, time.time() - start, loss)) + if loss < best_loss: + best_loss = loss + utils.makedirs(args.save) + torch.save({ + "args": args, + "state_dict": model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), + "optim_state_dict": optimizer.state_dict(), + }, os.path.join(args.save, "checkpt.pth")) + + # visualize samples and density + with torch.no_grad(): + fig_filename = os.path.join(args.save, "figs", "{:04d}.jpg".format(epoch)) + utils.makedirs(os.path.dirname(fig_filename)) + generated_samples = model(fixed_z, reverse=True).view(-1, *data_shape) + save_image(generated_samples, fig_filename, nrow=10) diff --git a/src/torchprune/torchprune/util/external/ffjord/train_discrete_tabular.py b/src/torchprune/torchprune/util/external/ffjord/train_discrete_tabular.py new file mode 100644 index 0000000..4492842 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_discrete_tabular.py @@ -0,0 +1,235 @@ +import argparse +import os +import time + +import torch + +import lib.utils as utils +from lib.custom_optimizers import Adam +import lib.layers as layers + +import datasets + +from train_misc import standard_normal_logprob, count_parameters + +parser = argparse.ArgumentParser() +parser.add_argument( + '--data', choices=['power', 'gas', 'hepmass', 'miniboone', 'bsds300'], type=str, default='miniboone' +) + +parser.add_argument('--depth', type=int, default=10) +parser.add_argument('--dims', type=str, default="100-100") +parser.add_argument('--nonlinearity', type=str, default="tanh") +parser.add_argument('--glow', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) + +parser.add_argument('--early_stopping', type=int, default=30) +parser.add_argument('--batch_size', type=int, default=1000) +parser.add_argument('--test_batch_size', type=int, default=None) +parser.add_argument('--lr', type=float, default=1e-4) +parser.add_argument('--weight_decay', type=float, default=1e-6) + +parser.add_argument('--resume', type=str, default=None) +parser.add_argument('--save', type=str, default='experiments/cnf') +parser.add_argument('--evaluate', action='store_true') +parser.add_argument('--val_freq', type=int, default=200) +parser.add_argument('--log_freq', type=int, default=10) +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) +logger.info(args) + +test_batch_size = args.test_batch_size if args.test_batch_size else args.batch_size + + +def batch_iter(X, batch_size=args.batch_size, shuffle=False): + """ + X: feature tensor (shape: num_instances x num_features) + """ + if shuffle: + idxs = torch.randperm(X.shape[0]) + else: + idxs = torch.arange(X.shape[0]) + if X.is_cuda: + idxs = idxs.cuda() + for batch_idxs in idxs.split(batch_size): + yield X[batch_idxs] + + +ndecs = 0 + + +def update_lr(optimizer, n_vals_without_improvement): + global ndecs + if ndecs == 0 and n_vals_without_improvement > args.early_stopping // 3: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr / 10 + ndecs = 1 + elif ndecs == 1 and n_vals_without_improvement > args.early_stopping // 3 * 2: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr / 100 + ndecs = 2 + else: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr / 10**ndecs + + +def load_data(name): + + if name == 'bsds300': + return datasets.BSDS300() + + elif name == 'power': + return datasets.POWER() + + elif name == 'gas': + return datasets.GAS() + + elif name == 'hepmass': + return datasets.HEPMASS() + + elif name == 'miniboone': + return datasets.MINIBOONE() + + else: + raise ValueError('Unknown dataset') + + +def build_model(input_dim): + hidden_dims = tuple(map(int, args.dims.split("-"))) + chain = [] + for i in range(args.depth): + if args.glow: chain.append(layers.BruteForceLayer(input_dim)) + chain.append(layers.MaskedCouplingLayer(input_dim, hidden_dims, 'alternate', swap=i % 2 == 0)) + if args.batch_norm: chain.append(layers.MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag)) + return layers.SequentialFlow(chain) + + +def compute_loss(x, model): + zero = torch.zeros(x.shape[0], 1).to(x) + + z, delta_logp = model(x, zero) # run model forward + + logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) + logpx = logpz - delta_logp + loss = -torch.mean(logpx) + return loss + + +def restore_model(model, filename): + checkpt = torch.load(filename, map_location=lambda storage, loc: storage) + model.load_state_dict(checkpt["state_dict"]) + return model + + +if __name__ == '__main__': + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) + + logger.info('Using {} GPUs.'.format(torch.cuda.device_count())) + + data = load_data(args.data) + data.trn.x = torch.from_numpy(data.trn.x) + data.val.x = torch.from_numpy(data.val.x) + data.tst.x = torch.from_numpy(data.tst.x) + + model = build_model(data.n_dims).to(device) + + if args.resume is not None: + checkpt = torch.load(args.resume) + model.load_state_dict(checkpt['state_dict']) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + if not args.evaluate: + optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + time_meter = utils.RunningAverageMeter(0.98) + loss_meter = utils.RunningAverageMeter(0.98) + + best_loss = float('inf') + itr = 0 + n_vals_without_improvement = 0 + end = time.time() + model.train() + while True: + if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping: + break + + for x in batch_iter(data.trn.x, shuffle=True): + if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping: + break + + optimizer.zero_grad() + + x = cvt(x) + loss = compute_loss(x, model) + loss_meter.update(loss.item()) + + loss.backward() + optimizer.step() + + time_meter.update(time.time() - end) + + if itr % args.log_freq == 0: + log_message = ( + 'Iter {:06d} | Epoch {:.2f} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | '.format( + itr, + float(itr) / (data.trn.x.shape[0] / float(args.batch_size)), time_meter.val, time_meter.avg, + loss_meter.val, loss_meter.avg + ) + ) + logger.info(log_message) + itr += 1 + end = time.time() + + # Validation loop. + if itr % args.val_freq == 0: + model.eval() + start_time = time.time() + with torch.no_grad(): + val_loss = utils.AverageMeter() + for x in batch_iter(data.val.x, batch_size=test_batch_size): + x = cvt(x) + val_loss.update(compute_loss(x, model).item(), x.shape[0]) + + if val_loss.avg < best_loss: + best_loss = val_loss.avg + utils.makedirs(args.save) + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + }, os.path.join(args.save, 'checkpt.pth')) + n_vals_without_improvement = 0 + else: + n_vals_without_improvement += 1 + update_lr(optimizer, n_vals_without_improvement) + + log_message = ( + '[VAL] Iter {:06d} | Val Loss {:.6f} | ' + 'NoImproveEpochs {:02d}/{:02d}'.format( + itr, val_loss.avg, n_vals_without_improvement, args.early_stopping + ) + ) + logger.info(log_message) + model.train() + + logger.info('Training has finished.') + model = restore_model(model, os.path.join(args.save, 'checkpt.pth')).to(device) + + logger.info('Evaluating model on test set.') + model.eval() + + with torch.no_grad(): + test_loss = utils.AverageMeter() + for itr, x in enumerate(batch_iter(data.tst.x, batch_size=test_batch_size)): + x = cvt(x) + test_loss.update(compute_loss(x, model).item(), x.shape[0]) + logger.info('Progress: {:.2f}%'.format(itr / (data.tst.x.shape[0] / test_batch_size))) + log_message = '[TEST] Iter {:06d} | Test Loss {:.6f} '.format(itr, test_loss.avg) + logger.info(log_message) diff --git a/src/torchprune/torchprune/util/external/ffjord/train_discrete_toy.py b/src/torchprune/torchprune/util/external/ffjord/train_discrete_toy.py new file mode 100644 index 0000000..afa536e --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_discrete_toy.py @@ -0,0 +1,186 @@ +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import argparse +import os +import time + +import torch +import torch.optim as optim + +import lib.layers as layers +import lib.toy_data as toy_data +import lib.utils as utils +from lib.visualize_flow import visualize_transform + +from train_misc import standard_normal_logprob +from train_misc import count_parameters + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser('Continuous Normalizing Flow') +parser.add_argument( + '--data', choices=['swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings'], + type=str, default='pinwheel' +) + +parser.add_argument('--depth', help='number of coupling layers', type=int, default=10) +parser.add_argument('--glow', type=eval, choices=[True, False], default=False) +parser.add_argument('--nf', type=eval, choices=[True, False], default=False) + +parser.add_argument('--niters', type=int, default=100001) +parser.add_argument('--batch_size', type=int, default=100) +parser.add_argument('--test_batch_size', type=int, default=1000) +parser.add_argument('--lr', type=float, default=1e-4) +parser.add_argument('--weight_decay', type=float, default=0) + +# Track quantities +parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1") +parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2") +parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2") +parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F") +parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F") +parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") + +parser.add_argument('--save', type=str, default='experiments/cnf') +parser.add_argument('--viz_freq', type=int, default=1000) +parser.add_argument('--val_freq', type=int, default=1000) +parser.add_argument('--log_freq', type=int, default=100) +parser.add_argument('--gpu', type=int, default=0) +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) + +logger.info(args) + +device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + + +def construct_model(): + + if args.nf: + chain = [] + for i in range(args.depth): + chain.append(layers.PlanarFlow(2)) + return layers.SequentialFlow(chain) + else: + chain = [] + for i in range(args.depth): + if args.glow: chain.append(layers.BruteForceLayer(2)) + chain.append(layers.CouplingLayer(2, swap=i % 2 == 0)) + return layers.SequentialFlow(chain) + + +def get_transforms(model): + + if args.nf: + sample_fn = None + else: + + def sample_fn(z, logpz=None): + if logpz is not None: + return model(z, logpz, reverse=True) + else: + return model(z, reverse=True) + + def density_fn(x, logpx=None): + if logpx is not None: + return model(x, logpx, reverse=False) + else: + return model(x, reverse=False) + + return sample_fn, density_fn + + +def compute_loss(args, model, batch_size=None): + if batch_size is None: batch_size = args.batch_size + + # load data + x = toy_data.inf_train_gen(args.data, batch_size=batch_size) + x = torch.from_numpy(x).type(torch.float32).to(device) + zero = torch.zeros(x.shape[0], 1).to(x) + + # transform to z + z, delta_logp = model(x, zero) + + # compute log q(z) + logpz = standard_normal_logprob(z).sum(1, keepdim=True) + + logpx = logpz - delta_logp + loss = -torch.mean(logpx) + return loss + + +if __name__ == '__main__': + + model = construct_model().to(device) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + optimizer = optim.Adamax(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + time_meter = utils.RunningAverageMeter(0.98) + loss_meter = utils.RunningAverageMeter(0.98) + + end = time.time() + best_loss = float('inf') + model.train() + for itr in range(1, args.niters + 1): + optimizer.zero_grad() + + loss = compute_loss(args, model) + loss_meter.update(loss.item()) + + loss.backward() + optimizer.step() + + time_meter.update(time.time() - end) + + if itr % args.log_freq == 0: + log_message = ( + 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f})'.format( + itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg + ) + ) + logger.info(log_message) + + if itr % args.val_freq == 0 or itr == args.niters: + with torch.no_grad(): + model.eval() + test_loss = compute_loss(args, model, batch_size=args.test_batch_size) + log_message = '[TEST] Iter {:04d} | Test Loss {:.6f}'.format(itr, test_loss) + logger.info(log_message) + + if test_loss.item() < best_loss: + best_loss = test_loss.item() + utils.makedirs(args.save) + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + }, os.path.join(args.save, 'checkpt.pth')) + model.train() + + if itr % args.viz_freq == 0: + with torch.no_grad(): + model.eval() + p_samples = toy_data.inf_train_gen(args.data, batch_size=2000) + + sample_fn, density_fn = get_transforms(model) + + plt.figure(figsize=(9, 3)) + visualize_transform( + p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, + samples=True, npts=800, device=device + ) + fig_filename = os.path.join(args.save, 'figs', '{:04d}.jpg'.format(itr)) + utils.makedirs(os.path.dirname(fig_filename)) + plt.savefig(fig_filename) + plt.close() + model.train() + + end = time.time() + + logger.info('Training has finished.') diff --git a/src/torchprune/torchprune/util/external/ffjord/train_img2d.py b/src/torchprune/torchprune/util/external/ffjord/train_img2d.py new file mode 100644 index 0000000..5041e25 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_img2d.py @@ -0,0 +1,253 @@ +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import argparse +import os +import time + +import torch +import torch.optim as optim + +import lib.utils as utils +from lib.visualize_flow import visualize_transform +import lib.layers.odefunc as odefunc + +from train_misc import standard_normal_logprob +from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time +from train_misc import add_spectral_norm, spectral_norm_power_iteration +from train_misc import create_regularization_fns, get_regularization, append_regularization_to_log +from train_misc import build_model_tabular + +from diagnostics.viz_toy import save_trajectory, trajectory_to_video + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser('Continuous Normalizing Flow') +parser.add_argument('--img', type=str, required=True) +parser.add_argument('--data', type=str, default='dummy') +parser.add_argument( + "--layer_type", type=str, default="concatsquash", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument('--dims', type=str, default='64-64-64') +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') +parser.add_argument('--time_length', type=float, default=0.5) +parser.add_argument('--train_T', type=eval, default=True) +parser.add_argument("--divergence_fn", type=str, default="brute_force", choices=["brute_force", "approximate"]) +parser.add_argument("--nonlinearity", type=str, default="tanh", choices=odefunc.NONLINEARITIES) + +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) + +parser.add_argument('--niters', type=int, default=10000) +parser.add_argument('--batch_size', type=int, default=1000) +parser.add_argument('--test_batch_size', type=int, default=1000) +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--weight_decay', type=float, default=1e-5) + +# Track quantities +parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1") +parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2") +parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2") +parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F") +parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F") +parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") + +parser.add_argument('--save', type=str, default='experiments/cnf') +parser.add_argument('--viz_freq', type=int, default=100) +parser.add_argument('--val_freq', type=int, default=100) +parser.add_argument('--log_freq', type=int, default=10) +parser.add_argument('--gpu', type=int, default=0) +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) + +if args.layer_type == "blend": + logger.info("!! Setting time_length from None to 1.0 due to use of Blend layers.") + args.time_length = 1.0 + +logger.info(args) + +device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + +from PIL import Image +import numpy as np + +img = np.array(Image.open(args.img).convert('L')) +h, w = img.shape +xx = np.linspace(-4, 4, w) +yy = np.linspace(-4, 4, h) +xx, yy = np.meshgrid(xx, yy) +xx = xx.reshape(-1, 1) +yy = yy.reshape(-1, 1) + +means = np.concatenate([xx, yy], 1) +img = img.max() - img +probs = img.reshape(-1) / img.sum() + +std = np.array([8 / w / 2, 8 / h / 2]) + + +def sample_data(data=None, rng=None, batch_size=200): + """data and rng are ignored.""" + inds = np.random.choice(int(probs.shape[0]), int(batch_size), p=probs) + m = means[inds] + samples = np.random.randn(*m.shape) * std + m + return samples + + +def get_transforms(model): + + def sample_fn(z, logpz=None): + if logpz is not None: + return model(z, logpz, reverse=True) + else: + return model(z, reverse=True) + + def density_fn(x, logpx=None): + if logpx is not None: + return model(x, logpx, reverse=False) + else: + return model(x, reverse=False) + + return sample_fn, density_fn + + +def compute_loss(args, model, batch_size=None): + if batch_size is None: batch_size = args.batch_size + + # load data + x = sample_data(args.data, batch_size=batch_size) + x = torch.from_numpy(x).type(torch.float32).to(device) + zero = torch.zeros(x.shape[0], 1).to(x) + + # transform to z + z, delta_logp = model(x, zero) + + # compute log q(z) + logpz = standard_normal_logprob(z).sum(1, keepdim=True) + + logpx = logpz - delta_logp + loss = -torch.mean(logpx) + return loss + + +if __name__ == '__main__': + + regularization_fns, regularization_coeffs = create_regularization_fns(args) + model = build_model_tabular(args, 2, regularization_fns).to(device) + if args.spectral_norm: add_spectral_norm(model) + set_cnf_options(args, model) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + time_meter = utils.RunningAverageMeter(0.93) + loss_meter = utils.RunningAverageMeter(0.93) + nfef_meter = utils.RunningAverageMeter(0.93) + nfeb_meter = utils.RunningAverageMeter(0.93) + tt_meter = utils.RunningAverageMeter(0.93) + + end = time.time() + best_loss = float('inf') + model.train() + for itr in range(1, args.niters + 1): + optimizer.zero_grad() + if args.spectral_norm: spectral_norm_power_iteration(model, 1) + + loss = compute_loss(args, model) + loss_meter.update(loss.item()) + + if len(regularization_coeffs) > 0: + reg_states = get_regularization(model, regularization_coeffs) + reg_loss = sum( + reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 + ) + loss = loss + reg_loss + + total_time = count_total_time(model) + nfe_forward = count_nfe(model) + + loss.backward() + optimizer.step() + + nfe_total = count_nfe(model) + nfe_backward = nfe_total - nfe_forward + nfef_meter.update(nfe_forward) + nfeb_meter.update(nfe_backward) + + time_meter.update(time.time() - end) + tt_meter.update(total_time) + + log_message = ( + 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' + ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( + itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, + nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg + ) + ) + if len(regularization_coeffs) > 0: + log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) + + logger.info(log_message) + + if itr % args.val_freq == 0 or itr == args.niters: + with torch.no_grad(): + model.eval() + test_loss = compute_loss(args, model, batch_size=args.test_batch_size) + test_nfe = count_nfe(model) + log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe) + logger.info(log_message) + + if test_loss.item() < best_loss: + best_loss = test_loss.item() + utils.makedirs(args.save) + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + }, os.path.join(args.save, 'checkpt.pth')) + model.train() + + if itr % args.viz_freq == 0: + with torch.no_grad(): + model.eval() + p_samples = sample_data(args.data, batch_size=2000) + + sample_fn, density_fn = get_transforms(model) + + plt.figure(figsize=(9, 3)) + visualize_transform( + p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, + samples=True, npts=800, device=device + ) + fig_filename = os.path.join(args.save, 'figs', '{:04d}.jpg'.format(itr)) + utils.makedirs(os.path.dirname(fig_filename)) + plt.savefig(fig_filename) + plt.close() + model.train() + + end = time.time() + + logger.info('Training has finished.') + + save_traj_dir = os.path.join(args.save, 'trajectory') + logger.info('Plotting trajectory to {}'.format(save_traj_dir)) + data_samples = sample_data(args.data, batch_size=2000) + save_trajectory(model, data_samples, save_traj_dir, device=device) + trajectory_to_video(save_traj_dir) diff --git a/src/torchprune/torchprune/util/external/ffjord/train_misc.py b/src/torchprune/torchprune/util/external/ffjord/train_misc.py new file mode 100644 index 0000000..d2ee9e1 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_misc.py @@ -0,0 +1,200 @@ +import six +import math + +from .lib.layers.wrappers import cnf_regularization as reg_lib +from .lib import spectral_norm as spectral_norm +from .lib import layers as layers +from .lib.layers.odefunc import divergence_bf, divergence_approx + + +def standard_normal_logprob(z): + logZ = -0.5 * math.log(2 * math.pi) + return logZ - z.pow(2) / 2 + + +def set_cnf_options(args, model): + + def _set(module): + if isinstance(module, layers.CNF): + # Set training settings + module.solver = args.solver + module.atol = args.atol + module.rtol = args.rtol + if args.step_size is not None: + module.solver_options['step_size'] = args.step_size + + # If using fixed-grid adams, restrict order to not be too high. + if args.solver in ['fixed_adams', 'explicit_adams']: + module.solver_options['max_order'] = 4 + + # Set the test settings + module.test_solver = args.test_solver if args.test_solver else args.solver + module.test_atol = args.test_atol if args.test_atol else args.atol + module.test_rtol = args.test_rtol if args.test_rtol else args.rtol + + if isinstance(module, layers.ODEfunc): + module.rademacher = args.rademacher + module.residual = args.residual + + model.apply(_set) + + +def override_divergence_fn(model, divergence_fn): + + def _set(module): + if isinstance(module, layers.ODEfunc): + if divergence_fn == "brute_force": + module.divergence_fn = divergence_bf + elif divergence_fn == "approximate": + module.divergence_fn = divergence_approx + + model.apply(_set) + + +def count_nfe(model): + + class AccNumEvals(object): + + def __init__(self): + self.num_evals = 0 + + def __call__(self, module): + if isinstance(module, layers.ODEfunc): + self.num_evals += module.num_evals() + + accumulator = AccNumEvals() + model.apply(accumulator) + return accumulator.num_evals + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def count_total_time(model): + + class Accumulator(object): + + def __init__(self): + self.total_time = 0 + + def __call__(self, module): + if isinstance(module, layers.CNF): + self.total_time = self.total_time + module.sqrt_end_time * module.sqrt_end_time + + accumulator = Accumulator() + model.apply(accumulator) + return accumulator.total_time + + +def add_spectral_norm(model, logger=None): + """Applies spectral norm to all modules within the scope of a CNF.""" + + def apply_spectral_norm(module): + if 'weight' in module._parameters: + if logger: logger.info("Adding spectral norm to {}".format(module)) + spectral_norm.inplace_spectral_norm(module, 'weight') + + def find_cnf(module): + if isinstance(module, layers.CNF): + module.apply(apply_spectral_norm) + else: + for child in module.children(): + find_cnf(child) + + find_cnf(model) + + +def spectral_norm_power_iteration(model, n_power_iterations=1): + + def recursive_power_iteration(module): + if hasattr(module, spectral_norm.POWER_ITERATION_FN): + getattr(module, spectral_norm.POWER_ITERATION_FN)(n_power_iterations) + + model.apply(recursive_power_iteration) + + +REGULARIZATION_FNS = { + "l1int": reg_lib.l1_regularzation_fn, + "l2int": reg_lib.l2_regularzation_fn, + "dl2int": reg_lib.directional_l2_regularization_fn, + "JFrobint": reg_lib.jacobian_frobenius_regularization_fn, + "JdiagFrobint": reg_lib.jacobian_diag_frobenius_regularization_fn, + "JoffdiagFrobint": reg_lib.jacobian_offdiag_frobenius_regularization_fn, +} + +INV_REGULARIZATION_FNS = {v: k for k, v in six.iteritems(REGULARIZATION_FNS)} + + +def append_regularization_to_log(log_message, regularization_fns, reg_states): + for i, reg_fn in enumerate(regularization_fns): + log_message = log_message + " | " + INV_REGULARIZATION_FNS[reg_fn] + ": {:.8f}".format(reg_states[i].item()) + return log_message + + +def create_regularization_fns(args): + regularization_fns = [] + regularization_coeffs = [] + + for arg_key, reg_fn in six.iteritems(REGULARIZATION_FNS): + if getattr(args, arg_key) is not None: + regularization_fns.append(reg_fn) + regularization_coeffs.append(eval("args." + arg_key)) + + regularization_fns = tuple(regularization_fns) + regularization_coeffs = tuple(regularization_coeffs) + return regularization_fns, regularization_coeffs + + +def get_regularization(model, regularization_coeffs): + if len(regularization_coeffs) == 0: + return None + + acc_reg_states = tuple([0.] * len(regularization_coeffs)) + for module in model.modules(): + if isinstance(module, layers.CNF): + acc_reg_states = tuple(acc + reg for acc, reg in zip(acc_reg_states, module.get_regularization_states())) + return acc_reg_states + + +def build_model_tabular(args, dims, regularization_fns=None): + + hidden_dims = tuple(map(int, args.dims.split("-"))) + + def build_cnf(): + diffeq = layers.ODEnet( + hidden_dims=hidden_dims, + input_shape=(dims,), + strides=None, + conv=False, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = layers.ODEfunc( + diffeq=diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + cnf = layers.CNF( + odefunc=odefunc, + T=args.time_length, + train_T=args.train_T, + regularization_fns=regularization_fns, + solver=args.solver, + ) + return cnf + + chain = [build_cnf() for _ in range(args.num_blocks)] + if args.batch_norm: + bn_layers = [layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag) for _ in range(args.num_blocks)] + bn_chain = [layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag)] + for a, b in zip(chain, bn_layers): + bn_chain.append(a) + bn_chain.append(b) + chain = bn_chain + model = layers.SequentialFlow(chain) + + set_cnf_options(args, model) + + return model diff --git a/src/torchprune/torchprune/util/external/ffjord/train_tabular.py b/src/torchprune/torchprune/util/external/ffjord/train_tabular.py new file mode 100644 index 0000000..b66d1d7 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_tabular.py @@ -0,0 +1,306 @@ +import argparse +import os +import time + +import torch + +import lib.utils as utils +import lib.layers.odefunc as odefunc +from lib.custom_optimizers import Adam + +import datasets + +from train_misc import standard_normal_logprob +from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time +from train_misc import create_regularization_fns, get_regularization, append_regularization_to_log +from train_misc import build_model_tabular, override_divergence_fn + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser('Continuous Normalizing Flow') +parser.add_argument( + '--data', choices=['power', 'gas', 'hepmass', 'miniboone', 'bsds300'], type=str, default='miniboone' +) +parser.add_argument( + "--layer_type", type=str, default="concatsquash", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument('--hdim_factor', type=int, default=10) +parser.add_argument('--nhidden', type=int, default=1) +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') +parser.add_argument('--time_length', type=float, default=1.0) +parser.add_argument('--train_T', type=eval, default=True) +parser.add_argument("--divergence_fn", type=str, default="approximate", choices=["brute_force", "approximate"]) +parser.add_argument("--nonlinearity", type=str, default="softplus", choices=odefunc.NONLINEARITIES) + +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-8) +parser.add_argument('--rtol', type=float, default=1e-6) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) + +parser.add_argument('--early_stopping', type=int, default=30) +parser.add_argument('--batch_size', type=int, default=1000) +parser.add_argument('--test_batch_size', type=int, default=None) +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--weight_decay', type=float, default=1e-6) + +# Track quantities +parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1") +parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2") +parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2") +parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F") +parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F") +parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") + +parser.add_argument('--resume', type=str, default=None) +parser.add_argument('--save', type=str, default='experiments/cnf') +parser.add_argument('--evaluate', action='store_true') +parser.add_argument('--val_freq', type=int, default=200) +parser.add_argument('--log_freq', type=int, default=10) +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) + +if args.layer_type == "blend": + logger.info("!! Setting time_length from None to 1.0 due to use of Blend layers.") + args.time_length = 1.0 + args.train_T = False + +logger.info(args) + +test_batch_size = args.test_batch_size if args.test_batch_size else args.batch_size + + +def batch_iter(X, batch_size=args.batch_size, shuffle=False): + """ + X: feature tensor (shape: num_instances x num_features) + """ + if shuffle: + idxs = torch.randperm(X.shape[0]) + else: + idxs = torch.arange(X.shape[0]) + if X.is_cuda: + idxs = idxs.cuda() + for batch_idxs in idxs.split(batch_size): + yield X[batch_idxs] + + +ndecs = 0 + + +def update_lr(optimizer, n_vals_without_improvement): + global ndecs + if ndecs == 0 and n_vals_without_improvement > args.early_stopping // 3: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr / 10 + ndecs = 1 + elif ndecs == 1 and n_vals_without_improvement > args.early_stopping // 3 * 2: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr / 100 + ndecs = 2 + else: + for param_group in optimizer.param_groups: + param_group["lr"] = args.lr / 10**ndecs + + +def load_data(name): + + if name == 'bsds300': + return datasets.BSDS300() + + elif name == 'power': + return datasets.POWER() + + elif name == 'gas': + return datasets.GAS() + + elif name == 'hepmass': + return datasets.HEPMASS() + + elif name == 'miniboone': + return datasets.MINIBOONE() + + else: + raise ValueError('Unknown dataset') + + +def compute_loss(x, model): + zero = torch.zeros(x.shape[0], 1).to(x) + + z, delta_logp = model(x, zero) # run model forward + + logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) + logpx = logpz - delta_logp + loss = -torch.mean(logpx) + return loss + + +def restore_model(model, filename): + checkpt = torch.load(filename, map_location=lambda storage, loc: storage) + model.load_state_dict(checkpt["state_dict"]) + return model + + +if __name__ == '__main__': + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) + + logger.info('Using {} GPUs.'.format(torch.cuda.device_count())) + + data = load_data(args.data) + data.trn.x = torch.from_numpy(data.trn.x) + data.val.x = torch.from_numpy(data.val.x) + data.tst.x = torch.from_numpy(data.tst.x) + + args.dims = '-'.join([str(args.hdim_factor * data.n_dims)] * args.nhidden) + + regularization_fns, regularization_coeffs = create_regularization_fns(args) + model = build_model_tabular(args, data.n_dims, regularization_fns).to(device) + set_cnf_options(args, model) + + for k in model.state_dict().keys(): + logger.info(k) + + if args.resume is not None: + checkpt = torch.load(args.resume) + + # Backwards compatibility with an older version of the code. + # TODO: remove upon release. + filtered_state_dict = {} + for k, v in checkpt['state_dict'].items(): + if 'diffeq.diffeq' not in k: + filtered_state_dict[k.replace('module.', '')] = v + model.load_state_dict(filtered_state_dict) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + if not args.evaluate: + optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + time_meter = utils.RunningAverageMeter(0.98) + loss_meter = utils.RunningAverageMeter(0.98) + nfef_meter = utils.RunningAverageMeter(0.98) + nfeb_meter = utils.RunningAverageMeter(0.98) + tt_meter = utils.RunningAverageMeter(0.98) + + best_loss = float('inf') + itr = 0 + n_vals_without_improvement = 0 + end = time.time() + model.train() + while True: + if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping: + break + + for x in batch_iter(data.trn.x, shuffle=True): + if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping: + break + + optimizer.zero_grad() + + x = cvt(x) + loss = compute_loss(x, model) + loss_meter.update(loss.item()) + + if len(regularization_coeffs) > 0: + reg_states = get_regularization(model, regularization_coeffs) + reg_loss = sum( + reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 + ) + loss = loss + reg_loss + + total_time = count_total_time(model) + nfe_forward = count_nfe(model) + + loss.backward() + optimizer.step() + + nfe_total = count_nfe(model) + nfe_backward = nfe_total - nfe_forward + nfef_meter.update(nfe_forward) + nfeb_meter.update(nfe_backward) + + time_meter.update(time.time() - end) + tt_meter.update(total_time) + + if itr % args.log_freq == 0: + log_message = ( + 'Iter {:06d} | Epoch {:.2f} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | ' + 'NFE Forward {:.0f}({:.1f}) | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( + itr, + float(itr) / (data.trn.x.shape[0] / float(args.batch_size)), time_meter.val, time_meter.avg, + loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, + nfeb_meter.avg, tt_meter.val, tt_meter.avg + ) + ) + if len(regularization_coeffs) > 0: + log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) + + logger.info(log_message) + itr += 1 + end = time.time() + + # Validation loop. + if itr % args.val_freq == 0: + model.eval() + start_time = time.time() + with torch.no_grad(): + val_loss = utils.AverageMeter() + val_nfe = utils.AverageMeter() + for x in batch_iter(data.val.x, batch_size=test_batch_size): + x = cvt(x) + val_loss.update(compute_loss(x, model).item(), x.shape[0]) + val_nfe.update(count_nfe(model)) + + if val_loss.avg < best_loss: + best_loss = val_loss.avg + utils.makedirs(args.save) + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + }, os.path.join(args.save, 'checkpt.pth')) + n_vals_without_improvement = 0 + else: + n_vals_without_improvement += 1 + update_lr(optimizer, n_vals_without_improvement) + + log_message = ( + '[VAL] Iter {:06d} | Val Loss {:.6f} | NFE {:.0f} | ' + 'NoImproveEpochs {:02d}/{:02d}'.format( + itr, val_loss.avg, val_nfe.avg, n_vals_without_improvement, args.early_stopping + ) + ) + logger.info(log_message) + model.train() + + logger.info('Training has finished.') + model = restore_model(model, os.path.join(args.save, 'checkpt.pth')).to(device) + set_cnf_options(args, model) + + logger.info('Evaluating model on test set.') + model.eval() + + override_divergence_fn(model, "brute_force") + + with torch.no_grad(): + test_loss = utils.AverageMeter() + test_nfe = utils.AverageMeter() + for itr, x in enumerate(batch_iter(data.tst.x, batch_size=test_batch_size)): + x = cvt(x) + test_loss.update(compute_loss(x, model).item(), x.shape[0]) + test_nfe.update(count_nfe(model)) + logger.info('Progress: {:.2f}%'.format(100. * itr / (data.tst.x.shape[0] / test_batch_size))) + log_message = '[TEST] Iter {:06d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss.avg, test_nfe.avg) + logger.info(log_message) diff --git a/src/torchprune/torchprune/util/external/ffjord/train_toy.py b/src/torchprune/torchprune/util/external/ffjord/train_toy.py new file mode 100644 index 0000000..72ab811 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_toy.py @@ -0,0 +1,231 @@ +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +import argparse +import os +import time + +import torch +import torch.optim as optim + +import lib.toy_data as toy_data +import lib.utils as utils +from lib.visualize_flow import visualize_transform +import lib.layers.odefunc as odefunc + +from train_misc import standard_normal_logprob +from train_misc import set_cnf_options, count_nfe, count_parameters, count_total_time +from train_misc import add_spectral_norm, spectral_norm_power_iteration +from train_misc import create_regularization_fns, get_regularization, append_regularization_to_log +from train_misc import build_model_tabular + +from diagnostics.viz_toy import save_trajectory, trajectory_to_video + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser('Continuous Normalizing Flow') +parser.add_argument( + '--data', choices=['swissroll', '8gaussians', 'pinwheel', 'circles', 'moons', '2spirals', 'checkerboard', 'rings'], + type=str, default='pinwheel' +) +parser.add_argument( + "--layer_type", type=str, default="concatsquash", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument('--dims', type=str, default='64-64-64') +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') +parser.add_argument('--time_length', type=float, default=0.5) +parser.add_argument('--train_T', type=eval, default=True) +parser.add_argument("--divergence_fn", type=str, default="brute_force", choices=["brute_force", "approximate"]) +parser.add_argument("--nonlinearity", type=str, default="tanh", choices=odefunc.NONLINEARITIES) + +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False]) +parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) + +parser.add_argument('--niters', type=int, default=10000) +parser.add_argument('--batch_size', type=int, default=100) +parser.add_argument('--test_batch_size', type=int, default=1000) +parser.add_argument('--lr', type=float, default=1e-3) +parser.add_argument('--weight_decay', type=float, default=1e-5) + +# Track quantities +parser.add_argument('--l1int', type=float, default=None, help="int_t ||f||_1") +parser.add_argument('--l2int', type=float, default=None, help="int_t ||f||_2") +parser.add_argument('--dl2int', type=float, default=None, help="int_t ||f^T df/dt||_2") +parser.add_argument('--JFrobint', type=float, default=None, help="int_t ||df/dx||_F") +parser.add_argument('--JdiagFrobint', type=float, default=None, help="int_t ||df_i/dx_i||_F") +parser.add_argument('--JoffdiagFrobint', type=float, default=None, help="int_t ||df/dx - df_i/dx_i||_F") + +parser.add_argument('--save', type=str, default='experiments/cnf') +parser.add_argument('--viz_freq', type=int, default=100) +parser.add_argument('--val_freq', type=int, default=100) +parser.add_argument('--log_freq', type=int, default=10) +parser.add_argument('--gpu', type=int, default=0) +args = parser.parse_args() + +# logger +utils.makedirs(args.save) +logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) + +if args.layer_type == "blend": + logger.info("!! Setting time_length from None to 1.0 due to use of Blend layers.") + args.time_length = 1.0 + +logger.info(args) + +device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') + + +def get_transforms(model): + + def sample_fn(z, logpz=None): + if logpz is not None: + return model(z, logpz, reverse=True) + else: + return model(z, reverse=True) + + def density_fn(x, logpx=None): + if logpx is not None: + return model(x, logpx, reverse=False) + else: + return model(x, reverse=False) + + return sample_fn, density_fn + + +def compute_loss(args, model, batch_size=None): + if batch_size is None: batch_size = args.batch_size + + # load data + x = toy_data.inf_train_gen(args.data, batch_size=batch_size) + x = torch.from_numpy(x).type(torch.float32).to(device) + zero = torch.zeros(x.shape[0], 1).to(x) + + # transform to z + z, delta_logp = model(x, zero) + + # compute log q(z) + logpz = standard_normal_logprob(z).sum(1, keepdim=True) + + logpx = logpz - delta_logp + loss = -torch.mean(logpx) + return loss + + +if __name__ == '__main__': + + regularization_fns, regularization_coeffs = create_regularization_fns(args) + model = build_model_tabular(args, 2, regularization_fns).to(device) + if args.spectral_norm: add_spectral_norm(model) + set_cnf_options(args, model) + + logger.info(model) + logger.info("Number of trainable parameters: {}".format(count_parameters(model))) + + optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + time_meter = utils.RunningAverageMeter(0.93) + loss_meter = utils.RunningAverageMeter(0.93) + nfef_meter = utils.RunningAverageMeter(0.93) + nfeb_meter = utils.RunningAverageMeter(0.93) + tt_meter = utils.RunningAverageMeter(0.93) + + end = time.time() + best_loss = float('inf') + model.train() + for itr in range(1, args.niters + 1): + optimizer.zero_grad() + if args.spectral_norm: spectral_norm_power_iteration(model, 1) + + loss = compute_loss(args, model) + loss_meter.update(loss.item()) + + if len(regularization_coeffs) > 0: + reg_states = get_regularization(model, regularization_coeffs) + reg_loss = sum( + reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 + ) + loss = loss + reg_loss + + total_time = count_total_time(model) + nfe_forward = count_nfe(model) + + loss.backward() + optimizer.step() + + nfe_total = count_nfe(model) + nfe_backward = nfe_total - nfe_forward + nfef_meter.update(nfe_forward) + nfeb_meter.update(nfe_backward) + + time_meter.update(time.time() - end) + tt_meter.update(total_time) + + log_message = ( + 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' + ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( + itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, + nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg + ) + ) + if len(regularization_coeffs) > 0: + log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) + + logger.info(log_message) + + if itr % args.val_freq == 0 or itr == args.niters: + with torch.no_grad(): + model.eval() + test_loss = compute_loss(args, model, batch_size=args.test_batch_size) + test_nfe = count_nfe(model) + log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe) + logger.info(log_message) + + if test_loss.item() < best_loss: + best_loss = test_loss.item() + utils.makedirs(args.save) + torch.save({ + 'args': args, + 'state_dict': model.state_dict(), + }, os.path.join(args.save, 'checkpt.pth')) + model.train() + + if itr % args.viz_freq == 0: + with torch.no_grad(): + model.eval() + p_samples = toy_data.inf_train_gen(args.data, batch_size=2000) + + sample_fn, density_fn = get_transforms(model) + + plt.figure(figsize=(9, 3)) + visualize_transform( + p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, + samples=True, npts=800, device=device + ) + fig_filename = os.path.join(args.save, 'figs', '{:04d}.jpg'.format(itr)) + utils.makedirs(os.path.dirname(fig_filename)) + plt.savefig(fig_filename) + plt.close() + model.train() + + end = time.time() + + logger.info('Training has finished.') + + save_traj_dir = os.path.join(args.save, 'trajectory') + logger.info('Plotting trajectory to {}'.format(save_traj_dir)) + data_samples = toy_data.inf_train_gen(args.data, batch_size=2000) + save_trajectory(model, data_samples, save_traj_dir, device=device) + trajectory_to_video(save_traj_dir) diff --git a/src/torchprune/torchprune/util/external/ffjord/train_vae_flow.py b/src/torchprune/torchprune/util/external/ffjord/train_vae_flow.py new file mode 100644 index 0000000..7bec82b --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/train_vae_flow.py @@ -0,0 +1,358 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import print_function +import argparse +import time +import torch +import torch.utils.data +import torch.optim as optim +import numpy as np +import math +import random + +import os + +import datetime + +import lib.utils as utils +import lib.layers.odefunc as odefunc + +import vae_lib.models.VAE as VAE +import vae_lib.models.CNFVAE as CNFVAE +from vae_lib.optimization.training import train, evaluate +from vae_lib.utils.load_data import load_dataset +from vae_lib.utils.plotting import plot_training_curve + +SOLVERS = ["dopri5", "bdf", "rk4", "midpoint", 'adams', 'explicit_adams', 'fixed_adams'] +parser = argparse.ArgumentParser(description='PyTorch Sylvester Normalizing flows') + +parser.add_argument( + '-d', '--dataset', type=str, default='mnist', choices=['mnist', 'freyfaces', 'omniglot', 'caltech'], + metavar='DATASET', help='Dataset choice.' +) + +parser.add_argument( + '-freys', '--freyseed', type=int, default=123, metavar='FREYSEED', + help="""Seed for shuffling frey face dataset for test split. Ignored for other datasets. + Results in paper are produced with seeds 123, 321, 231""" +) + +parser.add_argument('-nc', '--no_cuda', action='store_true', default=False, help='disables CUDA training') + +parser.add_argument('--manual_seed', type=int, help='manual seed, if not given resorts to random seed.') + +parser.add_argument( + '-li', '--log_interval', type=int, default=10, metavar='LOG_INTERVAL', + help='how many batches to wait before logging training status' +) + +parser.add_argument( + '-od', '--out_dir', type=str, default='snapshots', metavar='OUT_DIR', + help='output directory for model snapshots etc.' +) + +# optimization settings +parser.add_argument( + '-e', '--epochs', type=int, default=2000, metavar='EPOCHS', help='number of epochs to train (default: 2000)' +) +parser.add_argument( + '-es', '--early_stopping_epochs', type=int, default=35, metavar='EARLY_STOPPING', + help='number of early stopping epochs' +) + +parser.add_argument( + '-bs', '--batch_size', type=int, default=100, metavar='BATCH_SIZE', help='input batch size for training' +) +parser.add_argument('-lr', '--learning_rate', type=float, default=0.0005, metavar='LEARNING_RATE', help='learning rate') + +parser.add_argument( + '-w', '--warmup', type=int, default=100, metavar='N', + help='number of epochs for warm-up. Set to 0 to turn warmup off.' +) +parser.add_argument('--max_beta', type=float, default=1., metavar='MB', help='max beta for warm-up') +parser.add_argument('--min_beta', type=float, default=0.0, metavar='MB', help='min beta for warm-up') +parser.add_argument( + '-f', '--flow', type=str, default='no_flow', choices=[ + 'planar', 'iaf', 'householder', 'orthogonal', 'triangular', 'cnf', 'cnf_bias', 'cnf_hyper', 'cnf_rank', + 'cnf_lyper', 'no_flow' + ], help="""Type of flows to use, no flows can also be selected""" +) +parser.add_argument('-r', '--rank', type=int, default=1) +parser.add_argument( + '-nf', '--num_flows', type=int, default=4, metavar='NUM_FLOWS', + help='Number of flow layers, ignored in absence of flows' +) +parser.add_argument( + '-nv', '--num_ortho_vecs', type=int, default=8, metavar='NUM_ORTHO_VECS', + help=""" For orthogonal flow: How orthogonal vectors per flow do you need. + Ignored for other flow types.""" +) +parser.add_argument( + '-nh', '--num_householder', type=int, default=8, metavar='NUM_HOUSEHOLDERS', + help=""" For Householder Sylvester flow: Number of Householder matrices per flow. + Ignored for other flow types.""" +) +parser.add_argument( + '-mhs', '--made_h_size', type=int, default=320, metavar='MADEHSIZE', + help='Width of mades for iaf. Ignored for all other flows.' +) +parser.add_argument('--z_size', type=int, default=64, metavar='ZSIZE', help='how many stochastic hidden units') +# gpu/cpu +parser.add_argument('--gpu_num', type=int, default=0, metavar='GPU', help='choose GPU to run on.') + +# CNF settings +parser.add_argument( + "--layer_type", type=str, default="concat", + choices=["ignore", "concat", "concat_v2", "squash", "concatsquash", "concatcoord", "hyper", "blend"] +) +parser.add_argument('--dims', type=str, default='512-512') +parser.add_argument("--num_blocks", type=int, default=1, help='Number of stacked CNFs.') +parser.add_argument('--time_length', type=float, default=0.5) +parser.add_argument('--train_T', type=eval, default=False) +parser.add_argument("--divergence_fn", type=str, default="approximate", choices=["brute_force", "approximate"]) +parser.add_argument("--nonlinearity", type=str, default="softplus", choices=odefunc.NONLINEARITIES) + +parser.add_argument('--solver', type=str, default='dopri5', choices=SOLVERS) +parser.add_argument('--atol', type=float, default=1e-5) +parser.add_argument('--rtol', type=float, default=1e-5) +parser.add_argument("--step_size", type=float, default=None, help="Optional fixed step size.") + +parser.add_argument('--test_solver', type=str, default=None, choices=SOLVERS + [None]) +parser.add_argument('--test_atol', type=float, default=None) +parser.add_argument('--test_rtol', type=float, default=None) + +parser.add_argument('--residual', type=eval, default=False, choices=[True, False]) +parser.add_argument('--rademacher', type=eval, default=False, choices=[True, False]) +parser.add_argument('--batch_norm', type=eval, default=False, choices=[True, False]) +parser.add_argument('--bn_lag', type=float, default=0) +# evaluation +parser.add_argument('--evaluate', type=eval, default=False, choices=[True, False]) +parser.add_argument('--model_path', type=str, default='') +parser.add_argument('--retrain_encoder', type=eval, default=False, choices=[True, False]) + +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +if args.manual_seed is None: + args.manual_seed = random.randint(1, 100000) +random.seed(args.manual_seed) +torch.manual_seed(args.manual_seed) +np.random.seed(args.manual_seed) + +if args.cuda: + # gpu device number + torch.cuda.set_device(args.gpu_num) + +kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {} + + +def run(args, kwargs): + # ================================================================================================================== + # SNAPSHOTS + # ================================================================================================================== + args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_') + args.model_signature = args.model_signature.replace(':', '_') + + snapshots_path = os.path.join(args.out_dir, 'vae_' + args.dataset + '_') + snap_dir = snapshots_path + args.flow + + if args.flow != 'no_flow': + snap_dir += '_' + 'num_flows_' + str(args.num_flows) + + if args.flow == 'orthogonal': + snap_dir = snap_dir + '_num_vectors_' + str(args.num_ortho_vecs) + elif args.flow == 'orthogonalH': + snap_dir = snap_dir + '_num_householder_' + str(args.num_householder) + elif args.flow == 'iaf': + snap_dir = snap_dir + '_madehsize_' + str(args.made_h_size) + + elif args.flow == 'permutation': + snap_dir = snap_dir + '_' + 'kernelsize_' + str(args.kernel_size) + elif args.flow == 'mixed': + snap_dir = snap_dir + '_' + 'num_householder_' + str(args.num_householder) + elif args.flow == 'cnf_rank': + snap_dir = snap_dir + '_rank_' + str(args.rank) + '_' + args.dims + '_num_blocks_' + str(args.num_blocks) + elif 'cnf' in args.flow: + snap_dir = snap_dir + '_' + args.dims + '_num_blocks_' + str(args.num_blocks) + + if args.retrain_encoder: + snap_dir = snap_dir + '_retrain-encoder_' + elif args.evaluate: + snap_dir = snap_dir + '_evaluate_' + + snap_dir = snap_dir + '__' + args.model_signature + '/' + + args.snap_dir = snap_dir + + if not os.path.exists(snap_dir): + os.makedirs(snap_dir) + + # logger + utils.makedirs(args.snap_dir) + logger = utils.get_logger(logpath=os.path.join(args.snap_dir, 'logs'), filepath=os.path.abspath(__file__)) + + logger.info(args) + + # SAVING + torch.save(args, snap_dir + args.flow + '.config') + + # ================================================================================================================== + # LOAD DATA + # ================================================================================================================== + train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs) + + if not args.evaluate: + + # ============================================================================================================== + # SELECT MODEL + # ============================================================================================================== + # flow parameters and architecture choice are passed on to model through args + + if args.flow == 'no_flow': + model = VAE.VAE(args) + elif args.flow == 'planar': + model = VAE.PlanarVAE(args) + elif args.flow == 'iaf': + model = VAE.IAFVAE(args) + elif args.flow == 'orthogonal': + model = VAE.OrthogonalSylvesterVAE(args) + elif args.flow == 'householder': + model = VAE.HouseholderSylvesterVAE(args) + elif args.flow == 'triangular': + model = VAE.TriangularSylvesterVAE(args) + elif args.flow == 'cnf': + model = CNFVAE.CNFVAE(args) + elif args.flow == 'cnf_bias': + model = CNFVAE.AmortizedBiasCNFVAE(args) + elif args.flow == 'cnf_hyper': + model = CNFVAE.HypernetCNFVAE(args) + elif args.flow == 'cnf_lyper': + model = CNFVAE.LypernetCNFVAE(args) + elif args.flow == 'cnf_rank': + model = CNFVAE.AmortizedLowRankCNFVAE(args) + else: + raise ValueError('Invalid flow choice') + + if args.retrain_encoder: + logger.info(f"Initializing decoder from {args.model_path}") + dec_model = torch.load(args.model_path) + dec_sd = {} + for k, v in dec_model.state_dict().items(): + if 'p_x' in k: + dec_sd[k] = v + model.load_state_dict(dec_sd, strict=False) + + if args.cuda: + logger.info("Model on GPU") + model.cuda() + + logger.info(model) + + if args.retrain_encoder: + parameters = [] + logger.info('Optimizing over:') + for name, param in model.named_parameters(): + if 'p_x' not in name: + logger.info(name) + parameters.append(param) + else: + parameters = model.parameters() + + optimizer = optim.Adamax(parameters, lr=args.learning_rate, eps=1.e-7) + + # ================================================================================================================== + # TRAINING + # ================================================================================================================== + train_loss = [] + val_loss = [] + + # for early stopping + best_loss = np.inf + best_bpd = np.inf + e = 0 + epoch = 0 + + train_times = [] + + for epoch in range(1, args.epochs + 1): + + t_start = time.time() + tr_loss = train(epoch, train_loader, model, optimizer, args, logger) + train_loss.append(tr_loss) + train_times.append(time.time() - t_start) + logger.info('One training epoch took %.2f seconds' % (time.time() - t_start)) + + v_loss, v_bpd = evaluate(val_loader, model, args, logger, epoch=epoch) + + val_loss.append(v_loss) + + # early-stopping + if v_loss < best_loss: + e = 0 + best_loss = v_loss + if args.input_type != 'binary': + best_bpd = v_bpd + logger.info('->model saved<-') + torch.save(model, snap_dir + args.flow + '.model') + # torch.save(model, snap_dir + args.flow + '_' + args.architecture + '.model') + + elif (args.early_stopping_epochs > 0) and (epoch >= args.warmup): + e += 1 + if e > args.early_stopping_epochs: + break + + if args.input_type == 'binary': + logger.info( + '--> Early stopping: {}/{} (BEST: loss {:.4f})\n'.format(e, args.early_stopping_epochs, best_loss) + ) + + else: + logger.info( + '--> Early stopping: {}/{} (BEST: loss {:.4f}, bpd {:.4f})\n'. + format(e, args.early_stopping_epochs, best_loss, best_bpd) + ) + + if math.isnan(v_loss): + raise ValueError('NaN encountered!') + + train_loss = np.hstack(train_loss) + val_loss = np.array(val_loss) + + plot_training_curve(train_loss, val_loss, fname=snap_dir + '/training_curve_%s.pdf' % args.flow) + + # training time per epoch + train_times = np.array(train_times) + mean_train_time = np.mean(train_times) + std_train_time = np.std(train_times, ddof=1) + logger.info('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) + + # ================================================================================================================== + # EVALUATION + # ================================================================================================================== + + logger.info(args) + logger.info('Stopped after %d epochs' % epoch) + logger.info('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) + + final_model = torch.load(snap_dir + args.flow + '.model') + validation_loss, validation_bpd = evaluate(val_loader, final_model, args, logger) + + else: + validation_loss = "N/A" + validation_bpd = "N/A" + logger.info(f"Loading model from {args.model_path}") + final_model = torch.load(args.model_path) + + test_loss, test_bpd = evaluate(test_loader, final_model, args, logger, testing=True) + + logger.info('FINAL EVALUATION ON VALIDATION SET. ELBO (VAL): {:.4f}'.format(validation_loss)) + logger.info('FINAL EVALUATION ON TEST SET. NLL (TEST): {:.4f}'.format(test_loss)) + if args.input_type != 'binary': + logger.info('FINAL EVALUATION ON VALIDATION SET. ELBO (VAL) BPD : {:.4f}'.format(validation_bpd)) + logger.info('FINAL EVALUATION ON TEST SET. NLL (TEST) BPD: {:.4f}'.format(test_bpd)) + + +if __name__ == "__main__": + + run(args, kwargs) diff --git a/src/torchprune/torchprune/util/models/cnn/LICENSE b/src/torchprune/torchprune/util/external/ffjord/vae_lib/LICENSE similarity index 96% rename from src/torchprune/torchprune/util/models/cnn/LICENSE rename to src/torchprune/torchprune/util/external/ffjord/vae_lib/LICENSE index 0482bd0..a37f7eb 100644 --- a/src/torchprune/torchprune/util/models/cnn/LICENSE +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2017 Wei Yang +Copyright (c) 2018 Rianne van den Berg Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/CNFVAE.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/CNFVAE.py new file mode 100644 index 0000000..3e24dc4 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/CNFVAE.py @@ -0,0 +1,412 @@ +import torch +import torch.nn as nn +from train_misc import build_model_tabular +import lib.layers as layers +from .VAE import VAE +import lib.layers.diffeq_layers as diffeq_layers +from lib.layers.odefunc import NONLINEARITIES + +from torchdiffeq import odeint_adjoint as odeint + + +def get_hidden_dims(args): + return tuple(map(int, args.dims.split("-"))) + (args.z_size,) + + +def concat_layer_num_params(in_dim, out_dim): + return (in_dim + 1) * out_dim + out_dim + + +class CNFVAE(VAE): + + def __init__(self, args): + super(CNFVAE, self).__init__(args) + + # CNF model + self.cnf = build_model_tabular(args, args.z_size) + + if args.cuda: + self.cuda() + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and flow parameters. + """ + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + + return mean_z, var_z + + def forward(self, x): + """ + Forward pass with planar flows for the transformation z_0 -> z_1 -> ... -> z_k. + Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. + """ + + z_mu, z_var = self.encode(x) + + # Sample z_0 + z0 = self.reparameterize(z_mu, z_var) + + zero = torch.zeros(x.shape[0], 1).to(x) + zk, delta_logp = self.cnf(z0, zero) # run model forward + + x_mean = self.decode(zk) + + return x_mean, z_mu, z_var, -delta_logp.view(-1), z0, zk + + +class AmortizedBiasODEnet(nn.Module): + + def __init__(self, hidden_dims, input_dim, layer_type="concat", nonlinearity="softplus"): + super(AmortizedBiasODEnet, self).__init__() + base_layer = { + "ignore": diffeq_layers.IgnoreLinear, + "hyper": diffeq_layers.HyperLinear, + "squash": diffeq_layers.SquashLinear, + "concat": diffeq_layers.ConcatLinear, + "concat_v2": diffeq_layers.ConcatLinear_v2, + "concatsquash": diffeq_layers.ConcatSquashLinear, + "blend": diffeq_layers.BlendLinear, + "concatcoord": diffeq_layers.ConcatLinear, + }[layer_type] + self.input_dim = input_dim + + # build layers and add them + layers = [] + activation_fns = [] + hidden_shape = input_dim + for dim_out in hidden_dims: + layer = base_layer(hidden_shape, dim_out) + layers.append(layer) + activation_fns.append(NONLINEARITIES[nonlinearity]) + hidden_shape = dim_out + + self.layers = nn.ModuleList(layers) + self.activation_fns = nn.ModuleList(activation_fns[:-1]) + + def _unpack_params(self, params): + return [params] + + def forward(self, t, y, am_biases): + dx = y + for l, layer in enumerate(self.layers): + dx = layer(t, dx) + this_bias, am_biases = am_biases[:, :dx.size(1)], am_biases[:, dx.size(1):] + dx = dx + this_bias + # if not last layer, use nonlinearity + if l < len(self.layers) - 1: + dx = self.activation_fns[l](dx) + return dx + + +class AmortizedLowRankODEnet(nn.Module): + + def __init__(self, hidden_dims, input_dim, rank=1, layer_type="concat", nonlinearity="softplus"): + super(AmortizedLowRankODEnet, self).__init__() + base_layer = { + "ignore": diffeq_layers.IgnoreLinear, + "hyper": diffeq_layers.HyperLinear, + "squash": diffeq_layers.SquashLinear, + "concat": diffeq_layers.ConcatLinear, + "concat_v2": diffeq_layers.ConcatLinear_v2, + "concatsquash": diffeq_layers.ConcatSquashLinear, + "blend": diffeq_layers.BlendLinear, + "concatcoord": diffeq_layers.ConcatLinear, + }[layer_type] + self.input_dim = input_dim + + # build layers and add them + layers = [] + activation_fns = [] + hidden_shape = input_dim + self.output_dims = hidden_dims + self.input_dims = (input_dim,) + hidden_dims[:-1] + for dim_out in hidden_dims: + layer = base_layer(hidden_shape, dim_out) + layers.append(layer) + activation_fns.append(NONLINEARITIES[nonlinearity]) + hidden_shape = dim_out + + self.layers = nn.ModuleList(layers) + self.activation_fns = nn.ModuleList(activation_fns[:-1]) + self.rank = rank + + def _unpack_params(self, params): + return [params] + + def _rank_k_bmm(self, x, u, v): + xu = torch.bmm(x[:, None], u.view(x.shape[0], x.shape[-1], self.rank)) + xuv = torch.bmm(xu, v.view(x.shape[0], self.rank, -1)) + return xuv[:, 0] + + def forward(self, t, y, am_params): + dx = y + for l, (layer, in_dim, out_dim) in enumerate(zip(self.layers, self.input_dims, self.output_dims)): + this_u, am_params = am_params[:, :in_dim * self.rank], am_params[:, in_dim * self.rank:] + this_v, am_params = am_params[:, :out_dim * self.rank], am_params[:, out_dim * self.rank:] + this_bias, am_params = am_params[:, :out_dim], am_params[:, out_dim:] + + xw = layer(t, dx) + xw_am = self._rank_k_bmm(dx, this_u, this_v) + dx = xw + xw_am + this_bias + # if not last layer, use nonlinearity + if l < len(self.layers) - 1: + dx = self.activation_fns[l](dx) + return dx + + +class HyperODEnet(nn.Module): + + def __init__(self, hidden_dims, input_dim, layer_type="concat", nonlinearity="softplus"): + super(HyperODEnet, self).__init__() + assert layer_type == "concat" + self.input_dim = input_dim + + # build layers and add them + activation_fns = [] + for dim_out in hidden_dims + (input_dim,): + activation_fns.append(NONLINEARITIES[nonlinearity]) + self.activation_fns = nn.ModuleList(activation_fns[:-1]) + self.output_dims = hidden_dims + self.input_dims = (input_dim,) + hidden_dims[:-1] + + def _pack_inputs(self, t, x): + tt = torch.ones_like(x[:, :1]) * t + ttx = torch.cat([tt, x], 1) + return ttx + + def _unpack_params(self, params): + layer_params = [] + for in_dim, out_dim in zip(self.input_dims, self.output_dims): + this_num_params = concat_layer_num_params(in_dim, out_dim) + # get params for this layer + this_params, params = params[:, :this_num_params], params[:, this_num_params:] + # split into weight and bias + bias, weight_params = this_params[:, :out_dim], this_params[:, out_dim:] + weight = weight_params.view(weight_params.size(0), in_dim + 1, out_dim) + layer_params.append(weight) + layer_params.append(bias) + return layer_params + + def _layer(self, t, x, weight, bias): + # weights is (batch, in_dim + 1, out_dim) + ttx = self._pack_inputs(t, x) # (batch, in_dim + 1) + ttx = ttx.view(ttx.size(0), 1, ttx.size(1)) # (batch, 1, in_dim + 1) + xw = torch.bmm(ttx, weight)[:, 0, :] # (batch, out_dim) + return xw + bias + + def forward(self, t, y, *layer_params): + dx = y + for l, (weight, bias) in enumerate(zip(layer_params[::2], layer_params[1::2])): + dx = self._layer(t, dx, weight, bias) + # if not last layer, use nonlinearity + if l < len(layer_params) - 1: + dx = self.activation_fns[l](dx) + return dx + + +class LyperODEnet(nn.Module): + + def __init__(self, hidden_dims, input_dim, layer_type="concat", nonlinearity="softplus"): + super(LyperODEnet, self).__init__() + base_layer = { + "ignore": diffeq_layers.IgnoreLinear, + "hyper": diffeq_layers.HyperLinear, + "squash": diffeq_layers.SquashLinear, + "concat": diffeq_layers.ConcatLinear, + "concat_v2": diffeq_layers.ConcatLinear_v2, + "concatsquash": diffeq_layers.ConcatSquashLinear, + "blend": diffeq_layers.BlendLinear, + "concatcoord": diffeq_layers.ConcatLinear, + }[layer_type] + self.input_dim = input_dim + + # build layers and add them + layers = [] + activation_fns = [] + hidden_shape = input_dim + self.dims = (input_dim,) + hidden_dims + self.output_dims = hidden_dims + self.input_dims = (input_dim,) + hidden_dims[:-1] + for dim_out in hidden_dims[:-1]: + layer = base_layer(hidden_shape, dim_out) + layers.append(layer) + activation_fns.append(NONLINEARITIES[nonlinearity]) + hidden_shape = dim_out + + self.layers = nn.ModuleList(layers) + self.activation_fns = nn.ModuleList(activation_fns) + + def _pack_inputs(self, t, x): + tt = torch.ones_like(x[:, :1]) * t + ttx = torch.cat([tt, x], 1) + return ttx + + def _unpack_params(self, params): + return [params] + + def _am_layer(self, t, x, weight, bias): + # weights is (batch, in_dim + 1, out_dim) + ttx = self._pack_inputs(t, x) # (batch, in_dim + 1) + ttx = ttx.view(ttx.size(0), 1, ttx.size(1)) # (batch, 1, in_dim + 1) + xw = torch.bmm(ttx, weight)[:, 0, :] # (batch, out_dim) + return xw + bias + + def forward(self, t, x, am_params): + dx = x + for layer, act in zip(self.layers, self.activation_fns): + dx = act(layer(t, dx)) + bias, weight_params = am_params[:, :self.dims[-1]], am_params[:, self.dims[-1]:] + weight = weight_params.view(weight_params.size(0), self.dims[-2] + 1, self.dims[-1]) + dx = self._am_layer(t, dx, weight, bias) + return dx + + +def construct_amortized_odefunc(args, z_dim, amortization_type="bias"): + + hidden_dims = get_hidden_dims(args) + + if amortization_type == "bias": + diffeq = AmortizedBiasODEnet( + hidden_dims=hidden_dims, + input_dim=z_dim, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + elif amortization_type == "hyper": + diffeq = HyperODEnet( + hidden_dims=hidden_dims, + input_dim=z_dim, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + elif amortization_type == "lyper": + diffeq = LyperODEnet( + hidden_dims=hidden_dims, + input_dim=z_dim, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + elif amortization_type == "low_rank": + diffeq = AmortizedLowRankODEnet( + hidden_dims=hidden_dims, + input_dim=z_dim, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + rank=args.rank, + ) + odefunc = layers.ODEfunc( + diffeq=diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + return odefunc + + +class AmortizedCNFVAE(VAE): + h_size = 256 + + def __init__(self, args): + super(AmortizedCNFVAE, self).__init__(args) + + # CNF model + self.odefuncs = nn.ModuleList([ + construct_amortized_odefunc(args, args.z_size, self.amortization_type) for _ in range(args.num_blocks) + ]) + self.q_am = self._amortized_layers(args) + assert len(self.q_am) == args.num_blocks or len(self.q_am) == 0 + + if args.cuda: + self.cuda() + + self.register_buffer('integration_times', torch.tensor([0.0, args.time_length])) + + self.atol = args.atol + self.rtol = args.rtol + self.solver = args.solver + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and flow parameters. + """ + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + am_params = [q_am(h) for q_am in self.q_am] + + return mean_z, var_z, am_params + + def forward(self, x): + + self.log_det_j = 0. + + z_mu, z_var, am_params = self.encode(x) + + # Sample z_0 + z0 = self.reparameterize(z_mu, z_var) + + delta_logp = torch.zeros(x.shape[0], 1).to(x) + z = z0 + for odefunc, am_param in zip(self.odefuncs, am_params): + am_param_unpacked = odefunc.diffeq._unpack_params(am_param) + odefunc.before_odeint() + states = odeint( + odefunc, + (z, delta_logp) + tuple(am_param_unpacked), + self.integration_times.to(z), + atol=self.atol, + rtol=self.rtol, + method=self.solver, + ) + z, delta_logp = states[0][-1], states[1][-1] + + x_mean = self.decode(z) + + return x_mean, z_mu, z_var, -delta_logp.view(-1), z0, z + + +class AmortizedBiasCNFVAE(AmortizedCNFVAE): + amortization_type = "bias" + + def _amortized_layers(self, args): + hidden_dims = get_hidden_dims(args) + bias_size = sum(hidden_dims) + return nn.ModuleList([nn.Linear(self.h_size, bias_size) for _ in range(args.num_blocks)]) + + +class AmortizedLowRankCNFVAE(AmortizedCNFVAE): + amortization_type = "low_rank" + + def _amortized_layers(self, args): + out_dims = get_hidden_dims(args) + in_dims = (out_dims[-1],) + out_dims[:-1] + params_size = (sum(in_dims) + sum(out_dims)) * args.rank + sum(out_dims) + return nn.ModuleList([nn.Linear(self.h_size, params_size) for _ in range(args.num_blocks)]) + + +class HypernetCNFVAE(AmortizedCNFVAE): + amortization_type = "hyper" + + def _amortized_layers(self, args): + hidden_dims = get_hidden_dims(args) + input_dims = (args.z_size,) + hidden_dims[:-1] + assert args.layer_type == "concat", "hypernets only support concat layers at the moment" + weight_dims = [concat_layer_num_params(in_dim, out_dim) for in_dim, out_dim in zip(input_dims, hidden_dims)] + weight_size = sum(weight_dims) + return nn.ModuleList([nn.Linear(self.h_size, weight_size) for _ in range(args.num_blocks)]) + + +class LypernetCNFVAE(AmortizedCNFVAE): + amortization_type = "lyper" + + def _amortized_layers(self, args): + dims = (args.z_size,) + get_hidden_dims(args) + weight_size = concat_layer_num_params(dims[-2], dims[-1]) + return nn.ModuleList([nn.Linear(self.h_size, weight_size) for _ in range(args.num_blocks)]) diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/VAE.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/VAE.py new file mode 100644 index 0000000..96e4ea5 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/VAE.py @@ -0,0 +1,735 @@ +from __future__ import print_function + +import torch +import torch.nn as nn +import vae_lib.models.flows as flows +from vae_lib.models.layers import GatedConv2d, GatedConvTranspose2d + + +class VAE(nn.Module): + """ + The base VAE class containing gated convolutional encoder and decoder architecture. + Can be used as a base class for VAE's with normalizing flows. + """ + + def __init__(self, args): + super(VAE, self).__init__() + + # extract model settings from args + self.z_size = args.z_size + self.input_size = args.input_size + self.input_type = args.input_type + + if self.input_size == [1, 28, 28] or self.input_size == [3, 28, 28]: + self.last_kernel_size = 7 + elif self.input_size == [1, 28, 20]: + self.last_kernel_size = (7, 5) + else: + raise ValueError('invalid input size!!') + + self.q_z_nn, self.q_z_mean, self.q_z_var = self.create_encoder() + self.p_x_nn, self.p_x_mean = self.create_decoder() + + self.q_z_nn_output_dim = 256 + + # auxiliary + if args.cuda: + self.FloatTensor = torch.cuda.FloatTensor + else: + self.FloatTensor = torch.FloatTensor + + # log-det-jacobian = 0 without flows + self.log_det_j = self.FloatTensor(1).zero_() + + def create_encoder(self): + """ + Helper function to create the elemental blocks for the encoder. Creates a gated convnet encoder. + the encoder expects data as input of shape (batch_size, num_channels, width, height). + """ + + if self.input_type == 'binary': + q_z_nn = nn.Sequential( + GatedConv2d(self.input_size[0], 32, 5, 1, 2), + GatedConv2d(32, 32, 5, 2, 2), + GatedConv2d(32, 64, 5, 1, 2), + GatedConv2d(64, 64, 5, 2, 2), + GatedConv2d(64, 64, 5, 1, 2), + GatedConv2d(64, 256, self.last_kernel_size, 1, 0), + ) + q_z_mean = nn.Linear(256, self.z_size) + q_z_var = nn.Sequential( + nn.Linear(256, self.z_size), + nn.Softplus(), + ) + return q_z_nn, q_z_mean, q_z_var + + elif self.input_type == 'multinomial': + act = None + + q_z_nn = nn.Sequential( + GatedConv2d(self.input_size[0], 32, 5, 1, 2, activation=act), + GatedConv2d(32, 32, 5, 2, 2, activation=act), + GatedConv2d(32, 64, 5, 1, 2, activation=act), + GatedConv2d(64, 64, 5, 2, 2, activation=act), + GatedConv2d(64, 64, 5, 1, 2, activation=act), + GatedConv2d(64, 256, self.last_kernel_size, 1, 0, activation=act) + ) + q_z_mean = nn.Linear(256, self.z_size) + q_z_var = nn.Sequential(nn.Linear(256, self.z_size), nn.Softplus(), nn.Hardtanh(min_val=0.01, max_val=7.)) + return q_z_nn, q_z_mean, q_z_var + + def create_decoder(self): + """ + Helper function to create the elemental blocks for the decoder. Creates a gated convnet decoder. + """ + + num_classes = 256 + + if self.input_type == 'binary': + p_x_nn = nn.Sequential( + GatedConvTranspose2d(self.z_size, 64, self.last_kernel_size, 1, 0), + GatedConvTranspose2d(64, 64, 5, 1, 2), + GatedConvTranspose2d(64, 32, 5, 2, 2, 1), + GatedConvTranspose2d(32, 32, 5, 1, 2), + GatedConvTranspose2d(32, 32, 5, 2, 2, 1), GatedConvTranspose2d(32, 32, 5, 1, 2) + ) + + p_x_mean = nn.Sequential(nn.Conv2d(32, self.input_size[0], 1, 1, 0), nn.Sigmoid()) + return p_x_nn, p_x_mean + + elif self.input_type == 'multinomial': + act = None + p_x_nn = nn.Sequential( + GatedConvTranspose2d(self.z_size, 64, self.last_kernel_size, 1, 0, activation=act), + GatedConvTranspose2d(64, 64, 5, 1, 2, activation=act), + GatedConvTranspose2d(64, 32, 5, 2, 2, 1, activation=act), + GatedConvTranspose2d(32, 32, 5, 1, 2, activation=act), + GatedConvTranspose2d(32, 32, 5, 2, 2, 1, activation=act), + GatedConvTranspose2d(32, 32, 5, 1, 2, activation=act) + ) + + p_x_mean = nn.Sequential( + nn.Conv2d(32, 256, 5, 1, 2), + nn.Conv2d(256, self.input_size[0] * num_classes, 1, 1, 0), + # output shape: batch_size, num_channels * num_classes, pixel_width, pixel_height + ) + + return p_x_nn, p_x_mean + + else: + raise ValueError('invalid input type!!') + + def reparameterize(self, mu, var): + """ + Samples z from a multivariate Gaussian with diagonal covariance matrix using the + reparameterization trick. + """ + + std = var.sqrt() + eps = self.FloatTensor(std.size()).normal_() + z = eps.mul(std).add_(mu) + + return z + + def encode(self, x): + """ + Encoder expects following data shapes as input: shape = (batch_size, num_channels, width, height) + """ + + h = self.q_z_nn(x) + h = h.view(h.size(0), -1) + mean = self.q_z_mean(h) + var = self.q_z_var(h) + + return mean, var + + def decode(self, z): + """ + Decoder outputs reconstructed image in the following shapes: + x_mean.shape = (batch_size, num_channels, width, height) + """ + + z = z.view(z.size(0), self.z_size, 1, 1) + h = self.p_x_nn(z) + x_mean = self.p_x_mean(h) + + return x_mean + + def forward(self, x): + """ + Evaluates the model as a whole, encodes and decodes. Note that the log det jacobian is zero + for a plain VAE (without flows), and z_0 = z_k. + """ + + # mean and variance of z + z_mu, z_var = self.encode(x) + # sample z + z = self.reparameterize(z_mu, z_var) + x_mean = self.decode(z) + + return x_mean, z_mu, z_var, self.log_det_j, z, z + + +class PlanarVAE(VAE): + """ + Variational auto-encoder with planar flows in the encoder. + """ + + def __init__(self, args): + super(PlanarVAE, self).__init__(args) + + # Initialize log-det-jacobian to zero + self.log_det_j = 0. + + # Flow parameters + flow = flows.Planar + self.num_flows = args.num_flows + + # Amortized flow parameters + self.amor_u = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size) + self.amor_w = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size) + self.amor_b = nn.Linear(self.q_z_nn_output_dim, self.num_flows) + + # Normalizing flow layers + for k in range(self.num_flows): + flow_k = flow() + self.add_module('flow_' + str(k), flow_k) + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and flow parameters. + """ + + batch_size = x.size(0) + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + + # return amortized u an w for all flows + u = self.amor_u(h).view(batch_size, self.num_flows, self.z_size, 1) + w = self.amor_w(h).view(batch_size, self.num_flows, 1, self.z_size) + b = self.amor_b(h).view(batch_size, self.num_flows, 1, 1) + + return mean_z, var_z, u, w, b + + def forward(self, x): + """ + Forward pass with planar flows for the transformation z_0 -> z_1 -> ... -> z_k. + Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. + """ + + self.log_det_j = 0. + + z_mu, z_var, u, w, b = self.encode(x) + + # Sample z_0 + z = [self.reparameterize(z_mu, z_var)] + + # Normalizing flows + for k in range(self.num_flows): + flow_k = getattr(self, 'flow_' + str(k)) + z_k, log_det_jacobian = flow_k(z[k], u[:, k, :, :], w[:, k, :, :], b[:, k, :, :]) + z.append(z_k) + self.log_det_j += log_det_jacobian + + x_mean = self.decode(z[-1]) + + return x_mean, z_mu, z_var, self.log_det_j, z[0], z[-1] + + +class OrthogonalSylvesterVAE(VAE): + """ + Variational auto-encoder with orthogonal flows in the encoder. + """ + + def __init__(self, args): + super(OrthogonalSylvesterVAE, self).__init__(args) + + # Initialize log-det-jacobian to zero + self.log_det_j = 0. + + # Flow parameters + flow = flows.Sylvester + self.num_flows = args.num_flows + self.num_ortho_vecs = args.num_ortho_vecs + + assert (self.num_ortho_vecs <= self.z_size) and (self.num_ortho_vecs > 0) + + # Orthogonalization parameters + if self.num_ortho_vecs == self.z_size: + self.cond = 1.e-5 + else: + self.cond = 1.e-6 + + self.steps = 100 + identity = torch.eye(self.num_ortho_vecs, self.num_ortho_vecs) + # Add batch dimension + identity = identity.unsqueeze(0) + # Put identity in buffer so that it will be moved to GPU if needed by any call of .cuda + self.register_buffer('_eye', identity) + self._eye.requires_grad = False + + # Masks needed for triangular R1 and R2. + triu_mask = torch.triu(torch.ones(self.num_ortho_vecs, self.num_ortho_vecs), diagonal=1) + triu_mask = triu_mask.unsqueeze(0).unsqueeze(3) + diag_idx = torch.arange(0, self.num_ortho_vecs).long() + + self.register_buffer('triu_mask', triu_mask) + self.triu_mask.requires_grad = False + self.register_buffer('diag_idx', diag_idx) + + # Amortized flow parameters + # Diagonal elements of R1 * R2 have to satisfy -1 < R1 * R2 for flow to be invertible + self.diag_activation = nn.Tanh() + + self.amor_d = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.num_ortho_vecs * self.num_ortho_vecs) + + self.amor_diag1 = nn.Sequential( + nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.num_ortho_vecs), self.diag_activation + ) + self.amor_diag2 = nn.Sequential( + nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.num_ortho_vecs), self.diag_activation + ) + + self.amor_q = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size * self.num_ortho_vecs) + self.amor_b = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.num_ortho_vecs) + + # Normalizing flow layers + for k in range(self.num_flows): + flow_k = flow(self.num_ortho_vecs) + self.add_module('flow_' + str(k), flow_k) + + def batch_construct_orthogonal(self, q): + """ + Batch orthogonal matrix construction. + :param q: q contains batches of matrices, shape : (batch_size * num_flows, z_size * num_ortho_vecs) + :return: batches of orthogonalized matrices, shape: (batch_size * num_flows, z_size, num_ortho_vecs) + """ + + # Reshape to shape (num_flows * batch_size, z_size * num_ortho_vecs) + q = q.view(-1, self.z_size * self.num_ortho_vecs) + + norm = torch.norm(q, p=2, dim=1, keepdim=True) + amat = torch.div(q, norm) + dim0 = amat.size(0) + amat = amat.resize(dim0, self.z_size, self.num_ortho_vecs) + + max_norm = 0. + + # Iterative orthogonalization + for s in range(self.steps): + tmp = torch.bmm(amat.transpose(2, 1), amat) + tmp = self._eye - tmp + tmp = self._eye + 0.5 * tmp + amat = torch.bmm(amat, tmp) + + # Testing for convergence + test = torch.bmm(amat.transpose(2, 1), amat) - self._eye + norms2 = torch.sum(torch.norm(test, p=2, dim=2)**2, dim=1) + norms = torch.sqrt(norms2) + max_norm = torch.max(norms).item() + if max_norm <= self.cond: + break + + if max_norm > self.cond: + print('\nWARNING WARNING WARNING: orthogonalization not complete') + print('\t Final max norm =', max_norm) + + print() + + # Reshaping: first dimension is batch_size + amat = amat.view(-1, self.num_flows, self.z_size, self.num_ortho_vecs) + amat = amat.transpose(0, 1) + + return amat + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and flow parameters. + """ + + batch_size = x.size(0) + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + + # Amortized r1, r2, q, b for all flows + + full_d = self.amor_d(h) + diag1 = self.amor_diag1(h) + diag2 = self.amor_diag2(h) + + full_d = full_d.resize(batch_size, self.num_ortho_vecs, self.num_ortho_vecs, self.num_flows) + diag1 = diag1.resize(batch_size, self.num_ortho_vecs, self.num_flows) + diag2 = diag2.resize(batch_size, self.num_ortho_vecs, self.num_flows) + + r1 = full_d * self.triu_mask + r2 = full_d.transpose(2, 1) * self.triu_mask + + r1[:, self.diag_idx, self.diag_idx, :] = diag1 + r2[:, self.diag_idx, self.diag_idx, :] = diag2 + + q = self.amor_q(h) + b = self.amor_b(h) + + # Resize flow parameters to divide over K flows + b = b.resize(batch_size, 1, self.num_ortho_vecs, self.num_flows) + + return mean_z, var_z, r1, r2, q, b + + def forward(self, x): + """ + Forward pass with orthogonal sylvester flows for the transformation z_0 -> z_1 -> ... -> z_k. + Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. + """ + + self.log_det_j = 0. + + z_mu, z_var, r1, r2, q, b = self.encode(x) + + # Orthogonalize all q matrices + q_ortho = self.batch_construct_orthogonal(q) + + # Sample z_0 + z = [self.reparameterize(z_mu, z_var)] + + # Normalizing flows + for k in range(self.num_flows): + + flow_k = getattr(self, 'flow_' + str(k)) + z_k, log_det_jacobian = flow_k(z[k], r1[:, :, :, k], r2[:, :, :, k], q_ortho[k, :, :, :], b[:, :, :, k]) + + z.append(z_k) + self.log_det_j += log_det_jacobian + + x_mean = self.decode(z[-1]) + + return x_mean, z_mu, z_var, self.log_det_j, z[0], z[-1] + + +class HouseholderSylvesterVAE(VAE): + """ + Variational auto-encoder with householder sylvester flows in the encoder. + """ + + def __init__(self, args): + super(HouseholderSylvesterVAE, self).__init__(args) + + # Initialize log-det-jacobian to zero + self.log_det_j = 0. + + # Flow parameters + flow = flows.Sylvester + self.num_flows = args.num_flows + self.num_householder = args.num_householder + assert self.num_householder > 0 + + identity = torch.eye(self.z_size, self.z_size) + # Add batch dimension + identity = identity.unsqueeze(0) + # Put identity in buffer so that it will be moved to GPU if needed by any call of .cuda + self.register_buffer('_eye', identity) + self._eye.requires_grad = False + + # Masks needed for triangular r1 and r2. + triu_mask = torch.triu(torch.ones(self.z_size, self.z_size), diagonal=1) + triu_mask = triu_mask.unsqueeze(0).unsqueeze(3) + diag_idx = torch.arange(0, self.z_size).long() + + self.register_buffer('triu_mask', triu_mask) + self.triu_mask.requires_grad = False + self.register_buffer('diag_idx', diag_idx) + + # Amortized flow parameters + # Diagonal elements of r1 * r2 have to satisfy -1 < r1 * r2 for flow to be invertible + self.diag_activation = nn.Tanh() + + self.amor_d = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size * self.z_size) + + self.amor_diag1 = nn.Sequential( + nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size), self.diag_activation + ) + self.amor_diag2 = nn.Sequential( + nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size), self.diag_activation + ) + + self.amor_q = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size * self.num_householder) + + self.amor_b = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size) + + # Normalizing flow layers + for k in range(self.num_flows): + flow_k = flow(self.z_size) + + self.add_module('flow_' + str(k), flow_k) + + def batch_construct_orthogonal(self, q): + """ + Batch orthogonal matrix construction. + :param q: q contains batches of matrices, shape : (batch_size, num_flows * z_size * num_householder) + :return: batches of orthogonalized matrices, shape: (batch_size * num_flows, z_size, z_size) + """ + + # Reshape to shape (num_flows * batch_size * num_householder, z_size) + q = q.view(-1, self.z_size) + + norm = torch.norm(q, p=2, dim=1, keepdim=True) # ||v||_2 + v = torch.div(q, norm) # v / ||v||_2 + + # Calculate Householder Matrices + vvT = torch.bmm(v.unsqueeze(2), v.unsqueeze(1)) # v * v_T : batch_dot( B x L x 1 * B x 1 x L ) = B x L x L + + amat = self._eye - 2 * vvT # NOTICE: v is already normalized! so there is no need to calculate vvT/vTv + + # Reshaping: first dimension is batch_size * num_flows + amat = amat.view(-1, self.num_householder, self.z_size, self.z_size) + + tmp = amat[:, 0] + for k in range(1, self.num_householder): + tmp = torch.bmm(amat[:, k], tmp) + + amat = tmp.view(-1, self.num_flows, self.z_size, self.z_size) + amat = amat.transpose(0, 1) + + return amat + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and flow parameters. + """ + + batch_size = x.size(0) + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + + # Amortized r1, r2, q, b for all flows + full_d = self.amor_d(h) + diag1 = self.amor_diag1(h) + diag2 = self.amor_diag2(h) + + full_d = full_d.resize(batch_size, self.z_size, self.z_size, self.num_flows) + diag1 = diag1.resize(batch_size, self.z_size, self.num_flows) + diag2 = diag2.resize(batch_size, self.z_size, self.num_flows) + + r1 = full_d * self.triu_mask + r2 = full_d.transpose(2, 1) * self.triu_mask + + r1[:, self.diag_idx, self.diag_idx, :] = diag1 + r2[:, self.diag_idx, self.diag_idx, :] = diag2 + + q = self.amor_q(h) + + b = self.amor_b(h) + + # Resize flow parameters to divide over K flows + b = b.resize(batch_size, 1, self.z_size, self.num_flows) + + return mean_z, var_z, r1, r2, q, b + + def forward(self, x): + """ + Forward pass with orthogonal flows for the transformation z_0 -> z_1 -> ... -> z_k. + Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. + """ + + self.log_det_j = 0. + + z_mu, z_var, r1, r2, q, b = self.encode(x) + + # Orthogonalize all q matrices + q_ortho = self.batch_construct_orthogonal(q) + + # Sample z_0 + z = [self.reparameterize(z_mu, z_var)] + + # Normalizing flows + for k in range(self.num_flows): + + flow_k = getattr(self, 'flow_' + str(k)) + q_k = q_ortho[k] + + z_k, log_det_jacobian = flow_k(z[k], r1[:, :, :, k], r2[:, :, :, k], q_k, b[:, :, :, k], sum_ldj=True) + + z.append(z_k) + self.log_det_j += log_det_jacobian + + x_mean = self.decode(z[-1]) + + return x_mean, z_mu, z_var, self.log_det_j, z[0], z[-1] + + +class TriangularSylvesterVAE(VAE): + """ + Variational auto-encoder with triangular Sylvester flows in the encoder. Alternates between setting + the orthogonal matrix equal to permutation and identity matrix for each flow. + """ + + def __init__(self, args): + super(TriangularSylvesterVAE, self).__init__(args) + + # Initialize log-det-jacobian to zero + self.log_det_j = 0. + + # Flow parameters + flow = flows.TriangularSylvester + self.num_flows = args.num_flows + + # permuting indices corresponding to Q=P (permutation matrix) for every other flow + flip_idx = torch.arange(self.z_size - 1, -1, -1).long() + self.register_buffer('flip_idx', flip_idx) + + # Masks needed for triangular r1 and r2. + triu_mask = torch.triu(torch.ones(self.z_size, self.z_size), diagonal=1) + triu_mask = triu_mask.unsqueeze(0).unsqueeze(3) + diag_idx = torch.arange(0, self.z_size).long() + + self.register_buffer('triu_mask', triu_mask) + self.triu_mask.requires_grad = False + self.register_buffer('diag_idx', diag_idx) + + # Amortized flow parameters + # Diagonal elements of r1 * r2 have to satisfy -1 < r1 * r2 for flow to be invertible + self.diag_activation = nn.Tanh() + + self.amor_d = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size * self.z_size) + + self.amor_diag1 = nn.Sequential( + nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size), self.diag_activation + ) + self.amor_diag2 = nn.Sequential( + nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size), self.diag_activation + ) + + self.amor_b = nn.Linear(self.q_z_nn_output_dim, self.num_flows * self.z_size) + + # Normalizing flow layers + for k in range(self.num_flows): + flow_k = flow(self.z_size) + + self.add_module('flow_' + str(k), flow_k) + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and flow parameters. + """ + + batch_size = x.size(0) + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + + # Amortized r1, r2, b for all flows + full_d = self.amor_d(h) + diag1 = self.amor_diag1(h) + diag2 = self.amor_diag2(h) + + full_d = full_d.resize(batch_size, self.z_size, self.z_size, self.num_flows) + diag1 = diag1.resize(batch_size, self.z_size, self.num_flows) + diag2 = diag2.resize(batch_size, self.z_size, self.num_flows) + + r1 = full_d * self.triu_mask + r2 = full_d.transpose(2, 1) * self.triu_mask + + r1[:, self.diag_idx, self.diag_idx, :] = diag1 + r2[:, self.diag_idx, self.diag_idx, :] = diag2 + + b = self.amor_b(h) + + # Resize flow parameters to divide over K flows + b = b.resize(batch_size, 1, self.z_size, self.num_flows) + + return mean_z, var_z, r1, r2, b + + def forward(self, x): + """ + Forward pass with orthogonal flows for the transformation z_0 -> z_1 -> ... -> z_k. + Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. + """ + + self.log_det_j = 0. + + z_mu, z_var, r1, r2, b = self.encode(x) + + # Sample z_0 + z = [self.reparameterize(z_mu, z_var)] + + # Normalizing flows + for k in range(self.num_flows): + + flow_k = getattr(self, 'flow_' + str(k)) + if k % 2 == 1: + # Alternate with reorderering z for triangular flow + permute_z = self.flip_idx + else: + permute_z = None + + z_k, log_det_jacobian = flow_k(z[k], r1[:, :, :, k], r2[:, :, :, k], b[:, :, :, k], permute_z, sum_ldj=True) + + z.append(z_k) + self.log_det_j += log_det_jacobian + + x_mean = self.decode(z[-1]) + + return x_mean, z_mu, z_var, self.log_det_j, z[0], z[-1] + + +class IAFVAE(VAE): + """ + Variational auto-encoder with inverse autoregressive flows in the encoder. + """ + + def __init__(self, args): + super(IAFVAE, self).__init__(args) + + # Initialize log-det-jacobian to zero + self.log_det_j = 0. + self.h_size = args.made_h_size + + self.h_context = nn.Linear(self.q_z_nn_output_dim, self.h_size) + + # Flow parameters + self.num_flows = args.num_flows + self.flow = flows.IAF( + z_size=self.z_size, num_flows=self.num_flows, num_hidden=1, h_size=self.h_size, conv2d=False + ) + + def encode(self, x): + """ + Encoder that ouputs parameters for base distribution of z and context h for flows. + """ + + h = self.q_z_nn(x) + h = h.view(-1, self.q_z_nn_output_dim) + mean_z = self.q_z_mean(h) + var_z = self.q_z_var(h) + h_context = self.h_context(h) + + return mean_z, var_z, h_context + + def forward(self, x): + """ + Forward pass with inverse autoregressive flows for the transformation z_0 -> z_1 -> ... -> z_k. + Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. + """ + + # mean and variance of z + z_mu, z_var, h_context = self.encode(x) + # sample z + z_0 = self.reparameterize(z_mu, z_var) + + # iaf flows + z_k, self.log_det_j = self.flow(z_0, h_context) + + # decode + x_mean = self.decode(z_k) + + return x_mean, z_mu, z_var, self.log_det_j, z_0, z_k diff --git a/src/torchprune/torchprune/util/models/cnn/models/__init__.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/__init__.py similarity index 100% rename from src/torchprune/torchprune/util/models/cnn/models/__init__.py rename to src/torchprune/torchprune/util/external/ffjord/vae_lib/models/__init__.py diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/flows.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/flows.py new file mode 100644 index 0000000..0894043 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/flows.py @@ -0,0 +1,299 @@ +""" +Collection of flow strategies +""" + +from __future__ import print_function + +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.nn.functional as F + +from vae_lib.models.layers import MaskedConv2d, MaskedLinear + + +class Planar(nn.Module): + """ + PyTorch implementation of planar flows as presented in "Variational Inference with Normalizing Flows" + by Danilo Jimenez Rezende, Shakir Mohamed. Model assumes amortized flow parameters. + """ + + def __init__(self): + + super(Planar, self).__init__() + + self.h = nn.Tanh() + self.softplus = nn.Softplus() + + def der_h(self, x): + """ Derivative of tanh """ + + return 1 - self.h(x)**2 + + def forward(self, zk, u, w, b): + """ + Forward pass. Assumes amortized u, w and b. Conditions on diagonals of u and w for invertibility + will be be satisfied inside this function. Computes the following transformation: + z' = z + u h( w^T z + b) + or actually + z'^T = z^T + h(z^T w + b)u^T + Assumes the following input shapes: + shape u = (batch_size, z_size, 1) + shape w = (batch_size, 1, z_size) + shape b = (batch_size, 1, 1) + shape z = (batch_size, z_size). + """ + + zk = zk.unsqueeze(2) + + # reparameterize u such that the flow becomes invertible (see appendix paper) + uw = torch.bmm(w, u) + m_uw = -1. + self.softplus(uw) + w_norm_sq = torch.sum(w**2, dim=2, keepdim=True) + u_hat = u + ((m_uw - uw) * w.transpose(2, 1) / w_norm_sq) + + # compute flow with u_hat + wzb = torch.bmm(w, zk) + b + z = zk + u_hat * self.h(wzb) + z = z.squeeze(2) + + # compute logdetJ + psi = w * self.der_h(wzb) + log_det_jacobian = torch.log(torch.abs(1 + torch.bmm(psi, u_hat))) + log_det_jacobian = log_det_jacobian.squeeze(2).squeeze(1) + + return z, log_det_jacobian + + +class Sylvester(nn.Module): + """ + Sylvester normalizing flow. + """ + + def __init__(self, num_ortho_vecs): + + super(Sylvester, self).__init__() + + self.num_ortho_vecs = num_ortho_vecs + + self.h = nn.Tanh() + + triu_mask = torch.triu(torch.ones(num_ortho_vecs, num_ortho_vecs), diagonal=1).unsqueeze(0) + diag_idx = torch.arange(0, num_ortho_vecs).long() + + self.register_buffer('triu_mask', Variable(triu_mask)) + self.triu_mask.requires_grad = False + self.register_buffer('diag_idx', diag_idx) + + def der_h(self, x): + return self.der_tanh(x) + + def der_tanh(self, x): + return 1 - self.h(x)**2 + + def _forward(self, zk, r1, r2, q_ortho, b, sum_ldj=True): + """ + All flow parameters are amortized. Conditions on diagonals of R1 and R2 for invertibility need to be satisfied + outside of this function. Computes the following transformation: + z' = z + QR1 h( R2Q^T z + b) + or actually + z'^T = z^T + h(z^T Q R2^T + b^T)R1^T Q^T + :param zk: shape: (batch_size, z_size) + :param r1: shape: (batch_size, num_ortho_vecs, num_ortho_vecs) + :param r2: shape: (batch_size, num_ortho_vecs, num_ortho_vecs) + :param q_ortho: shape (batch_size, z_size , num_ortho_vecs) + :param b: shape: (batch_size, 1, self.z_size) + :return: z, log_det_j + """ + + # Amortized flow parameters + zk = zk.unsqueeze(1) + + # Save diagonals for log_det_j + diag_r1 = r1[:, self.diag_idx, self.diag_idx] + diag_r2 = r2[:, self.diag_idx, self.diag_idx] + + r1_hat = r1 + r2_hat = r2 + + qr2 = torch.bmm(q_ortho, r2_hat.transpose(2, 1)) + qr1 = torch.bmm(q_ortho, r1_hat) + + r2qzb = torch.bmm(zk, qr2) + b + z = torch.bmm(self.h(r2qzb), qr1.transpose(2, 1)) + zk + z = z.squeeze(1) + + # Compute log|det J| + # Output log_det_j in shape (batch_size) instead of (batch_size,1) + diag_j = diag_r1 * diag_r2 + diag_j = self.der_h(r2qzb).squeeze(1) * diag_j + diag_j += 1. + log_diag_j = diag_j.abs().log() + + if sum_ldj: + log_det_j = log_diag_j.sum(-1) + else: + log_det_j = log_diag_j + + return z, log_det_j + + def forward(self, zk, r1, r2, q_ortho, b, sum_ldj=True): + + return self._forward(zk, r1, r2, q_ortho, b, sum_ldj) + + +class TriangularSylvester(nn.Module): + """ + Sylvester normalizing flow with Q=P or Q=I. + """ + + def __init__(self, z_size): + + super(TriangularSylvester, self).__init__() + + self.z_size = z_size + self.h = nn.Tanh() + + diag_idx = torch.arange(0, z_size).long() + self.register_buffer('diag_idx', diag_idx) + + def der_h(self, x): + return self.der_tanh(x) + + def der_tanh(self, x): + return 1 - self.h(x)**2 + + def _forward(self, zk, r1, r2, b, permute_z=None, sum_ldj=True): + """ + All flow parameters are amortized. conditions on diagonals of R1 and R2 need to be satisfied + outside of this function. + Computes the following transformation: + z' = z + QR1 h( R2Q^T z + b) + or actually + z'^T = z^T + h(z^T Q R2^T + b^T)R1^T Q^T + with Q = P a permutation matrix (equal to identity matrix if permute_z=None) + :param zk: shape: (batch_size, z_size) + :param r1: shape: (batch_size, num_ortho_vecs, num_ortho_vecs). + :param r2: shape: (batch_size, num_ortho_vecs, num_ortho_vecs). + :param b: shape: (batch_size, 1, self.z_size) + :return: z, log_det_j + """ + + # Amortized flow parameters + zk = zk.unsqueeze(1) + + # Save diagonals for log_det_j + diag_r1 = r1[:, self.diag_idx, self.diag_idx] + diag_r2 = r2[:, self.diag_idx, self.diag_idx] + + if permute_z is not None: + # permute order of z + z_per = zk[:, :, permute_z] + else: + z_per = zk + + r2qzb = torch.bmm(z_per, r2.transpose(2, 1)) + b + z = torch.bmm(self.h(r2qzb), r1.transpose(2, 1)) + + if permute_z is not None: + # permute order of z again back again + z = z[:, :, permute_z] + + z += zk + z = z.squeeze(1) + + # Compute log|det J| + # Output log_det_j in shape (batch_size) instead of (batch_size,1) + diag_j = diag_r1 * diag_r2 + diag_j = self.der_h(r2qzb).squeeze(1) * diag_j + diag_j += 1. + log_diag_j = diag_j.abs().log() + + if sum_ldj: + log_det_j = log_diag_j.sum(-1) + else: + log_det_j = log_diag_j + + return z, log_det_j + + def forward(self, zk, r1, r2, q_ortho, b, sum_ldj=True): + + return self._forward(zk, r1, r2, q_ortho, b, sum_ldj) + + +class IAF(nn.Module): + """ + PyTorch implementation of inverse autoregressive flows as presented in + "Improving Variational Inference with Inverse Autoregressive Flow" by Diederik P. Kingma, Tim Salimans, + Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling. + Inverse Autoregressive Flow with either MADE MLPs or Pixel CNNs. Contains several flows. Each transformation + takes as an input the previous stochastic z, and a context h. The structure of each flow is then as follows: + z <- autoregressive_layer(z) + h, allow for diagonal connections + z <- autoregressive_layer(z), allow for diagonal connections + : + z <- autoregressive_layer(z), do not allow for diagonal connections. + + Note that the size of h needs to be the same as h_size, which is the width of the MADE layers. + """ + + def __init__(self, z_size, num_flows=2, num_hidden=0, h_size=50, forget_bias=1., conv2d=False): + super(IAF, self).__init__() + self.z_size = z_size + self.num_flows = num_flows + self.num_hidden = num_hidden + self.h_size = h_size + self.conv2d = conv2d + if not conv2d: + ar_layer = MaskedLinear + else: + ar_layer = MaskedConv2d + self.activation = torch.nn.ELU + # self.activation = torch.nn.ReLU + + self.forget_bias = forget_bias + self.flows = [] + self.param_list = [] + + # For reordering z after each flow + flip_idx = torch.arange(self.z_size - 1, -1, -1).long() + self.register_buffer('flip_idx', flip_idx) + + for k in range(num_flows): + arch_z = [ar_layer(z_size, h_size), self.activation()] + self.param_list += list(arch_z[0].parameters()) + z_feats = torch.nn.Sequential(*arch_z) + arch_zh = [] + for j in range(num_hidden): + arch_zh += [ar_layer(h_size, h_size), self.activation()] + self.param_list += list(arch_zh[-2].parameters()) + zh_feats = torch.nn.Sequential(*arch_zh) + linear_mean = ar_layer(h_size, z_size, diagonal_zeros=True) + linear_std = ar_layer(h_size, z_size, diagonal_zeros=True) + self.param_list += list(linear_mean.parameters()) + self.param_list += list(linear_std.parameters()) + + if torch.cuda.is_available(): + z_feats = z_feats.cuda() + zh_feats = zh_feats.cuda() + linear_mean = linear_mean.cuda() + linear_std = linear_std.cuda() + self.flows.append((z_feats, zh_feats, linear_mean, linear_std)) + + self.param_list = torch.nn.ParameterList(self.param_list) + + def forward(self, z, h_context): + + logdets = 0. + for i, flow in enumerate(self.flows): + if (i + 1) % 2 == 0 and not self.conv2d: + # reverse ordering to help mixing + z = z[:, self.flip_idx] + + h = flow[0](z) + h = h + h_context + h = flow[1](h) + mean = flow[2](h) + gate = F.sigmoid(flow[3](h) + self.forget_bias) + z = gate * z + (1 - gate) * mean + logdets += torch.sum(gate.log().view(gate.size(0), -1), 1) + return z, logdets diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/layers.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/layers.py new file mode 100644 index 0000000..e15c453 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/models/layers.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter +import numpy as np +import torch.nn.functional as F + + +class Identity(nn.Module): + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +class GatedConv2d(nn.Module): + + def __init__(self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None): + super(GatedConv2d, self).__init__() + + self.activation = activation + self.sigmoid = nn.Sigmoid() + + self.h = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation) + self.g = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation) + + def forward(self, x): + if self.activation is None: + h = self.h(x) + else: + h = self.activation(self.h(x)) + + g = self.sigmoid(self.g(x)) + + return h * g + + +class GatedConvTranspose2d(nn.Module): + + def __init__( + self, input_channels, output_channels, kernel_size, stride, padding, output_padding=0, dilation=1, + activation=None + ): + super(GatedConvTranspose2d, self).__init__() + + self.activation = activation + self.sigmoid = nn.Sigmoid() + + self.h = nn.ConvTranspose2d( + input_channels, output_channels, kernel_size, stride, padding, output_padding, dilation=dilation + ) + self.g = nn.ConvTranspose2d( + input_channels, output_channels, kernel_size, stride, padding, output_padding, dilation=dilation + ) + + def forward(self, x): + if self.activation is None: + h = self.h(x) + else: + h = self.activation(self.h(x)) + + g = self.sigmoid(self.g(x)) + + return h * g + + +class MaskedLinear(nn.Module): + """ + Creates masked linear layer for MLP MADE. + For input (x) to hidden (h) or hidden to hidden layers choose diagonal_zeros = False. + For hidden to output (y) layers: + If output depends on input through y_i = f(x_{= n_in: + k = n_out // n_in + for i in range(n_in): + mask[i + 1:, i * k:(i + 1) * k] = 0 + if self.diagonal_zeros: + mask[i:i + 1, i * k:(i + 1) * k] = 0 + else: + k = n_in // n_out + for i in range(n_out): + mask[(i + 1) * k:, i:i + 1] = 0 + if self.diagonal_zeros: + mask[i * k:(i + 1) * k:, i:i + 1] = 0 + return mask + + def forward(self, x): + output = x.mm(self.mask * self.weight) + + if self.bias is not None: + return output.add(self.bias.expand_as(output)) + else: + return output + + def __repr__(self): + if self.bias is not None: + bias = True + else: + bias = False + return self.__class__.__name__ + ' (' \ + + str(self.in_features) + ' -> ' \ + + str(self.out_features) + ', diagonal_zeros=' \ + + str(self.diagonal_zeros) + ', bias=' \ + + str(bias) + ')' + + +class MaskedConv2d(nn.Module): + """ + Creates masked convolutional autoregressive layer for pixelCNN. + For input (x) to hidden (h) or hidden to hidden layers choose diagonal_zeros = False. + For hidden to output (y) layers: + If output depends on input through y_i = f(x_{= n_in: + k = n_out // n_in + for i in range(n_in): + mask[i * k:(i + 1) * k, i + 1:, l, m] = 0 + if self.diagonal_zeros: + mask[i * k:(i + 1) * k, i:i + 1, l, m] = 0 + else: + k = n_in // n_out + for i in range(n_out): + mask[i:i + 1, (i + 1) * k:, l, m] = 0 + if self.diagonal_zeros: + mask[i:i + 1, i * k:(i + 1) * k:, l, m] = 0 + + return mask + + def forward(self, x): + output = F.conv2d(x, self.mask * self.weight, bias=self.bias, padding=(1, 1)) + return output + + def __repr__(self): + if self.bias is not None: + bias = True + else: + bias = False + return self.__class__.__name__ + ' (' \ + + str(self.in_features) + ' -> ' \ + + str(self.out_features) + ', diagonal_zeros=' \ + + str(self.diagonal_zeros) + ', bias=' \ + + str(bias) + ', size_kernel=' \ + + str(self.size_kernel) + ')' diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/__init__.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/loss.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/loss.py new file mode 100644 index 0000000..4e56712 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/loss.py @@ -0,0 +1,271 @@ +from __future__ import print_function + +import numpy as np +import torch +import torch.nn as nn +from vae_lib.utils.distributions import log_normal_diag, log_normal_standard, log_bernoulli +import torch.nn.functional as F + + +def binary_loss_function(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): + """ + Computes the binary loss function while summing over batch dimension, not averaged! + :param recon_x: shape: (batch_size, num_channels, pixel_width, pixel_height), bernoulli parameters p(x=1) + :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. + :param z_mu: mean of z_0 + :param z_var: variance of z_0 + :param z_0: first stochastic latent variable + :param z_k: last stochastic latent variable + :param ldj: log det jacobian + :param beta: beta for kl loss + :return: loss, ce, kl + """ + + reconstruction_function = nn.BCELoss(size_average=False) + + batch_size = x.size(0) + + # - N E_q0 [ ln p(x|z_k) ] + bce = reconstruction_function(recon_x, x) + + # ln p(z_k) (not averaged) + log_p_zk = log_normal_standard(z_k, dim=1) + # ln q(z_0) (not averaged) + log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) + # N E_q0[ ln q(z_0) - ln p(z_k) ] + summed_logs = torch.sum(log_q_z0 - log_p_zk) + + # sum over batches + summed_ldj = torch.sum(ldj) + + # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] + kl = (summed_logs - summed_ldj) + loss = bce + beta * kl + + loss /= float(batch_size) + bce /= float(batch_size) + kl /= float(batch_size) + + return loss, bce, kl + + +def multinomial_loss_function(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): + """ + Computes the cross entropy loss function while summing over batch dimension, not averaged! + :param x_logit: shape: (batch_size, num_classes * num_channels, pixel_width, pixel_height), real valued logits + :param x: shape (batchsize, num_channels, pixel_width, pixel_height), pixel values rescaled between [0, 1]. + :param z_mu: mean of z_0 + :param z_var: variance of z_0 + :param z_0: first stochastic latent variable + :param z_k: last stochastic latent variable + :param ldj: log det jacobian + :param args: global parameter settings + :param beta: beta for kl loss + :return: loss, ce, kl + """ + + num_classes = 256 + batch_size = x.size(0) + + x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) + + # make integer class labels + target = (x * (num_classes - 1)).long() + + # - N E_q0 [ ln p(x|z_k) ] + # sums over batch dimension (and feature dimension) + ce = cross_entropy(x_logit, target, size_average=False) + + # ln p(z_k) (not averaged) + log_p_zk = log_normal_standard(z_k, dim=1) + # ln q(z_0) (not averaged) + log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) + # N E_q0[ ln q(z_0) - ln p(z_k) ] + summed_logs = torch.sum(log_q_z0 - log_p_zk) + + # sum over batches + summed_ldj = torch.sum(ldj) + + # ldj = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ] + kl = (summed_logs - summed_ldj) + loss = ce + beta * kl + + loss /= float(batch_size) + ce /= float(batch_size) + kl /= float(batch_size) + + return loss, ce, kl + + +def binary_loss_array(recon_x, x, z_mu, z_var, z_0, z_k, ldj, beta=1.): + """ + Computes the binary loss without averaging or summing over the batch dimension. + """ + + batch_size = x.size(0) + + # if not summed over batch_dimension + if len(ldj.size()) > 1: + ldj = ldj.view(ldj.size(0), -1).sum(-1) + + # TODO: upgrade to newest pytorch version on master branch, there the nn.BCELoss comes with the option + # reduce, which when set to False, does no sum over batch dimension. + bce = -log_bernoulli(x.view(batch_size, -1), recon_x.view(batch_size, -1), dim=1) + # ln p(z_k) (not averaged) + log_p_zk = log_normal_standard(z_k, dim=1) + # ln q(z_0) (not averaged) + log_q_z0 = log_normal_diag(z_0, mean=z_mu, log_var=z_var.log(), dim=1) + # ln q(z_0) - ln p(z_k) ] + logs = log_q_z0 - log_p_zk + + loss = bce + beta * (logs - ldj) + + return loss + + +def multinomial_loss_array(x_logit, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): + """ + Computes the discritezed logistic loss without averaging or summing over the batch dimension. + """ + + num_classes = 256 + batch_size = x.size(0) + + x_logit = x_logit.view(batch_size, num_classes, args.input_size[0], args.input_size[1], args.input_size[2]) + + # make integer class labels + target = (x * (num_classes - 1)).long() + + # - N E_q0 [ ln p(x|z_k) ] + # computes cross entropy over all dimensions separately: + ce = cross_entropy(x_logit, target, size_average=False, reduce=False) + # sum over feature dimension + ce = ce.view(batch_size, -1).sum(dim=1) + + # ln p(z_k) (not averaged) + log_p_zk = log_normal_standard(z_k.view(batch_size, -1), dim=1) + # ln q(z_0) (not averaged) + log_q_z0 = log_normal_diag( + z_0.view(batch_size, -1), mean=z_mu.view(batch_size, -1), log_var=z_var.log().view(batch_size, -1), dim=1 + ) + + # ln q(z_0) - ln p(z_k) ] + logs = log_q_z0 - log_p_zk + + loss = ce + beta * (logs - ldj) + + return loss + + +def cross_entropy(input, target, weight=None, size_average=True, ignore_index=-100, reduce=True): + r""" + Taken from the master branch of pytorch, accepts (N, C, d_1, d_2, ..., d_K) input shapes + instead of only (N, C, d_1, d_2) or (N, C). + This criterion combines `log_softmax` and `nll_loss` in a single + function. + See :class:`~torch.nn.CrossEntropyLoss` for details. + Args: + input: Variable :math:`(N, C)` where `C = number of classes` + target: Variable :math:`(N)` where each value is + `0 <= targets[i] <= C-1` + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, has to be a Tensor of size `C` + size_average (bool, optional): By default, the losses are averaged + over observations for each minibatch. However, if the field + sizeAverage is set to False, the losses are instead summed + for each minibatch. Ignored if reduce is False. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When size_average is + True, the loss is averaged over non-ignored targets. Default: -100 + reduce (bool, optional): By default, the losses are averaged or summed over + observations for each minibatch depending on size_average. When reduce + is False, returns a loss per batch element instead and ignores + size_average. Default: ``True`` + """ + return nll_loss(F.log_softmax(input, 1), target, weight, size_average, ignore_index, reduce) + + +def nll_loss(input, target, weight=None, size_average=True, ignore_index=-100, reduce=True): + r""" + Taken from the master branch of pytorch, accepts (N, C, d_1, d_2, ..., d_K) input shapes + instead of only (N, C, d_1, d_2) or (N, C). + The negative log likelihood loss. + See :class:`~torch.nn.NLLLoss` for details. + Args: + input: :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)` + in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K > 1` + in the case of K-dimensional loss. + target: :math:`(N)` where each value is `0 <= targets[i] <= C-1`, + or :math:`(N, C, d_1, d_2, ..., d_K)` where :math:`K >= 1` for + K-dimensional loss. + weight (Tensor, optional): a manual rescaling weight given to each + class. If given, has to be a Tensor of size `C` + size_average (bool, optional): By default, the losses are averaged + over observations for each minibatch. If size_average + is False, the losses are summed for each minibatch. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When size_average is + True, the loss is averaged over non-ignored targets. Default: -100 + """ + dim = input.dim() + if dim == 2: + return F.nll_loss( + input, target, weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce + ) + elif dim == 4: + return F.nll_loss( + input, target, weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce + ) + elif dim == 3 or dim > 4: + n = input.size(0) + c = input.size(1) + out_size = (n,) + input.size()[2:] + if target.size()[1:] != input.size()[2:]: + raise ValueError('Expected target size {}, got {}'.format(out_size, input.size())) + input = input.contiguous().view(n, c, 1, -1) + target = target.contiguous().view(n, 1, -1) + if reduce: + _loss = nn.NLLLoss2d(weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce) + return _loss(input, target) + out = F.nll_loss( + input, target, weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce + ) + return out.view(out_size) + else: + raise ValueError('Expected 2 or more dimensions (got {})'.format(dim)) + + +def calculate_loss(x_mean, x, z_mu, z_var, z_0, z_k, ldj, args, beta=1.): + """ + Picks the correct loss depending on the input type. + """ + + if args.input_type == 'binary': + loss, rec, kl = binary_loss_function(x_mean, x, z_mu, z_var, z_0, z_k, ldj, beta=beta) + bpd = 0. + + elif args.input_type == 'multinomial': + loss, rec, kl = multinomial_loss_function(x_mean, x, z_mu, z_var, z_0, z_k, ldj, args, beta=beta) + bpd = loss.data[0] / (np.prod(args.input_size) * np.log(2.)) + + else: + raise ValueError('Invalid input type for calculate loss: %s.' % args.input_type) + + return loss, rec, kl, bpd + + +def calculate_loss_array(x_mean, x, z_mu, z_var, z_0, z_k, ldj, args): + """ + Picks the correct loss depending on the input type. + """ + + if args.input_type == 'binary': + loss = binary_loss_array(x_mean, x, z_mu, z_var, z_0, z_k, ldj) + + elif args.input_type == 'multinomial': + loss = multinomial_loss_array(x_mean, x, z_mu, z_var, z_0, z_k, ldj, args) + + else: + raise ValueError('Invalid input type for calculate loss: %s.' % args.input_type) + + return loss diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/training.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/training.py new file mode 100644 index 0000000..e6f4b37 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/optimization/training.py @@ -0,0 +1,171 @@ +from __future__ import print_function +import time +import torch + +from vae_lib.optimization.loss import calculate_loss +from vae_lib.utils.visual_evaluation import plot_reconstructions +from vae_lib.utils.log_likelihood import calculate_likelihood + +import numpy as np +from train_misc import count_nfe, override_divergence_fn + + +def train(epoch, train_loader, model, opt, args, logger): + + model.train() + train_loss = np.zeros(len(train_loader)) + train_bpd = np.zeros(len(train_loader)) + + num_data = 0 + + # set warmup coefficient + beta = min([(epoch * 1.) / max([args.warmup, 1.]), args.max_beta]) + logger.info('beta = {:5.4f}'.format(beta)) + end = time.time() + for batch_idx, (data, _) in enumerate(train_loader): + if args.cuda: + data = data.cuda() + + if args.dynamic_binarization: + data = torch.bernoulli(data) + + data = data.view(-1, *args.input_size) + + opt.zero_grad() + x_mean, z_mu, z_var, ldj, z0, zk = model(data) + + if 'cnf' in args.flow: + f_nfe = count_nfe(model) + + loss, rec, kl, bpd = calculate_loss(x_mean, data, z_mu, z_var, z0, zk, ldj, args, beta=beta) + + loss.backward() + + if 'cnf' in args.flow: + t_nfe = count_nfe(model) + b_nfe = t_nfe - f_nfe + + train_loss[batch_idx] = loss.item() + train_bpd[batch_idx] = bpd + + opt.step() + + rec = rec.item() + kl = kl.item() + + num_data += len(data) + + batch_time = time.time() - end + end = time.time() + + if batch_idx % args.log_interval == 0: + if args.input_type == 'binary': + perc = 100. * batch_idx / len(train_loader) + log_msg = ( + 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | ' + 'Rec {:11.6f} | KL {:11.6f}'.format( + epoch, num_data, len(train_loader.sampler), perc, batch_time, loss.item(), rec, kl + ) + ) + else: + perc = 100. * batch_idx / len(train_loader) + tmp = 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | Bits/dim {:8.6f}' + log_msg = tmp.format(epoch, num_data, len(train_loader.sampler), perc, batch_time, loss.item(), + bpd), '\trec: {:11.3f}\tkl: {:11.6f}'.format(rec, kl) + log_msg = "".join(log_msg) + if 'cnf' in args.flow: + log_msg += ' | NFE Forward {} | NFE Backward {}'.format(f_nfe, b_nfe) + logger.info(log_msg) + + if args.input_type == 'binary': + logger.info('====> Epoch: {:3d} Average train loss: {:.4f}'.format(epoch, train_loss.sum() / len(train_loader))) + else: + logger.info( + '====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'. + format(epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)) + ) + + return train_loss + + +def evaluate(data_loader, model, args, logger, testing=False, epoch=0): + model.eval() + loss = 0. + batch_idx = 0 + bpd = 0. + + if args.input_type == 'binary': + loss_type = 'elbo' + else: + loss_type = 'bpd' + + if testing and 'cnf' in args.flow: + override_divergence_fn(model, "brute_force") + + for data, _ in data_loader: + batch_idx += 1 + + if args.cuda: + data = data.cuda() + + with torch.no_grad(): + data = data.view(-1, *args.input_size) + + x_mean, z_mu, z_var, ldj, z0, zk = model(data) + + batch_loss, rec, kl, batch_bpd = calculate_loss(x_mean, data, z_mu, z_var, z0, zk, ldj, args) + + bpd += batch_bpd + loss += batch_loss.item() + + # PRINT RECONSTRUCTIONS + if batch_idx == 1 and testing is False: + plot_reconstructions(data, x_mean, batch_loss, loss_type, epoch, args) + + loss /= len(data_loader) + bpd /= len(data_loader) + + if testing: + logger.info('====> Test set loss: {:.4f}'.format(loss)) + + # Compute log-likelihood + if testing and not ("cnf" in args.flow): # don't compute log-likelihood for cnf models + + with torch.no_grad(): + test_data = data_loader.dataset.tensors[0] + + if args.cuda: + test_data = test_data.cuda() + + logger.info('Computing log-likelihood on test set') + + model.eval() + + if args.dataset == 'caltech': + log_likelihood, nll_bpd = calculate_likelihood(test_data, model, args, logger, S=2000, MB=500) + else: + log_likelihood, nll_bpd = calculate_likelihood(test_data, model, args, logger, S=5000, MB=500) + + if 'cnf' in args.flow: + override_divergence_fn(model, args.divergence_fn) + else: + log_likelihood = None + nll_bpd = None + + if args.input_type in ['multinomial']: + bpd = loss / (np.prod(args.input_size) * np.log(2.)) + + if testing and not ("cnf" in args.flow): + logger.info('====> Test set log-likelihood: {:.4f}'.format(log_likelihood)) + + if args.input_type != 'binary': + logger.info('====> Test set bpd (elbo): {:.4f}'.format(bpd)) + logger.info( + '====> Test set bpd (log-likelihood): {:.4f}'. + format(log_likelihood / (np.prod(args.input_size) * np.log(2.))) + ) + + if not testing: + return loss, bpd + else: + return log_likelihood, nll_bpd diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/__init__.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/distributions.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/distributions.py new file mode 100644 index 0000000..c58e4bc --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/distributions.py @@ -0,0 +1,65 @@ +from __future__ import print_function +import torch +import torch.utils.data + +import math + +MIN_EPSILON = 1e-5 +MAX_EPSILON = 1. - 1e-5 + +PI = torch.FloatTensor([math.pi]) +if torch.cuda.is_available(): + PI = PI.cuda() + +# N(x | mu, var) = 1/sqrt{2pi var} exp[-1/(2 var) (x-mean)(x-mean)] +# log N(x| mu, var) = -log sqrt(2pi) -0.5 log var - 0.5 (x-mean)(x-mean)/var + + +def log_normal_diag(x, mean, log_var, average=False, reduce=True, dim=None): + log_norm = -0.5 * (log_var + (x - mean) * (x - mean) * log_var.exp().reciprocal()) + if reduce: + if average: + return torch.mean(log_norm, dim) + else: + return torch.sum(log_norm, dim) + else: + return log_norm + + +def log_normal_normalized(x, mean, log_var, average=False, reduce=True, dim=None): + log_norm = -(x - mean) * (x - mean) + log_norm *= torch.reciprocal(2. * log_var.exp()) + log_norm += -0.5 * log_var + log_norm += -0.5 * torch.log(2. * PI) + + if reduce: + if average: + return torch.mean(log_norm, dim) + else: + return torch.sum(log_norm, dim) + else: + return log_norm + + +def log_normal_standard(x, average=False, reduce=True, dim=None): + log_norm = -0.5 * x * x + + if reduce: + if average: + return torch.mean(log_norm, dim) + else: + return torch.sum(log_norm, dim) + else: + return log_norm + + +def log_bernoulli(x, mean, average=False, reduce=True, dim=None): + probs = torch.clamp(mean, min=MIN_EPSILON, max=MAX_EPSILON) + log_bern = x * torch.log(probs) + (1. - x) * torch.log(1. - probs) + if reduce: + if average: + return torch.mean(log_bern, dim) + else: + return torch.sum(log_bern, dim) + else: + return log_bern diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/load_data.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/load_data.py new file mode 100644 index 0000000..ad1a836 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/load_data.py @@ -0,0 +1,205 @@ +from __future__ import print_function + +import torch +import torch.utils.data as data_utils +import pickle +from scipy.io import loadmat + +import numpy as np + +import os + + +def load_static_mnist(args, **kwargs): + """ + Dataloading function for static mnist. Outputs image data in vectorized form: each image is a vector of size 784 + """ + args.dynamic_binarization = False + args.input_type = 'binary' + + args.input_size = [1, 28, 28] + + # start processing + def lines_to_np_array(lines): + return np.array([[int(i) for i in line.split()] for line in lines]) + + with open(os.path.join('data', 'MNIST_static', 'binarized_mnist_train.amat')) as f: + lines = f.readlines() + x_train = lines_to_np_array(lines).astype('float32') + with open(os.path.join('data', 'MNIST_static', 'binarized_mnist_valid.amat')) as f: + lines = f.readlines() + x_val = lines_to_np_array(lines).astype('float32') + with open(os.path.join('data', 'MNIST_static', 'binarized_mnist_test.amat')) as f: + lines = f.readlines() + x_test = lines_to_np_array(lines).astype('float32') + + # shuffle train data + np.random.shuffle(x_train) + + # idle y's + y_train = np.zeros((x_train.shape[0], 1)) + y_val = np.zeros((x_val.shape[0], 1)) + y_test = np.zeros((x_test.shape[0], 1)) + + # pytorch data loader + train = data_utils.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) + train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs) + + validation = data_utils.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val)) + val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs) + + test = data_utils.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test)) + test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, args + + +def load_freyfaces(args, **kwargs): + # set args + args.input_size = [1, 28, 20] + args.input_type = 'multinomial' + args.dynamic_binarization = False + + TRAIN = 1565 + VAL = 200 + TEST = 200 + + # start processing + with open('data/Freyfaces/freyfaces.pkl', 'rb') as f: + data = pickle.load(f, encoding="latin1")[0] + + data = data / 255. + + # NOTE: shuffling is done before splitting into train and test set, so test set is different for every run! + # shuffle data: + np.random.seed(args.freyseed) + + np.random.shuffle(data) + + # train images + x_train = data[0:TRAIN].reshape(-1, 28 * 20) + # validation images + x_val = data[TRAIN:(TRAIN + VAL)].reshape(-1, 28 * 20) + # test images + x_test = data[(TRAIN + VAL):(TRAIN + VAL + TEST)].reshape(-1, 28 * 20) + + # idle y's + y_train = np.zeros((x_train.shape[0], 1)) + y_val = np.zeros((x_val.shape[0], 1)) + y_test = np.zeros((x_test.shape[0], 1)) + + # pytorch data loader + train = data_utils.TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train)) + train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs) + + validation = data_utils.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val)) + val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs) + + test = data_utils.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test)) + test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs) + return train_loader, val_loader, test_loader, args + + +def load_omniglot(args, **kwargs): + n_validation = 1345 + + # set args + args.input_size = [1, 28, 28] + args.input_type = 'binary' + args.dynamic_binarization = True + + # start processing + def reshape_data(data): + return data.reshape((-1, 28, 28)).reshape((-1, 28 * 28), order='F') + + omni_raw = loadmat(os.path.join('data', 'OMNIGLOT', 'chardata.mat')) + + # train and test data + train_data = reshape_data(omni_raw['data'].T.astype('float32')) + x_test = reshape_data(omni_raw['testdata'].T.astype('float32')) + + # shuffle train data + np.random.shuffle(train_data) + + # set train and validation data + x_train = train_data[:-n_validation] + x_val = train_data[-n_validation:] + + # binarize + if args.dynamic_binarization: + args.input_type = 'binary' + np.random.seed(777) + x_val = np.random.binomial(1, x_val) + x_test = np.random.binomial(1, x_test) + else: + args.input_type = 'gray' + + # idle y's + y_train = np.zeros((x_train.shape[0], 1)) + y_val = np.zeros((x_val.shape[0], 1)) + y_test = np.zeros((x_test.shape[0], 1)) + + # pytorch data loader + train = data_utils.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) + train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs) + + validation = data_utils.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val)) + val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs) + + test = data_utils.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test)) + test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, args + + +def load_caltech101silhouettes(args, **kwargs): + # set args + args.input_size = [1, 28, 28] + args.input_type = 'binary' + args.dynamic_binarization = False + + # start processing + def reshape_data(data): + return data.reshape((-1, 28, 28)).reshape((-1, 28 * 28), order='F') + + caltech_raw = loadmat(os.path.join('data', 'Caltech101Silhouettes', 'caltech101_silhouettes_28_split1.mat')) + + # train, validation and test data + x_train = 1. - reshape_data(caltech_raw['train_data'].astype('float32')) + np.random.shuffle(x_train) + x_val = 1. - reshape_data(caltech_raw['val_data'].astype('float32')) + np.random.shuffle(x_val) + x_test = 1. - reshape_data(caltech_raw['test_data'].astype('float32')) + + y_train = caltech_raw['train_labels'] + y_val = caltech_raw['val_labels'] + y_test = caltech_raw['test_labels'] + + # pytorch data loader + train = data_utils.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) + train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs) + + validation = data_utils.TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val)) + val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs) + + test = data_utils.TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test)) + test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, args + + +def load_dataset(args, **kwargs): + + if args.dataset == 'mnist': + train_loader, val_loader, test_loader, args = load_static_mnist(args, **kwargs) + elif args.dataset == 'caltech': + train_loader, val_loader, test_loader, args = load_caltech101silhouettes(args, **kwargs) + + elif args.dataset == 'freyfaces': + train_loader, val_loader, test_loader, args = load_freyfaces(args, **kwargs) + elif args.dataset == 'omniglot': + train_loader, val_loader, test_loader, args = load_omniglot(args, **kwargs) + else: + raise Exception('Wrong name of the dataset!') + + return train_loader, val_loader, test_loader, args diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/log_likelihood.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/log_likelihood.py new file mode 100644 index 0000000..b5461b4 --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/log_likelihood.py @@ -0,0 +1,60 @@ +from __future__ import print_function +import time +import numpy as np +from scipy.misc import logsumexp +from vae_lib.optimization.loss import calculate_loss_array + + +def calculate_likelihood(X, model, args, logger, S=5000, MB=500): + + # set auxiliary variables for number of training and test sets + N_test = X.size(0) + + X = X.view(-1, *args.input_size) + + likelihood_test = [] + + if S <= MB: + R = 1 + else: + R = S // MB + S = MB + + end = time.time() + for j in range(N_test): + + x_single = X[j].unsqueeze(0) + + a = [] + for r in range(0, R): + # Repeat it for all training points + x = x_single.expand(S, *x_single.size()[1:]).contiguous() + + x_mean, z_mu, z_var, ldj, z0, zk = model(x) + + a_tmp = calculate_loss_array(x_mean, x, z_mu, z_var, z0, zk, ldj, args) + + a.append(-a_tmp.cpu().data.numpy()) + + # calculate max + a = np.asarray(a) + a = np.reshape(a, (a.shape[0] * a.shape[1], 1)) + likelihood_x = logsumexp(a) + likelihood_test.append(likelihood_x - np.log(len(a))) + + if j % 1 == 0: + logger.info('Progress: {:.2f}% | Time: {:.4f}'.format(j / (1. * N_test) * 100, time.time() - end)) + end = time.time() + + likelihood_test = np.array(likelihood_test) + + nll = -np.mean(likelihood_test) + + if args.input_type == 'multinomial': + bpd = nll / (np.prod(args.input_size) * np.log(2.)) + elif args.input_type == 'binary': + bpd = 0. + else: + raise ValueError('invalid input type!') + + return nll, bpd diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/plotting.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/plotting.py new file mode 100644 index 0000000..e5c614f --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/plotting.py @@ -0,0 +1,104 @@ +from __future__ import division +from __future__ import print_function + +import numpy as np +import matplotlib +# noninteractive background +matplotlib.use('Agg') +import matplotlib.pyplot as plt + + +def plot_training_curve(train_loss, validation_loss, fname='training_curve.pdf', labels=None): + """ + Plots train_loss and validation loss as a function of optimization iteration + :param train_loss: np.array of train_loss (1D or 2D) + :param validation_loss: np.array of validation loss (1D or 2D) + :param fname: output file name + :param labels: if train_loss and validation loss are 2D, then labels indicate which variable is varied + accross training curves. + :return: None + """ + + plt.close() + + matplotlib.rcParams.update({'font.size': 14}) + matplotlib.rcParams['mathtext.fontset'] = 'stix' + matplotlib.rcParams['font.family'] = 'STIXGeneral' + + if len(train_loss.shape) == 1: + # Single training curve + fig, ax = plt.subplots(nrows=1, ncols=1) + figsize = (6, 4) + + if train_loss.shape[0] == validation_loss.shape[0]: + # validation score evaluated every iteration + x = np.arange(train_loss.shape[0]) + ax.plot(x, train_loss, '-', lw=2., color='black', label='train') + ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val') + + elif train_loss.shape[0] % validation_loss.shape[0] == 0: + # validation score evaluated every epoch + x = np.arange(train_loss.shape[0]) + ax.plot(x, train_loss, '-', lw=2., color='black', label='train') + + x = np.arange(validation_loss.shape[0]) + x = (x + 1) * train_loss.shape[0] / validation_loss.shape[0] + ax.plot(x, validation_loss, '-', lw=2., color='blue', label='val') + else: + raise ValueError('Length of train_loss and validation_loss must be equal or divisible') + + miny = np.minimum(validation_loss.min(), train_loss.min()) - 20. + maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30. + ax.set_ylim([miny, maxy]) + + elif len(train_loss.shape) == 2: + # Multiple training curves + + cmap = plt.cm.brg + + cNorm = matplotlib.colors.Normalize(vmin=0, vmax=train_loss.shape[0]) + scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap) + + fig, ax = plt.subplots(nrows=1, ncols=1) + figsize = (6, 4) + + if labels is None: + labels = ['%d' % i for i in range(train_loss.shape[0])] + + if train_loss.shape[1] == validation_loss.shape[1]: + for i in range(train_loss.shape[0]): + color_val = scalarMap.to_rgba(i) + + # validation score evaluated every iteration + x = np.arange(train_loss.shape[0]) + ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i]) + ax.plot(x, validation_loss[i], '--', lw=2., color=color_val) + + elif train_loss.shape[1] % validation_loss.shape[1] == 0: + for i in range(train_loss.shape[0]): + color_val = scalarMap.to_rgba(i) + + # validation score evaluated every epoch + x = np.arange(train_loss.shape[1]) + ax.plot(x, train_loss[i], '-', lw=2., color=color_val, label=labels[i]) + + x = np.arange(validation_loss.shape[1]) + x = (x + 1) * train_loss.shape[1] / validation_loss.shape[1] + ax.plot(x, validation_loss[i], '-', lw=2., color=color_val) + + miny = np.minimum(validation_loss.min(), train_loss.min()) - 20. + maxy = np.maximum(validation_loss.max(), train_loss.max()) + 30. + ax.set_ylim([miny, maxy]) + + else: + raise ValueError('train_loss and validation_loss must be 1D or 2D arrays') + + ax.set_xlabel('iteration') + ax.set_ylabel('loss') + plt.title('Training and validation loss') + + fig.set_size_inches(figsize) + fig.subplots_adjust(hspace=0.1) + plt.savefig(fname, bbox_inches='tight') + + plt.close() diff --git a/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/visual_evaluation.py b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/visual_evaluation.py new file mode 100644 index 0000000..1281dea --- /dev/null +++ b/src/torchprune/torchprune/util/external/ffjord/vae_lib/utils/visual_evaluation.py @@ -0,0 +1,53 @@ +from __future__ import print_function +import os +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec + + +def plot_reconstructions(data, recon_mean, loss, loss_type, epoch, args): + + if args.input_type == 'multinomial': + # data is already between 0 and 1 + num_classes = 256 + # Find largest class logit + tmp = recon_mean.view(-1, num_classes, *args.input_size).max(dim=1)[1] + recon_mean = tmp.float() / (num_classes - 1.) + if epoch == 1: + if not os.path.exists(args.snap_dir + 'reconstruction/'): + os.makedirs(args.snap_dir + 'reconstruction/') + # VISUALIZATION: plot real images + plot_images(args, data.data.cpu().numpy()[0:9], args.snap_dir + 'reconstruction/', 'real', size_x=3, size_y=3) + # VISUALIZATION: plot reconstructions + if loss_type == 'bpd': + fname = str(epoch) + '_bpd_%5.3f' % loss + elif loss_type == 'elbo': + fname = str(epoch) + '_elbo_%6.4f' % loss + plot_images(args, recon_mean.data.cpu().numpy()[0:9], args.snap_dir + 'reconstruction/', fname, size_x=3, size_y=3) + + +def plot_images(args, x_sample, dir, file_name, size_x=3, size_y=3): + + fig = plt.figure(figsize=(size_x, size_y)) + # fig = plt.figure(1) + gs = gridspec.GridSpec(size_x, size_y) + gs.update(wspace=0.05, hspace=0.05) + + for i, sample in enumerate(x_sample): + ax = plt.subplot(gs[i]) + plt.axis('off') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_aspect('equal') + sample = sample.reshape((args.input_size[0], args.input_size[1], args.input_size[2])) + sample = sample.swapaxes(0, 2) + sample = sample.swapaxes(0, 1) + if (args.input_type == 'binary') or (args.input_type in ['multinomial'] and args.input_size[0] == 1): + sample = sample[:, :, 0] + plt.imshow(sample, cmap='gray', vmin=0, vmax=1) + else: + plt.imshow(sample) + + plt.savefig(dir + file_name + '.png', bbox_inches='tight') + plt.close(fig) diff --git a/src/torchprune/torchprune/util/metrics.py b/src/torchprune/torchprune/util/metrics.py index 8349f3b..39f2b5f 100644 --- a/src/torchprune/torchprune/util/metrics.py +++ b/src/torchprune/torchprune/util/metrics.py @@ -4,6 +4,8 @@ from scipy import stats import torch +from .nn_loss import NLLPriorLoss, NLLNatsLoss, NLLBitsLoss + class AbstractMetric(ABC): """Functor template for metric.""" @@ -29,7 +31,7 @@ def __call__(self, output, target): # check if output is dict if isinstance(output, dict): if "out" in output: - # Segmentation networks + # Segmentation networks and ffjord networks output = output["out"] elif "logits" in output: # BERT @@ -263,3 +265,60 @@ def short_name(self): def _get_metric(self, output, target): """Compute metric and return as 0d tensor.""" return torch.tensor(0.0) + + +class NLLPrior(AbstractMetric): + """A wrapper for the NLLPriorLoss as metric.""" + + @property + def name(self): + """Get the display name of this metric.""" + return "Negative Log Likelihood" + + @property + def short_name(self): + """Get the short name of this metric.""" + return "NLL" + + def _get_metric(self, output, target): + """Compute metric and return as 0d tensor.""" + # since for the metric higher is better we negate the loss + return -(NLLPriorLoss()(output["output"], target)) + + def __call__(self, output, target): + """Call metric like output but wrap output so we keep dictionary.""" + return super().__call__({"output": output}, target) + + +class NLLNats(NLLPrior): + """A wrapper for the NLLNatsLoss as metric.""" + + @property + def name(self): + """Get the display name of this metric.""" + return "Negative Log Probability (nats)" + + @property + def short_name(self): + """Get the short name of this metric.""" + return "Nats" + + def _get_metric(self, output, target): + return -(NLLNatsLoss()(output["output"], target)) + + +class NLLBits(NLLPrior): + """A wrapper for the NLLBitsLoss as metric.""" + + @property + def name(self): + """Get the display name of this metric.""" + return "Negative Log Probability (bits/dim)" + + @property + def short_name(self): + """Get the short name of this metric.""" + return "Bits" + + def _get_metric(self, output, target): + return -(NLLBitsLoss()(output["output"], target)) diff --git a/src/torchprune/torchprune/util/models/__init__.py b/src/torchprune/torchprune/util/models/__init__.py index 94d9b5b..fb86d1b 100644 --- a/src/torchprune/torchprune/util/models/__init__.py +++ b/src/torchprune/torchprune/util/models/__init__.py @@ -1,13 +1,18 @@ # flake8: noqa: F401, F403 """Package with all custom net implementation and CIFAR nets.""" # import custom nets -from .fcnet import FCNet, lenet300_100, lenet500_300_100 +from .fcnet import FCNet, lenet300_100, lenet500_300_100, fcnet_nettrim from .lenet5 import lenet5 from .deepknight import deepknight from .deeplab import * from .cnn60k import cnn60k from .cnn5 import cnn5 from .bert import bert +from .node import * +from .cnf import * +from .ffjord import * +from .ffjord_tabular import * +from .ffjord_cnf import * # import cifar nets from ..external.cnn.models.cifar import * diff --git a/src/torchprune/torchprune/util/models/cnf.py b/src/torchprune/torchprune/util/models/cnf.py new file mode 100644 index 0000000..1e8cbfa --- /dev/null +++ b/src/torchprune/torchprune/util/models/cnf.py @@ -0,0 +1,452 @@ +"""Module containing various FFjord NODE configurations with torchdyn lib.""" + +import torch +import torch.nn as nn +from torchdyn.models import autograd_trace + +from .ffjord import Ffjord + + +class VanillaCNF(Ffjord): + """Neural ODEs for CNFs via brute-force trace estimator.""" + + @property + def trace_estimator(self): + """Return the desired trace estimator.""" + return autograd_trace + + +def cnf_l4_h64_sigmoid(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + ) + + +def cnf_l4_h64_softplus(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and softplus.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Softplus, + ) + + +def cnf_l4_h64_tanh(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and tanh.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Tanh, + ) + + +def cnf_l4_h64_relu(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and relu.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.ReLU, + ) + + +def cnf_l8_h64_sigmoid(num_classes): + """Return a brute-force CNF with 8 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=8, + hidden_size=64, + module_activate=nn.Sigmoid, + ) + + +def cnf_l2_h128_sigmoid(num_classes): + """Return a brute-force CNF with 2 layers, 128 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=2, + hidden_size=128, + module_activate=nn.Sigmoid, + ) + + +def cnf_l2_h64_sigmoid(num_classes): + """Return a brute-force CNF with 2 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.Sigmoid, + ) + + +def cnf_l4_h64_sigmoid_dopri_adjoint(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h64_sigmoid_dopri_autograd(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="autograd", + solver="dopri5", + ) + + +def cnf_l4_h64_sigmoid_rk4_autograd(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 20), + sensitivity="autograd", + solver="rk4", + ) + + +def cnf_l4_h64_sigmoid_rk4_adjoint(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 20), + sensitivity="adjoint", + solver="rk4", + ) + + +def cnf_l4_h64_sigmoid_euler_autograd(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 100), + sensitivity="autograd", + solver="euler", + ) + + +def cnf_l4_h64_sigmoid_euler_adjoint(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 100), + sensitivity="adjoint", + solver="euler", + ) + + +def cnf_l4_h64_sigmoid_da(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h64_softplus_da(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and softplus.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Softplus, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h64_tanh_da(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and tanh.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h64_relu_da(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and relu.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.ReLU, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l8_h64_sigmoid_da(num_classes): + """Return a brute-force CNF with 8 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=8, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l8_h37_sigmoid_da(num_classes): + """Return a brute-force CNF with 8 layers, 37 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=8, + hidden_size=37, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l8_h18_sigmoid_da(num_classes): + """Return a brute-force CNF with 8 layers, 18 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=8, + hidden_size=18, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l8_h10_sigmoid_da(num_classes): + """Return a brute-force CNF with 8 layers, 10 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=8, + hidden_size=10, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l6_h45_sigmoid_da(num_classes): + """Return a brute-force CNF with 6 layers, 45 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=6, + hidden_size=45, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l6_h22_sigmoid_da(num_classes): + """Return a brute-force CNF with 6 layers, 22 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=6, + hidden_size=22, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l6_h12_sigmoid_da(num_classes): + """Return a brute-force CNF with 6 layers, 12 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=6, + hidden_size=12, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h128_sigmoid_da(num_classes): + """Return a brute-force CNF with 4 layers, 128 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=128, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h30_sigmoid_da(num_classes): + """Return a brute-force CNF with 4 layers, 30 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=30, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h17_sigmoid_da(num_classes): + """Return a brute-force CNF with 4 layers, 17 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=17, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l3_h90_sigmoid_da(num_classes): + """Return a brute-force CNF with 3 layers, 90 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=3, + hidden_size=90, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l3_h43_sigmoid_da(num_classes): + """Return a brute-force CNF with 3 layers, 43 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=3, + hidden_size=43, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l3_h23_sigmoid_da(num_classes): + """Return a brute-force CNF with 3 layers, 23 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=3, + hidden_size=23, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l2_h1700_sigmoid_da(num_classes): + """Return a brute-force CNF with 2 layers, 1700 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=2, + hidden_size=1700, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l2_h400_sigmoid_da(num_classes): + """Return a brute-force CNF with 2 layers, 400 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=2, + hidden_size=400, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l2_h128_sigmoid_da(num_classes): + """Return a brute-force CNF with 2 layers, 128 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=2, + hidden_size=128, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l2_h64_sigmoid_da(num_classes): + """Return a brute-force CNF with 2 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def cnf_l4_h64_sigmoid_da_high_tol(num_classes): + """Return a brute-force CNF with 4 layers, 64 neurons, and sigmoid.""" + return VanillaCNF( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + atol=1e-4, + rtol=1e-4, + ) diff --git a/src/torchprune/torchprune/util/models/cnn/README.md b/src/torchprune/torchprune/util/models/cnn/README.md deleted file mode 100644 index 707c6fd..0000000 --- a/src/torchprune/torchprune/util/models/cnn/README.md +++ /dev/null @@ -1,74 +0,0 @@ -# pytorch-classification -Classification on CIFAR-10/100 and ImageNet with PyTorch. - -## Features -* Unified interface for different network architectures -* Multi-GPU support -* Training progress bar with rich info -* Training log and training curve visualization code (see `./utils/logger.py`) - -## Install -* Install [PyTorch](http://pytorch.org/) -* Clone recursively - ``` - git clone --recursive https://github.com/bearpaw/pytorch-classification.git - ``` - -## Training -Please see the [Training recipes](TRAINING.md) for how to train the models. - -## Results - -### CIFAR -Top1 error rate on the CIFAR-10/100 benchmarks are reported. You may get different results when training your models with different random seed. -Note that the number of parameters are computed on the CIFAR-10 dataset. - -| Model | Params (M) | CIFAR-10 (%) | CIFAR-100 (%) | -| ------------------- | ------------------ | ------------------ | ------------------ | -| alexnet | 2.47 | 22.78 | 56.13 | -| vgg19_bn | 20.04 | 6.66 | 28.05 | -| ResNet-110 | 1.70 | 6.11 | 28.86 | -| PreResNet-110 | 1.70 | 4.94 | 23.65 | -| WRN-28-10 (drop 0.3) | 36.48 | 3.79 | 18.14 | -| ResNeXt-29, 8x64 | 34.43 | 3.69 | 17.38 | -| ResNeXt-29, 16x64 | 68.16 | 3.53 | 17.30 | -| DenseNet-BC (L=100, k=12) | 0.77 | 4.54 | 22.88 | -| DenseNet-BC (L=190, k=40) | 25.62 | 3.32 | 17.17 | - - -![cifar](utils/images/cifar.png) - -### ImageNet -Single-crop (224x224) validation error rate is reported. - - -| Model | Params (M) | Top-1 Error (%) | Top-5 Error (%) | -| ------------------- | ------------------ | ------------------ | ------------------ | -| ResNet-18 | 11.69 | 30.09 | 10.78 | -| ResNeXt-50 (32x4d) | 25.03 | 22.6 | 6.29 | - -![Validation curve](utils/images/imagenet.png) - -## Pretrained models -Our trained models and training logs are downloadable at [OneDrive](https://mycuhk-my.sharepoint.com/personal/1155056070_link_cuhk_edu_hk/_layouts/15/guestaccess.aspx?folderid=0a380d1fece1443f0a2831b761df31905&authkey=Ac5yBC-FSE4oUJZ2Lsx7I5c). - -## Supported Architectures - -### CIFAR-10 / CIFAR-100 -Since the size of images in CIFAR dataset is `32x32`, popular network structures for ImageNet need some modifications to adapt this input size. The modified models is in the package `models.cifar`: -- [x] [AlexNet](https://arxiv.org/abs/1404.5997) -- [x] [VGG](https://arxiv.org/abs/1409.1556) (Imported from [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar)) -- [x] [ResNet](https://arxiv.org/abs/1512.03385) -- [x] [Pre-act-ResNet](https://arxiv.org/abs/1603.05027) -- [x] [ResNeXt](https://arxiv.org/abs/1611.05431) (Imported from [ResNeXt.pytorch](https://github.com/prlz77/ResNeXt.pytorch)) -- [x] [Wide Residual Networks](http://arxiv.org/abs/1605.07146) (Imported from [WideResNet-pytorch](https://github.com/xternalz/WideResNet-pytorch)) -- [x] [DenseNet](https://arxiv.org/abs/1608.06993) - -### ImageNet -- [x] All models in `torchvision.models` (alexnet, vgg, resnet, densenet, inception_v3, squeezenet) -- [x] [ResNeXt](https://arxiv.org/abs/1611.05431) -- [ ] [Wide Residual Networks](http://arxiv.org/abs/1605.07146) - - -## Contribute -Feel free to create a pull request if you find any bugs or you want to contribute (e.g., more datasets and more network structures). diff --git a/src/torchprune/torchprune/util/models/cnn/TRAINING.md b/src/torchprune/torchprune/util/models/cnn/TRAINING.md deleted file mode 100644 index ff140ab..0000000 --- a/src/torchprune/torchprune/util/models/cnn/TRAINING.md +++ /dev/null @@ -1,119 +0,0 @@ - -## CIFAR-10 - -#### AlexNet -``` -python cifar.py -a alexnet --epochs 164 --schedule 81 122 --gamma 0.1 --checkpoint checkpoints/cifar10/alexnet -``` - - -#### VGG19 (BN) -``` -python cifar.py -a vgg19_bn --epochs 164 --schedule 81 122 --gamma 0.1 --checkpoint checkpoints/cifar10/vgg19_bn -``` - -#### ResNet-110 -``` -python cifar.py -a resnet --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar10/resnet-110 -``` - -#### ResNet-1202 -``` -python cifar.py -a resnet --depth 1202 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar10/resnet-1202 -``` - -#### PreResNet-110 -``` -python cifar.py -a preresnet --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar10/preresnet-110 -``` - -#### ResNeXt-29, 8x64d -``` -python cifar.py -a resnext --depth 29 --cardinality 8 --widen-factor 4 --schedule 150 225 --wd 5e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/resnext-8x64d -``` -#### ResNeXt-29, 16x64d -``` -python cifar.py -a resnext --depth 29 --cardinality 16 --widen-factor 4 --schedule 150 225 --wd 5e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/resnext-16x64d -``` - -#### WRN-28-10-drop -``` -python cifar.py -a wrn --depth 28 --depth 28 --widen-factor 10 --drop 0.3 --epochs 200 --schedule 60 120 160 --wd 5e-4 --gamma 0.2 --checkpoint checkpoints/cifar10/WRN-28-10-drop -``` - -#### DenseNet-BC (L=100, k=12) -**Note**: -* DenseNet use weight decay value `1e-4`. Larger weight decay (`5e-4`) if harmful for the accuracy (95.46 vs. 94.05) -* Official batch size is 64. But there is no big difference using batchsize 64 or 128 (95.46 vs 95.11). - -``` -python cifar.py -a densenet --depth 100 --growthRate 12 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/densenet-bc-100-12 -``` - -#### DenseNet-BC (L=190, k=40) -``` -python cifar.py -a densenet --depth 190 --growthRate 40 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar10/densenet-bc-L190-k40 -``` - -## CIFAR-100 - -#### AlexNet -``` -python cifar.py -a alexnet --dataset cifar100 --checkpoint checkpoints/cifar100/alexnet --epochs 164 --schedule 81 122 --gamma 0.1 -``` - -#### VGG19 (BN) -``` -python cifar.py -a vgg19_bn --dataset cifar100 --checkpoint checkpoints/cifar100/vgg19_bn --epochs 164 --schedule 81 122 --gamma 0.1 -``` - -#### ResNet-110 -``` -python cifar.py -a resnet --dataset cifar100 --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/resnet-110 -``` - -#### ResNet-1202 -``` -python cifar.py -a resnet --dataset cifar100 --depth 1202 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/resnet-1202 -``` - -#### PreResNet-110 -``` -python cifar.py -a preresnet --dataset cifar100 --depth 110 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint checkpoints/cifar100/preresnet-110 -``` - -#### ResNeXt-29, 8x64d -``` -python cifar.py -a resnext --dataset cifar100 --depth 29 --cardinality 8 --widen-factor 4 --checkpoint checkpoints/cifar100/resnext-8x64d --schedule 150 225 --wd 5e-4 --gamma 0.1 -``` -#### ResNeXt-29, 16x64d -``` -python cifar.py -a resnext --dataset cifar100 --depth 29 --cardinality 16 --widen-factor 4 --checkpoint checkpoints/cifar100/resnext-16x64d --schedule 150 225 --wd 5e-4 --gamma 0.1 -``` - -#### WRN-28-10-drop -``` -python cifar.py -a wrn --dataset cifar100 --depth 28 --depth 28 --widen-factor 10 --drop 0.3 --epochs 200 --schedule 60 120 160 --wd 5e-4 --gamma 0.2 --checkpoint checkpoints/cifar100/WRN-28-10-drop -``` - -#### DenseNet-BC (L=100, k=12) -``` -python cifar.py -a densenet --dataset cifar100 --depth 100 --growthRate 12 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar100/densenet-bc-100-12 -``` - -#### DenseNet-BC (L=190, k=40) -``` -python cifar.py -a densenet --dataset cifar100 --depth 190 --growthRate 40 --train-batch 64 --epochs 300 --schedule 150 225 --wd 1e-4 --gamma 0.1 --checkpoint checkpoints/cifar100/densenet-bc-L190-k40 -``` - -## ImageNet -### ResNet-18 -``` -python imagenet.py -a resnet18 --data ~/dataset/ILSVRC2012/ --epochs 90 --schedule 31 61 --gamma 0.1 -c checkpoints/imagenet/resnet18 -``` - -### ResNeXt-50 (32x4d) -*(Originally trained on 8xGPUs)* -``` -python imagenet.py -a resnext50 --base-width 4 --cardinality 32 --data ~/dataset/ILSVRC2012/ --epochs 90 --schedule 31 61 --gamma 0.1 -c checkpoints/imagenet/resnext50-32x4d -``` \ No newline at end of file diff --git a/src/torchprune/torchprune/util/models/cnn/cifar.py b/src/torchprune/torchprune/util/models/cnn/cifar.py deleted file mode 100644 index 510e474..0000000 --- a/src/torchprune/torchprune/util/models/cnn/cifar.py +++ /dev/null @@ -1,350 +0,0 @@ -''' -Training script for CIFAR-10/100 -Copyright (c) Wei YANG, 2017 -''' -from __future__ import print_function - -import argparse -import os -import shutil -import time -import random - -import torch -import torch.nn as nn -import torch.nn.parallel -import torch.backends.cudnn as cudnn -import torch.optim as optim -import torch.utils.data as data -import torchvision.transforms as transforms -import torchvision.datasets as datasets -import models.cifar as models - -from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig - - -model_names = sorted(name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name])) - -parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100 Training') -# Datasets -parser.add_argument('-d', '--dataset', default='cifar10', type=str) -parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='number of data loading workers (default: 4)') -# Optimization options -parser.add_argument('--epochs', default=300, type=int, metavar='N', - help='number of total epochs to run') -parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='manual epoch number (useful on restarts)') -parser.add_argument('--train-batch', default=128, type=int, metavar='N', - help='train batchsize') -parser.add_argument('--test-batch', default=100, type=int, metavar='N', - help='test batchsize') -parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, - metavar='LR', help='initial learning rate') -parser.add_argument('--drop', '--dropout', default=0, type=float, - metavar='Dropout', help='Dropout ratio') -parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], - help='Decrease learning rate at these epochs.') -parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') -parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') -parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)') -# Checkpoints -parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', - help='path to save checkpoint (default: checkpoint)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') -# Architecture -parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet20', - choices=model_names, - help='model architecture: ' + - ' | '.join(model_names) + - ' (default: resnet18)') -parser.add_argument('--depth', type=int, default=29, help='Model depth.') -parser.add_argument('--cardinality', type=int, default=8, help='Model cardinality (group).') -parser.add_argument('--widen-factor', type=int, default=4, help='Widen factor. 4 -> 64, 8 -> 128, ...') -parser.add_argument('--growthRate', type=int, default=12, help='Growth rate for DenseNet.') -parser.add_argument('--compressionRate', type=int, default=2, help='Compression Rate (theta) for DenseNet.') -# Miscs -parser.add_argument('--manualSeed', type=int, help='manual seed') -parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', - help='evaluate model on validation set') -#Device options -parser.add_argument('--gpu-id', default='0', type=str, - help='id(s) for CUDA_VISIBLE_DEVICES') - -args = parser.parse_args() -state = {k: v for k, v in args._get_kwargs()} - -# Validate dataset -assert args.dataset == 'cifar10' or args.dataset == 'cifar100', 'Dataset can only be cifar10 or cifar100.' - -# Use CUDA -os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id -use_cuda = torch.cuda.is_available() - -# Random seed -if args.manualSeed is None: - args.manualSeed = random.randint(1, 10000) -random.seed(args.manualSeed) -torch.manual_seed(args.manualSeed) -if use_cuda: - torch.cuda.manual_seed_all(args.manualSeed) - -best_acc = 0 # best test accuracy - -def main(): - global best_acc - start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch - - if not os.path.isdir(args.checkpoint): - mkdir_p(args.checkpoint) - - - - # Data - print('==> Preparing dataset %s' % args.dataset) - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - - transform_test = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - if args.dataset == 'cifar10': - dataloader = datasets.CIFAR10 - num_classes = 10 - else: - dataloader = datasets.CIFAR100 - num_classes = 100 - - - trainset = dataloader(root='./data', train=True, download=True, transform=transform_train) - trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers) - - testset = dataloader(root='./data', train=False, download=False, transform=transform_test) - testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) - - # Model - print("==> creating model '{}'".format(args.arch)) - if args.arch.startswith('resnext'): - model = models.__dict__[args.arch]( - cardinality=args.cardinality, - num_classes=num_classes, - depth=args.depth, - widen_factor=args.widen_factor, - dropRate=args.drop, - ) - elif args.arch.startswith('densenet'): - model = models.__dict__[args.arch]( - num_classes=num_classes, - depth=args.depth, - growthRate=args.growthRate, - compressionRate=args.compressionRate, - dropRate=args.drop, - ) - elif args.arch.startswith('wrn'): - model = models.__dict__[args.arch]( - num_classes=num_classes, - depth=args.depth, - widen_factor=args.widen_factor, - dropRate=args.drop, - ) - elif args.arch.endswith('resnet'): - model = models.__dict__[args.arch]( - num_classes=num_classes, - depth=args.depth, - ) - else: - model = models.__dict__[args.arch](num_classes=num_classes) - - model = torch.nn.DataParallel(model).cuda() - cudnn.benchmark = True - print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) - - criterion = nn.CrossEntropyLoss() - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - - # Resume - title = 'cifar-10-' + args.arch - if args.resume: - # Load checkpoint. - print('==> Resuming from checkpoint..') - assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' - args.checkpoint = os.path.dirname(args.resume) - checkpoint = torch.load(args.resume) - best_acc = checkpoint['best_acc'] - start_epoch = checkpoint['epoch'] - model.load_state_dict(checkpoint['state_dict']) - optimizer.load_state_dict(checkpoint['optimizer']) - logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) - else: - logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) - logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) - - - if args.evaluate: - print('\nEvaluation only') - test_loss, test_acc = test(testloader, model, criterion, start_epoch, use_cuda) - print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) - return - - # Train and val - for epoch in range(start_epoch, args.epochs): - adjust_learning_rate(optimizer, epoch) - - print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) - - train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda) - test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda) - - # append logger file - logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc]) - - # save model - is_best = test_acc > best_acc - best_acc = max(test_acc, best_acc) - save_checkpoint({ - 'epoch': epoch + 1, - 'state_dict': model.state_dict(), - 'acc': test_acc, - 'best_acc': best_acc, - 'optimizer' : optimizer.state_dict(), - }, is_best, checkpoint=args.checkpoint) - - logger.close() - logger.plot() - savefig(os.path.join(args.checkpoint, 'log.eps')) - - print('Best acc:') - print(best_acc) - -def train(trainloader, model, criterion, optimizer, epoch, use_cuda): - # switch to train mode - model.train() - - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - end = time.time() - - bar = Bar('Processing', max=len(trainloader)) - for batch_idx, (inputs, targets) in enumerate(trainloader): - # measure data loading time - data_time.update(time.time() - end) - - if use_cuda: - inputs, targets = inputs.cuda(), targets.cuda(async=True) - inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) - - # compute output - outputs = model(inputs) - loss = criterion(outputs, targets) - - # measure accuracy and record loss - prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) - losses.update(loss.data[0], inputs.size(0)) - top1.update(prec1[0], inputs.size(0)) - top5.update(prec5[0], inputs.size(0)) - - # compute gradient and do SGD step - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - # plot progress - bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( - batch=batch_idx + 1, - size=len(trainloader), - data=data_time.avg, - bt=batch_time.avg, - total=bar.elapsed_td, - eta=bar.eta_td, - loss=losses.avg, - top1=top1.avg, - top5=top5.avg, - ) - bar.next() - bar.finish() - return (losses.avg, top1.avg) - -def test(testloader, model, criterion, epoch, use_cuda): - global best_acc - - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - - # switch to evaluate mode - model.eval() - - end = time.time() - bar = Bar('Processing', max=len(testloader)) - for batch_idx, (inputs, targets) in enumerate(testloader): - # measure data loading time - data_time.update(time.time() - end) - - if use_cuda: - inputs, targets = inputs.cuda(), targets.cuda() - inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets) - - # compute output - outputs = model(inputs) - loss = criterion(outputs, targets) - - # measure accuracy and record loss - prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) - losses.update(loss.data[0], inputs.size(0)) - top1.update(prec1[0], inputs.size(0)) - top5.update(prec5[0], inputs.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - # plot progress - bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( - batch=batch_idx + 1, - size=len(testloader), - data=data_time.avg, - bt=batch_time.avg, - total=bar.elapsed_td, - eta=bar.eta_td, - loss=losses.avg, - top1=top1.avg, - top5=top5.avg, - ) - bar.next() - bar.finish() - return (losses.avg, top1.avg) - -def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): - filepath = os.path.join(checkpoint, filename) - torch.save(state, filepath) - if is_best: - shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) - -def adjust_learning_rate(optimizer, epoch): - global state - if epoch in args.schedule: - state['lr'] *= args.gamma - for param_group in optimizer.param_groups: - param_group['lr'] = state['lr'] - -if __name__ == '__main__': - main() diff --git a/src/torchprune/torchprune/util/models/cnn/imagenet.py b/src/torchprune/torchprune/util/models/cnn/imagenet.py deleted file mode 100644 index b1dace6..0000000 --- a/src/torchprune/torchprune/util/models/cnn/imagenet.py +++ /dev/null @@ -1,344 +0,0 @@ -''' -Training script for ImageNet -Copyright (c) Wei YANG, 2017 -''' -from __future__ import print_function - -import argparse -import os -import shutil -import time -import random - -import torch -import torch.nn as nn -import torch.nn.parallel -import torch.backends.cudnn as cudnn -import torch.optim as optim -import torch.utils.data as data -import torchvision.transforms as transforms -import torchvision.datasets as datasets -import torchvision.models as models -import models.imagenet as customized_models - -from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig - -# Models -default_model_names = sorted(name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name])) - -customized_models_names = sorted(name for name in customized_models.__dict__ - if name.islower() and not name.startswith("__") - and callable(customized_models.__dict__[name])) - -for name in customized_models.__dict__: - if name.islower() and not name.startswith("__") and callable(customized_models.__dict__[name]): - models.__dict__[name] = customized_models.__dict__[name] - -model_names = default_model_names + customized_models_names - -# Parse arguments -parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') - -# Datasets -parser.add_argument('-d', '--data', default='path to dataset', type=str) -parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='number of data loading workers (default: 4)') -# Optimization options -parser.add_argument('--epochs', default=90, type=int, metavar='N', - help='number of total epochs to run') -parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='manual epoch number (useful on restarts)') -parser.add_argument('--train-batch', default=256, type=int, metavar='N', - help='train batchsize (default: 256)') -parser.add_argument('--test-batch', default=200, type=int, metavar='N', - help='test batchsize (default: 200)') -parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, - metavar='LR', help='initial learning rate') -parser.add_argument('--drop', '--dropout', default=0, type=float, - metavar='Dropout', help='Dropout ratio') -parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], - help='Decrease learning rate at these epochs.') -parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') -parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') -parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)') -# Checkpoints -parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', - help='path to save checkpoint (default: checkpoint)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') -# Architecture -parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', - choices=model_names, - help='model architecture: ' + - ' | '.join(model_names) + - ' (default: resnet18)') -parser.add_argument('--depth', type=int, default=29, help='Model depth.') -parser.add_argument('--cardinality', type=int, default=32, help='ResNet cardinality (group).') -parser.add_argument('--base-width', type=int, default=4, help='ResNet base width.') -parser.add_argument('--widen-factor', type=int, default=4, help='Widen factor. 4 -> 64, 8 -> 128, ...') -# Miscs -parser.add_argument('--manualSeed', type=int, help='manual seed') -parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', - help='evaluate model on validation set') -parser.add_argument('--pretrained', dest='pretrained', action='store_true', - help='use pre-trained model') -#Device options -parser.add_argument('--gpu-id', default='0', type=str, - help='id(s) for CUDA_VISIBLE_DEVICES') - -args = parser.parse_args() -state = {k: v for k, v in args._get_kwargs()} - -# Use CUDA -os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id -use_cuda = torch.cuda.is_available() - -# Random seed -if args.manualSeed is None: - args.manualSeed = random.randint(1, 10000) -random.seed(args.manualSeed) -torch.manual_seed(args.manualSeed) -if use_cuda: - torch.cuda.manual_seed_all(args.manualSeed) - -best_acc = 0 # best test accuracy - -def main(): - global best_acc - start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch - - if not os.path.isdir(args.checkpoint): - mkdir_p(args.checkpoint) - - # Data loading code - traindir = os.path.join(args.data, 'train') - valdir = os.path.join(args.data, 'val') - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - - train_loader = torch.utils.data.DataLoader( - datasets.ImageFolder(traindir, transforms.Compose([ - transforms.RandomSizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])), - batch_size=args.train_batch, shuffle=True, - num_workers=args.workers, pin_memory=True) - - val_loader = torch.utils.data.DataLoader( - datasets.ImageFolder(valdir, transforms.Compose([ - transforms.Scale(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])), - batch_size=args.test_batch, shuffle=False, - num_workers=args.workers, pin_memory=True) - - # create model - if args.pretrained: - print("=> using pre-trained model '{}'".format(args.arch)) - model = models.__dict__[args.arch](pretrained=True) - elif args.arch.startswith('resnext'): - model = models.__dict__[args.arch]( - baseWidth=args.base_width, - cardinality=args.cardinality, - ) - else: - print("=> creating model '{}'".format(args.arch)) - model = models.__dict__[args.arch]() - - if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): - model.features = torch.nn.DataParallel(model.features) - model.cuda() - else: - model = torch.nn.DataParallel(model).cuda() - - cudnn.benchmark = True - print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) - - # define loss function (criterion) and optimizer - criterion = nn.CrossEntropyLoss().cuda() - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - - # Resume - title = 'ImageNet-' + args.arch - if args.resume: - # Load checkpoint. - print('==> Resuming from checkpoint..') - assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' - args.checkpoint = os.path.dirname(args.resume) - checkpoint = torch.load(args.resume) - best_acc = checkpoint['best_acc'] - start_epoch = checkpoint['epoch'] - model.load_state_dict(checkpoint['state_dict']) - optimizer.load_state_dict(checkpoint['optimizer']) - logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True) - else: - logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title) - logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.']) - - - if args.evaluate: - print('\nEvaluation only') - test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda) - print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc)) - return - - # Train and val - for epoch in range(start_epoch, args.epochs): - adjust_learning_rate(optimizer, epoch) - - print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) - - train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, use_cuda) - test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda) - - # append logger file - logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc]) - - # save model - is_best = test_acc > best_acc - best_acc = max(test_acc, best_acc) - save_checkpoint({ - 'epoch': epoch + 1, - 'state_dict': model.state_dict(), - 'acc': test_acc, - 'best_acc': best_acc, - 'optimizer' : optimizer.state_dict(), - }, is_best, checkpoint=args.checkpoint) - - logger.close() - logger.plot() - savefig(os.path.join(args.checkpoint, 'log.eps')) - - print('Best acc:') - print(best_acc) - -def train(train_loader, model, criterion, optimizer, epoch, use_cuda): - # switch to train mode - model.train() - - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - end = time.time() - - bar = Bar('Processing', max=len(train_loader)) - for batch_idx, (inputs, targets) in enumerate(train_loader): - # measure data loading time - data_time.update(time.time() - end) - - if use_cuda: - inputs, targets = inputs.cuda(), targets.cuda(async=True) - inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets) - - # compute output - outputs = model(inputs) - loss = criterion(outputs, targets) - - # measure accuracy and record loss - prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) - losses.update(loss.data[0], inputs.size(0)) - top1.update(prec1[0], inputs.size(0)) - top5.update(prec5[0], inputs.size(0)) - - # compute gradient and do SGD step - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - # plot progress - bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( - batch=batch_idx + 1, - size=len(train_loader), - data=data_time.val, - bt=batch_time.val, - total=bar.elapsed_td, - eta=bar.eta_td, - loss=losses.avg, - top1=top1.avg, - top5=top5.avg, - ) - bar.next() - bar.finish() - return (losses.avg, top1.avg) - -def test(val_loader, model, criterion, epoch, use_cuda): - global best_acc - - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - - # switch to evaluate mode - model.eval() - - end = time.time() - bar = Bar('Processing', max=len(val_loader)) - for batch_idx, (inputs, targets) in enumerate(val_loader): - # measure data loading time - data_time.update(time.time() - end) - - if use_cuda: - inputs, targets = inputs.cuda(), targets.cuda() - inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets) - - # compute output - outputs = model(inputs) - loss = criterion(outputs, targets) - - # measure accuracy and record loss - prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) - losses.update(loss.data[0], inputs.size(0)) - top1.update(prec1[0], inputs.size(0)) - top5.update(prec5[0], inputs.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - # plot progress - bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( - batch=batch_idx + 1, - size=len(val_loader), - data=data_time.avg, - bt=batch_time.avg, - total=bar.elapsed_td, - eta=bar.eta_td, - loss=losses.avg, - top1=top1.avg, - top5=top5.avg, - ) - bar.next() - bar.finish() - return (losses.avg, top1.avg) - -def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'): - filepath = os.path.join(checkpoint, filename) - torch.save(state, filepath) - if is_best: - shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) - -def adjust_learning_rate(optimizer, epoch): - global state - if epoch in args.schedule: - state['lr'] *= args.gamma - for param_group in optimizer.param_groups: - param_group['lr'] = state['lr'] - -if __name__ == '__main__': - main() diff --git a/src/torchprune/torchprune/util/models/cnn/models/cifar/__init__.py b/src/torchprune/torchprune/util/models/cnn/models/cifar/__init__.py deleted file mode 100644 index 3011b7a..0000000 --- a/src/torchprune/torchprune/util/models/cnn/models/cifar/__init__.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import absolute_import - -"""The models subpackage contains definitions for the following model for CIFAR10/CIFAR100 -architectures: - -- `AlexNet`_ -- `VGG`_ -- `ResNet`_ -- `SqueezeNet`_ -- `DenseNet`_ - -You can construct a model with random weights by calling its constructor: - -.. code:: python - - import torchvision.models as models - resnet18 = models.resnet18() - alexnet = models.alexnet() - squeezenet = models.squeezenet1_0() - densenet = models.densenet_161() - -We provide pre-trained models for the ResNet variants and AlexNet, using the -PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing -``pretrained=True``: - -.. code:: python - - import torchvision.models as models - resnet18 = models.resnet18(pretrained=True) - alexnet = models.alexnet(pretrained=True) - -ImageNet 1-crop error rates (224x224) - -======================== ============= ============= -Network Top-1 error Top-5 error -======================== ============= ============= -ResNet-18 30.24 10.92 -ResNet-34 26.70 8.58 -ResNet-50 23.85 7.13 -ResNet-101 22.63 6.44 -ResNet-152 21.69 5.94 -Inception v3 22.55 6.44 -AlexNet 43.45 20.91 -VGG-11 30.98 11.37 -VGG-13 30.07 10.75 -VGG-16 28.41 9.62 -VGG-19 27.62 9.12 -SqueezeNet 1.0 41.90 19.58 -SqueezeNet 1.1 41.81 19.38 -Densenet-121 25.35 7.83 -Densenet-169 24.00 7.00 -Densenet-201 22.80 6.43 -Densenet-161 22.35 6.20 -======================== ============= ============= - - -.. _AlexNet: https://arxiv.org/abs/1404.5997 -.. _VGG: https://arxiv.org/abs/1409.1556 -.. _ResNet: https://arxiv.org/abs/1512.03385 -.. _SqueezeNet: https://arxiv.org/abs/1602.07360 -.. _DenseNet: https://arxiv.org/abs/1608.06993 -""" - -from .alexnet import * -from .vgg import * -from .resnet import * -from .resnext import * -from .wrn import * -from .preresnet import * -from .densenet import * diff --git a/src/torchprune/torchprune/util/models/cnn/models/cifar/alexnet.py b/src/torchprune/torchprune/util/models/cnn/models/cifar/alexnet.py deleted file mode 100644 index 8c9407d..0000000 --- a/src/torchprune/torchprune/util/models/cnn/models/cifar/alexnet.py +++ /dev/null @@ -1,44 +0,0 @@ -'''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted. -Without BN, the start learning rate should be 0.01 -(c) YANG, Wei -''' -import torch.nn as nn - - -__all__ = ['alexnet'] - - -class AlexNet(nn.Module): - - def __init__(self, num_classes=10): - super(AlexNet, self).__init__() - self.features = nn.Sequential( - nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(64, 192, kernel_size=5, padding=2), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(192, 384, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(384, 256, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv2d(256, 256, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - ) - self.classifier = nn.Linear(256, num_classes) - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), -1) - x = self.classifier(x) - return x - - -def alexnet(**kwargs): - r"""AlexNet model architecture from the - `"One weird trick..." `_ paper. - """ - model = AlexNet(**kwargs) - return model diff --git a/src/torchprune/torchprune/util/models/cnn/models/cifar/densenet.py b/src/torchprune/torchprune/util/models/cnn/models/cifar/densenet.py deleted file mode 100644 index cf35774..0000000 --- a/src/torchprune/torchprune/util/models/cnn/models/cifar/densenet.py +++ /dev/null @@ -1,171 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import math - - -__all__ = ['densenet', 'densenet22', 'densenet100', 'densenet190'] - - -from torch.autograd import Variable - -class Bottleneck(nn.Module): - def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): - super(Bottleneck, self).__init__() - planes = expansion * growthRate - self.bn1 = nn.BatchNorm2d(inplanes) - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, - padding=1, bias=False) - self.relu = nn.ReLU(inplace=True) - self.dropRate = dropRate - - def forward(self, x): - out = self.bn1(x) - out = self.relu(out) - out = self.conv1(out) - out = self.bn2(out) - out = self.relu(out) - out = self.conv2(out) - if self.dropRate > 0: - out = F.dropout(out, p=self.dropRate, training=self.training) - - out = torch.cat((x, out), 1) - - return out - - -class BasicBlock(nn.Module): - def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0): - super(BasicBlock, self).__init__() - planes = expansion * growthRate - self.bn1 = nn.BatchNorm2d(inplanes) - self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3, - padding=1, bias=False) - self.relu = nn.ReLU(inplace=True) - self.dropRate = dropRate - - def forward(self, x): - out = self.bn1(x) - out = self.relu(out) - out = self.conv1(out) - if self.dropRate > 0: - out = F.dropout(out, p=self.dropRate, training=self.training) - - out = torch.cat((x, out), 1) - - return out - - -class Transition(nn.Module): - def __init__(self, inplanes, outplanes): - super(Transition, self).__init__() - self.bn1 = nn.BatchNorm2d(inplanes) - self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, - bias=False) - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - out = self.bn1(x) - out = self.relu(out) - out = self.conv1(out) - out = F.avg_pool2d(out, 2) - return out - - -class DenseNet(nn.Module): - - def __init__(self, depth=22, block=Bottleneck, - dropRate=0, num_classes=10, growthRate=12, compressionRate=2): - super(DenseNet, self).__init__() - - assert (depth - 4) % 3 == 0, 'depth should be 3n+4' - n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6 - - self.growthRate = growthRate - self.dropRate = dropRate - - # self.inplanes is a global variable used across multiple - # helper functions - self.inplanes = growthRate * 2 - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, - bias=False) - self.dense1 = self._make_denseblock(block, n) - self.trans1 = self._make_transition(compressionRate) - self.dense2 = self._make_denseblock(block, n) - self.trans2 = self._make_transition(compressionRate) - self.dense3 = self._make_denseblock(block, n) - self.bn = nn.BatchNorm2d(self.inplanes) - self.relu = nn.ReLU(inplace=True) - self.avgpool = nn.AvgPool2d(8) - self.fc = nn.Linear(self.inplanes, num_classes) - - # Weight initialization - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def _make_denseblock(self, block, blocks): - layers = [] - for i in range(blocks): - # Currently we fix the expansion ratio as the default value - layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) - self.inplanes += self.growthRate - - return nn.Sequential(*layers) - - def _make_transition(self, compressionRate): - inplanes = self.inplanes - outplanes = int(math.floor(self.inplanes // compressionRate)) - self.inplanes = outplanes - return Transition(inplanes, outplanes) - - - def forward(self, x): - x = self.conv1(x) - - x = self.trans1(self.dense1(x)) - x = self.trans2(self.dense2(x)) - x = self.dense3(x) - x = self.bn(x) - x = self.relu(x) - - x = self.avgpool(x) - x = x.view(x.size(0), -1) - x = self.fc(x) - - return x - - -def densenet(**kwargs): - """ - Constructs a ResNet model. - """ - return DenseNet(**kwargs) - -def densenet22(**kwargs): - """ - DenseNet-BC (L=22, k=12) - https://github.com/bearpaw/pytorch-classification/blob/master/TRAINING.md - """ - return densenet(**kwargs) - -def densenet100(**kwargs): - """ - DenseNet-BC (L=100, k=12) - https://github.com/bearpaw/pytorch-classification/blob/master/TRAINING.md - """ - return densenet(depth=100, growthRate=12, **kwargs) - - -def densenet190(**kwargs): - """ - DenseNet-BC (L=190, k=40) - https://github.com/bearpaw/pytorch-classification/blob/master/TRAINING.md - """ - return densenet(depth=190, growthRate=40, **kwargs) \ No newline at end of file diff --git a/src/torchprune/torchprune/util/models/cnn/models/cifar/preresnet.py b/src/torchprune/torchprune/util/models/cnn/models/cifar/preresnet.py deleted file mode 100644 index b69aeb6..0000000 --- a/src/torchprune/torchprune/util/models/cnn/models/cifar/preresnet.py +++ /dev/null @@ -1,158 +0,0 @@ -from __future__ import absolute_import - -'''Resnet for cifar dataset. -Ported form -https://github.com/facebook/fb.resnet.torch -and -https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py -(c) YANG, Wei -''' -import torch.nn as nn -import math - - -__all__ = ['preresnet'] - -def conv3x3(in_planes, out_planes, stride=1): - "3x3 convolution with padding" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(BasicBlock, self).__init__() - self.bn1 = nn.BatchNorm2d(inplanes) - self.relu = nn.ReLU(inplace=True) - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn2 = nn.BatchNorm2d(planes) - self.conv2 = conv3x3(planes, planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.bn1(x) - out = self.relu(out) - out = self.conv1(out) - - out = self.bn2(out) - out = self.relu(out) - out = self.conv2(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(Bottleneck, self).__init__() - self.bn1 = nn.BatchNorm2d(inplanes) - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.bn1(x) - out = self.relu(out) - out = self.conv1(out) - - out = self.bn2(out) - out = self.relu(out) - out = self.conv2(out) - - out = self.bn3(out) - out = self.relu(out) - out = self.conv3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - - return out - - -class PreResNet(nn.Module): - - def __init__(self, depth, num_classes=1000): - super(PreResNet, self).__init__() - # Model type specifies number of layers for CIFAR-10 model - assert (depth - 2) % 6 == 0, 'depth should be 6n+2' - n = (depth - 2) // 6 - - block = Bottleneck if depth >=44 else BasicBlock - - self.inplanes = 16 - self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, - bias=False) - self.layer1 = self._make_layer(block, 16, n) - self.layer2 = self._make_layer(block, 32, n, stride=2) - self.layer3 = self._make_layer(block, 64, n, stride=2) - self.bn = nn.BatchNorm2d(64 * block.expansion) - self.relu = nn.ReLU(inplace=True) - self.avgpool = nn.AvgPool2d(8) - self.fc = nn.Linear(64 * block.expansion, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def _make_layer(self, block, planes, blocks, stride=1): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample)) - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes)) - - return nn.Sequential(*layers) - - def forward(self, x): - x = self.conv1(x) - - x = self.layer1(x) # 32x32 - x = self.layer2(x) # 16x16 - x = self.layer3(x) # 8x8 - x = self.bn(x) - x = self.relu(x) - - x = self.avgpool(x) - x = x.view(x.size(0), -1) - x = self.fc(x) - - return x - - -def preresnet(**kwargs): - """ - Constructs a ResNet model. - """ - return PreResNet(**kwargs) diff --git a/src/torchprune/torchprune/util/models/cnn/models/cifar/resnet.py b/src/torchprune/torchprune/util/models/cnn/models/cifar/resnet.py deleted file mode 100644 index 294b590..0000000 --- a/src/torchprune/torchprune/util/models/cnn/models/cifar/resnet.py +++ /dev/null @@ -1,180 +0,0 @@ -from __future__ import absolute_import - -'''Resnet for cifar dataset. -Ported form -https://github.com/facebook/fb.resnet.torch -and -https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py -(c) YANG, Wei -''' -import torch.nn as nn -import math - - -__all__ = ['resnet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', - 'resnet110'] - -def conv3x3(in_planes, out_planes, stride=1): - "3x3 convolution with padding" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(BasicBlock, self).__init__() - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = nn.BatchNorm2d(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = nn.BatchNorm2d(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(Bottleneck, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class ResNet(nn.Module): - - def __init__(self, depth, num_classes=1000): - super(ResNet, self).__init__() - # Model type specifies number of layers for CIFAR-10 model - assert (depth - 2) % 6 == 0, 'depth should be 6n+2' - n = (depth - 2) // 6 - - block = Bottleneck if depth >=44 else BasicBlock - - self.inplanes = 16 - self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, - bias=False) - self.bn1 = nn.BatchNorm2d(16) - self.relu = nn.ReLU(inplace=True) - self.layer1 = self._make_layer(block, 16, n) - self.layer2 = self._make_layer(block, 32, n, stride=2) - self.layer3 = self._make_layer(block, 64, n, stride=2) - self.avgpool = nn.AvgPool2d(8) - self.fc = nn.Linear(64 * block.expansion, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def _make_layer(self, block, planes, blocks, stride=1): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample)) - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes)) - - return nn.Sequential(*layers) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) # 32x32 - - x = self.layer1(x) # 32x32 - x = self.layer2(x) # 16x16 - x = self.layer3(x) # 8x8 - - x = self.avgpool(x) - x = x.view(x.size(0), -1) - x = self.fc(x) - - return x - - -def resnet(**kwargs): - """ - Constructs a ResNet model. - """ - return ResNet(**kwargs) - - -def resnet20(**kwargs): - return ResNet(20, **kwargs) - - -def resnet32(**kwargs): - return ResNet(32, **kwargs) - - -def resnet44(**kwargs): - return ResNet(44, **kwargs) - - -def resnet56(**kwargs): - return ResNet(56, **kwargs) - - -def resnet110(**kwargs): - return ResNet(110, **kwargs) \ No newline at end of file diff --git a/src/torchprune/torchprune/util/models/cnn/models/cifar/resnext.py b/src/torchprune/torchprune/util/models/cnn/models/cifar/resnext.py deleted file mode 100644 index 50040ed..0000000 --- a/src/torchprune/torchprune/util/models/cnn/models/cifar/resnext.py +++ /dev/null @@ -1,126 +0,0 @@ -from __future__ import division -""" -Creates a ResNeXt Model as defined in: -Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). -Aggregated residual transformations for deep neural networks. -arXiv preprint arXiv:1611.05431. -import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py -""" -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import init - -__all__ = ['resnext'] - -class ResNeXtBottleneck(nn.Module): - """ - RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) - """ - def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor): - """ Constructor - Args: - in_channels: input channel dimensionality - out_channels: output channel dimensionality - stride: conv stride. Replaces pooling layer. - cardinality: num of convolution groups. - widen_factor: factor to reduce the input dimensionality before convolution. - """ - super(ResNeXtBottleneck, self).__init__() - D = cardinality * out_channels // widen_factor - self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False) - self.bn_reduce = nn.BatchNorm2d(D) - self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) - self.bn = nn.BatchNorm2d(D) - self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) - self.bn_expand = nn.BatchNorm2d(out_channels) - - self.shortcut = nn.Sequential() - if in_channels != out_channels: - self.shortcut.add_module('shortcut_conv', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)) - self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels)) - - def forward(self, x): - bottleneck = self.conv_reduce.forward(x) - bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True) - bottleneck = self.conv_conv.forward(bottleneck) - bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True) - bottleneck = self.conv_expand.forward(bottleneck) - bottleneck = self.bn_expand.forward(bottleneck) - residual = self.shortcut.forward(x) - return F.relu(residual + bottleneck, inplace=True) - - -class CifarResNeXt(nn.Module): - """ - ResNext optimized for the Cifar dataset, as specified in - https://arxiv.org/pdf/1611.05431.pdf - """ - def __init__(self, cardinality, depth, num_classes, widen_factor=4, dropRate=0): - """ Constructor - Args: - cardinality: number of convolution groups. - depth: number of layers. - num_classes: number of classes - widen_factor: factor to adjust the channel dimensionality - """ - super(CifarResNeXt, self).__init__() - self.cardinality = cardinality - self.depth = depth - self.block_depth = (self.depth - 2) // 9 - self.widen_factor = widen_factor - self.num_classes = num_classes - self.output_size = 64 - self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor] - - self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) - self.bn_1 = nn.BatchNorm2d(64) - self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1) - self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2) - self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2) - self.classifier = nn.Linear(1024, num_classes) - init.kaiming_normal(self.classifier.weight) - - for key in self.state_dict(): - if key.split('.')[-1] == 'weight': - if 'conv' in key: - init.kaiming_normal(self.state_dict()[key], mode='fan_out') - if 'bn' in key: - self.state_dict()[key][...] = 1 - elif key.split('.')[-1] == 'bias': - self.state_dict()[key][...] = 0 - - def block(self, name, in_channels, out_channels, pool_stride=2): - """ Stack n bottleneck modules where n is inferred from the depth of the network. - Args: - name: string name of the current block. - in_channels: number of input channels - out_channels: number of output channels - pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. - Returns: a Module consisting of n sequential bottlenecks. - """ - block = nn.Sequential() - for bottleneck in range(self.block_depth): - name_ = '%s_bottleneck_%d' % (name, bottleneck) - if bottleneck == 0: - block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality, - self.widen_factor)) - else: - block.add_module(name_, - ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.widen_factor)) - return block - - def forward(self, x): - x = self.conv_1_3x3.forward(x) - x = F.relu(self.bn_1.forward(x), inplace=True) - x = self.stage_1.forward(x) - x = self.stage_2.forward(x) - x = self.stage_3.forward(x) - x = F.avg_pool2d(x, 8, 1) - x = x.view(-1, 1024) - return self.classifier(x) - -def resnext(**kwargs): - """Constructs a ResNeXt. - """ - model = CifarResNeXt(**kwargs) - return model \ No newline at end of file diff --git a/src/torchprune/torchprune/util/models/cnn/models/cifar/vgg.py b/src/torchprune/torchprune/util/models/cnn/models/cifar/vgg.py deleted file mode 100644 index 89b1785..0000000 --- a/src/torchprune/torchprune/util/models/cnn/models/cifar/vgg.py +++ /dev/null @@ -1,138 +0,0 @@ -'''VGG for CIFAR10. FC layers are removed. -(c) YANG, Wei -''' -import torch.nn as nn -import torch.utils.model_zoo as model_zoo -import math - - -__all__ = [ - 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', - 'vgg19_bn', 'vgg19', -] - - -model_urls = { - 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', - 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', - 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', - 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', -} - - -class VGG(nn.Module): - - def __init__(self, features, num_classes=1000): - super(VGG, self).__init__() - self.features = features - self.classifier = nn.Linear(512, num_classes) - self._initialize_weights() - - def forward(self, x): - x = self.features(x) - x = x.view(x.size(0), -1) - x = self.classifier(x) - return x - - def _initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - n = m.weight.size(1) - m.weight.data.normal_(0, 0.01) - m.bias.data.zero_() - - -def make_layers(cfg, batch_norm=False): - layers = [] - in_channels = 3 - for v in cfg: - if v == 'M': - layers += [nn.MaxPool2d(kernel_size=2, stride=2)] - else: - conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) - if batch_norm: - layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] - else: - layers += [conv2d, nn.ReLU(inplace=True)] - in_channels = v - return nn.Sequential(*layers) - - -cfg = { - 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], - 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], -} - - -def vgg11(**kwargs): - """VGG 11-layer model (configuration "A") - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = VGG(make_layers(cfg['A']), **kwargs) - return model - - -def vgg11_bn(**kwargs): - """VGG 11-layer model (configuration "A") with batch normalization""" - model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) - return model - - -def vgg13(**kwargs): - """VGG 13-layer model (configuration "B") - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = VGG(make_layers(cfg['B']), **kwargs) - return model - - -def vgg13_bn(**kwargs): - """VGG 13-layer model (configuration "B") with batch normalization""" - model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) - return model - - -def vgg16(**kwargs): - """VGG 16-layer model (configuration "D") - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = VGG(make_layers(cfg['D']), **kwargs) - return model - - -def vgg16_bn(**kwargs): - """VGG 16-layer model (configuration "D") with batch normalization""" - model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) - return model - - -def vgg19(**kwargs): - """VGG 19-layer model (configuration "E") - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - """ - model = VGG(make_layers(cfg['E']), **kwargs) - return model - - -def vgg19_bn(**kwargs): - """VGG 19-layer model (configuration 'E') with batch normalization""" - model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) - return model diff --git a/src/torchprune/torchprune/util/models/cnn/models/cifar/wrn.py b/src/torchprune/torchprune/util/models/cnn/models/cifar/wrn.py deleted file mode 100644 index 49eff18..0000000 --- a/src/torchprune/torchprune/util/models/cnn/models/cifar/wrn.py +++ /dev/null @@ -1,128 +0,0 @@ -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -__all__ = ['wrn', 'wrn16_8', 'wrn28_10', 'wrn40_1', 'wrn40_4'] - -class BasicBlock(nn.Module): - def __init__(self, in_planes, out_planes, stride, dropRate=0.0): - super(BasicBlock, self).__init__() - self.bn1 = nn.BatchNorm2d(in_planes) - self.relu1 = nn.ReLU(inplace=True) - self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(out_planes) - self.relu2 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, - padding=1, bias=False) - self.droprate = dropRate - self.equalInOut = (in_planes == out_planes) - self.downsample = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, - padding=0, bias=False) or None - if self.downsample is not None: - self.downsample = nn.Sequential(*[self.downsample]) - - def forward(self, x): - if not self.equalInOut: - x = self.relu1(self.bn1(x)) - else: - out = self.relu1(self.bn1(x)) - out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) - if self.droprate > 0: - out = F.dropout(out, p=self.droprate, training=self.training) - out = self.conv2(out) - return torch.add(x if self.equalInOut else self.downsample(x), out) - -class NetworkBlock(nn.Module): - def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): - super(NetworkBlock, self).__init__() - self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) - def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): - layers = [] - for i in range(nb_layers): - layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) - return nn.Sequential(*layers) - def forward(self, x): - return self.layer(x) - -class WideResNet(nn.Module): - def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): - super(WideResNet, self).__init__() - nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] - assert (depth - 4) % 6 == 0, 'depth should be 6n+4' - n = (depth - 4) // 6 - block = BasicBlock - # 1st conv before any network block - self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, - padding=1, bias=False) - # 1st block - self.layer1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, - dropRate).layer - # 2nd block - self.layer2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, - dropRate).layer - # 3rd block - self.layer3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, - dropRate).layer - # global average pooling and classifier - self.bn1 = nn.BatchNorm2d(nChannels[3]) - self.relu = nn.ReLU(inplace=True) - self.fc = nn.Linear(nChannels[3], num_classes) - self.nChannels = nChannels[3] - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - m.bias.data.zero_() - - def forward(self, x): - out = self.conv1(x) - out = self.layer1(out) - out = self.layer2(out) - out = self.layer3(out) - out = self.relu(self.bn1(out)) - out = F.avg_pool2d(out, 8) - out = out.view(-1, self.nChannels) - return self.fc(out) - - -def wrn(**kwargs): - """ - Constructs a Wide Residual Networks. - """ - model = WideResNet(**kwargs) - return model - - -def wrn40_1(**kwargs): - """ - WRN, depth 40, widening factor 1 - """ - return WideResNet(depth=40, widen_factor=1, **kwargs) - - -def wrn40_4(**kwargs): - """ - WRN, depth 40, widening factor 4 - """ - return WideResNet(depth=40, widen_factor=4, **kwargs) - - -def wrn16_8(**kwargs): - """ - WRN, depth 16, widening factor 8 - """ - return WideResNet(depth=16, widen_factor=8, **kwargs) - - -def wrn28_10(**kwargs): - """ - WRN, depth 28, widening factor 10 - """ - return WideResNet(depth=28, widen_factor=10, **kwargs) diff --git a/src/torchprune/torchprune/util/models/cnn/models/imagenet/__init__.py b/src/torchprune/torchprune/util/models/cnn/models/imagenet/__init__.py deleted file mode 100644 index 5c0978e..0000000 --- a/src/torchprune/torchprune/util/models/cnn/models/imagenet/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from __future__ import absolute_import - -from .resnext import * diff --git a/src/torchprune/torchprune/util/models/cnn/models/imagenet/resnext.py b/src/torchprune/torchprune/util/models/cnn/models/imagenet/resnext.py deleted file mode 100644 index 181b9bf..0000000 --- a/src/torchprune/torchprune/util/models/cnn/models/imagenet/resnext.py +++ /dev/null @@ -1,173 +0,0 @@ -from __future__ import division -""" -Creates a ResNeXt Model as defined in: -Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). -Aggregated residual transformations for deep neural networks. -arXiv preprint arXiv:1611.05431. -import from https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua -""" -import math -import torch.nn as nn -import torch.nn.functional as F -from torch.nn import init -import torch - -__all__ = ['resnext50', 'resnext101', 'resnext152'] - -class Bottleneck(nn.Module): - """ - RexNeXt bottleneck type C - """ - expansion = 4 - - def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None): - """ Constructor - Args: - inplanes: input channel dimensionality - planes: output channel dimensionality - baseWidth: base width. - cardinality: num of convolution groups. - stride: conv stride. Replaces pooling layer. - """ - super(Bottleneck, self).__init__() - - D = int(math.floor(planes * (baseWidth / 64))) - C = cardinality - - self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False) - self.bn1 = nn.BatchNorm2d(D*C) - self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False) - self.bn2 = nn.BatchNorm2d(D*C) - self.conv3 = nn.Conv2d(D*C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) - self.relu = nn.ReLU(inplace=True) - - self.downsample = downsample - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class ResNeXt(nn.Module): - """ - ResNext optimized for the ImageNet dataset, as specified in - https://arxiv.org/pdf/1611.05431.pdf - """ - def __init__(self, baseWidth, cardinality, layers, num_classes): - """ Constructor - Args: - baseWidth: baseWidth for ResNeXt. - cardinality: number of convolution groups. - layers: config of layers, e.g., [3, 4, 6, 3] - num_classes: number of classes - """ - super(ResNeXt, self).__init__() - block = Bottleneck - - self.cardinality = cardinality - self.baseWidth = baseWidth - self.num_classes = num_classes - self.inplanes = 64 - self.output_size = 64 - - self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.relu = nn.ReLU(inplace=True) - self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], 2) - self.layer3 = self._make_layer(block, 256, layers[2], 2) - self.layer4 = self._make_layer(block, 512, layers[3], 2) - self.avgpool = nn.AvgPool2d(7) - self.fc = nn.Linear(512 * block.expansion, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - def _make_layer(self, block, planes, blocks, stride=1): - """ Stack n bottleneck modules where n is inferred from the depth of the network. - Args: - block: block type used to construct ResNext - planes: number of output channels (need to multiply by block.expansion) - blocks: number of blocks to be built - stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. - Returns: a Module consisting of n sequential bottlenecks. - """ - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample)) - self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality)) - - return nn.Sequential(*layers) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool1(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.avgpool(x) - x = x.view(x.size(0), -1) - x = self.fc(x) - - return x - - -def resnext50(baseWidth, cardinality): - """ - Construct ResNeXt-50. - """ - model = ResNeXt(baseWidth, cardinality, [3, 4, 6, 3], 1000) - return model - - -def resnext101(baseWidth, cardinality): - """ - Construct ResNeXt-101. - """ - model = ResNeXt(baseWidth, cardinality, [3, 4, 23, 3], 1000) - return model - - -def resnext152(baseWidth, cardinality): - """ - Construct ResNeXt-152. - """ - model = ResNeXt(baseWidth, cardinality, [3, 8, 36, 3], 1000) - return model diff --git a/src/torchprune/torchprune/util/models/cnn/utils/__init__.py b/src/torchprune/torchprune/util/models/cnn/utils/__init__.py deleted file mode 100644 index 848436b..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Useful utils -""" -from .misc import * -from .logger import * -from .visualize import * -from .eval import * - -# progress bar -import os, sys -sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) -from progress.bar import Bar as Bar \ No newline at end of file diff --git a/src/torchprune/torchprune/util/models/cnn/utils/eval.py b/src/torchprune/torchprune/util/models/cnn/utils/eval.py deleted file mode 100644 index 5051350..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/eval.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import print_function, absolute_import - -__all__ = ['accuracy'] - -def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) - res.append(correct_k.mul_(100.0 / batch_size)) - return res \ No newline at end of file diff --git a/src/torchprune/torchprune/util/models/cnn/utils/images/cifar.png b/src/torchprune/torchprune/util/models/cnn/utils/images/cifar.png deleted file mode 100644 index 752409a..0000000 Binary files a/src/torchprune/torchprune/util/models/cnn/utils/images/cifar.png and /dev/null differ diff --git a/src/torchprune/torchprune/util/models/cnn/utils/images/imagenet.png b/src/torchprune/torchprune/util/models/cnn/utils/images/imagenet.png deleted file mode 100644 index eb63b70..0000000 Binary files a/src/torchprune/torchprune/util/models/cnn/utils/images/imagenet.png and /dev/null differ diff --git a/src/torchprune/torchprune/util/models/cnn/utils/logger.py b/src/torchprune/torchprune/util/models/cnn/utils/logger.py deleted file mode 100644 index 7eb5c67..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/logger.py +++ /dev/null @@ -1,127 +0,0 @@ -# A simple torch style logger -# (C) Wei YANG 2017 -from __future__ import absolute_import -import matplotlib.pyplot as plt -import os -import sys -import numpy as np - -__all__ = ['Logger', 'LoggerMonitor', 'savefig'] - -def savefig(fname, dpi=None): - dpi = 150 if dpi == None else dpi - plt.savefig(fname, dpi=dpi) - -def plot_overlap(logger, names=None): - names = logger.names if names == None else names - numbers = logger.numbers - for _, name in enumerate(names): - x = np.arange(len(numbers[name])) - plt.plot(x, np.asarray(numbers[name])) - return [logger.title + '(' + name + ')' for name in names] - -class Logger(object): - '''Save training process to log file with simple plot function.''' - def __init__(self, fpath, title=None, resume=False): - self.file = None - self.resume = resume - self.title = '' if title == None else title - if fpath is not None: - if resume: - self.file = open(fpath, 'r') - name = self.file.readline() - self.names = name.rstrip().split('\t') - self.numbers = {} - for _, name in enumerate(self.names): - self.numbers[name] = [] - - for numbers in self.file: - numbers = numbers.rstrip().split('\t') - for i in range(0, len(numbers)): - self.numbers[self.names[i]].append(numbers[i]) - self.file.close() - self.file = open(fpath, 'a') - else: - self.file = open(fpath, 'w') - - def set_names(self, names): - if self.resume: - pass - # initialize numbers as empty list - self.numbers = {} - self.names = names - for _, name in enumerate(self.names): - self.file.write(name) - self.file.write('\t') - self.numbers[name] = [] - self.file.write('\n') - self.file.flush() - - - def append(self, numbers): - assert len(self.names) == len(numbers), 'Numbers do not match names' - for index, num in enumerate(numbers): - self.file.write("{0:.6f}".format(num)) - self.file.write('\t') - self.numbers[self.names[index]].append(num) - self.file.write('\n') - self.file.flush() - - def plot(self, names=None): - names = self.names if names == None else names - numbers = self.numbers - for _, name in enumerate(names): - x = np.arange(len(numbers[name])) - plt.plot(x, np.asarray(numbers[name])) - plt.legend([self.title + '(' + name + ')' for name in names]) - plt.grid(True) - - def close(self): - if self.file is not None: - self.file.close() - -class LoggerMonitor(object): - '''Load and visualize multiple logs.''' - def __init__ (self, paths): - '''paths is a distionary with {name:filepath} pair''' - self.loggers = [] - for title, path in paths.items(): - logger = Logger(path, title=title, resume=True) - self.loggers.append(logger) - - def plot(self, names=None): - plt.figure() - plt.subplot(121) - legend_text = [] - for logger in self.loggers: - legend_text += plot_overlap(logger, names) - plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) - plt.grid(True) - -if __name__ == '__main__': - # # Example - # logger = Logger('test.txt') - # logger.set_names(['Train loss', 'Valid loss','Test loss']) - - # length = 100 - # t = np.arange(length) - # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 - # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 - # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 - - # for i in range(0, length): - # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) - # logger.plot() - - # Example: logger monitor - paths = { - 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', - 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', - 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', - } - - field = ['Valid Acc.'] - - monitor = LoggerMonitor(paths) - monitor.plot(names=field) - savefig('test.eps') \ No newline at end of file diff --git a/src/torchprune/torchprune/util/models/cnn/utils/misc.py b/src/torchprune/torchprune/util/models/cnn/utils/misc.py deleted file mode 100644 index d387f59..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/misc.py +++ /dev/null @@ -1,76 +0,0 @@ -'''Some helper functions for PyTorch, including: - - get_mean_and_std: calculate the mean and std value of dataset. - - msr_init: net parameter initialization. - - progress_bar: progress bar mimic xlua.progress. -''' -import errno -import os -import sys -import time -import math - -import torch.nn as nn -import torch.nn.init as init -from torch.autograd import Variable - -__all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] - - -def get_mean_and_std(dataset): - '''Compute the mean and std value of dataset.''' - dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) - - mean = torch.zeros(3) - std = torch.zeros(3) - print('==> Computing mean and std..') - for inputs, targets in dataloader: - for i in range(3): - mean[i] += inputs[:,i,:,:].mean() - std[i] += inputs[:,i,:,:].std() - mean.div_(len(dataset)) - std.div_(len(dataset)) - return mean, std - -def init_params(net): - '''Init layer parameters.''' - for m in net.modules(): - if isinstance(m, nn.Conv2d): - init.kaiming_normal(m.weight, mode='fan_out') - if m.bias: - init.constant(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): - init.constant(m.weight, 1) - init.constant(m.bias, 0) - elif isinstance(m, nn.Linear): - init.normal(m.weight, std=1e-3) - if m.bias: - init.constant(m.bias, 0) - -def mkdir_p(path): - '''make dir if not exist''' - try: - os.makedirs(path) - except OSError as exc: # Python >2.5 - if exc.errno == errno.EEXIST and os.path.isdir(path): - pass - else: - raise - -class AverageMeter(object): - """Computes and stores the average and current value - Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 - """ - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count \ No newline at end of file diff --git a/src/torchprune/torchprune/util/models/cnn/utils/progress/LICENSE b/src/torchprune/torchprune/util/models/cnn/utils/progress/LICENSE deleted file mode 100644 index 059cc05..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/progress/LICENSE +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2012 Giorgos Verigakis -# -# Permission to use, copy, modify, and distribute this software for any -# purpose with or without fee is hereby granted, provided that the above -# copyright notice and this permission notice appear in all copies. -# -# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. diff --git a/src/torchprune/torchprune/util/models/cnn/utils/progress/MANIFEST.in b/src/torchprune/torchprune/util/models/cnn/utils/progress/MANIFEST.in deleted file mode 100644 index 0c73842..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/progress/MANIFEST.in +++ /dev/null @@ -1 +0,0 @@ -include README.rst LICENSE diff --git a/src/torchprune/torchprune/util/models/cnn/utils/progress/README.rst b/src/torchprune/torchprune/util/models/cnn/utils/progress/README.rst deleted file mode 100644 index 3f3be76..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/progress/README.rst +++ /dev/null @@ -1,131 +0,0 @@ -Easy progress reporting for Python -================================== - -|pypi| - -|demo| - -.. |pypi| image:: https://img.shields.io/pypi/v/progress.svg -.. |demo| image:: https://raw.github.com/verigak/progress/master/demo.gif - :alt: Demo - -Bars ----- - -There are 7 progress bars to choose from: - -- ``Bar`` -- ``ChargingBar`` -- ``FillingSquaresBar`` -- ``FillingCirclesBar`` -- ``IncrementalBar`` -- ``PixelBar`` -- ``ShadyBar`` - -To use them, just call ``next`` to advance and ``finish`` to finish: - -.. code-block:: python - - from progress.bar import Bar - - bar = Bar('Processing', max=20) - for i in range(20): - # Do some work - bar.next() - bar.finish() - -The result will be a bar like the following: :: - - Processing |############# | 42/100 - -To simplify the common case where the work is done in an iterator, you can -use the ``iter`` method: - -.. code-block:: python - - for i in Bar('Processing').iter(it): - # Do some work - -Progress bars are very customizable, you can change their width, their fill -character, their suffix and more: - -.. code-block:: python - - bar = Bar('Loading', fill='@', suffix='%(percent)d%%') - -This will produce a bar like the following: :: - - Loading |@@@@@@@@@@@@@ | 42% - -You can use a number of template arguments in ``message`` and ``suffix``: - -========== ================================ -Name Value -========== ================================ -index current value -max maximum value -remaining max - index -progress index / max -percent progress * 100 -avg simple moving average time per item (in seconds) -elapsed elapsed time in seconds -elapsed_td elapsed as a timedelta (useful for printing as a string) -eta avg * remaining -eta_td eta as a timedelta (useful for printing as a string) -========== ================================ - -Instead of passing all configuration options on instatiation, you can create -your custom subclass: - -.. code-block:: python - - class FancyBar(Bar): - message = 'Loading' - fill = '*' - suffix = '%(percent).1f%% - %(eta)ds' - -You can also override any of the arguments or create your own: - -.. code-block:: python - - class SlowBar(Bar): - suffix = '%(remaining_hours)d hours remaining' - @property - def remaining_hours(self): - return self.eta // 3600 - - -Spinners -======== - -For actions with an unknown number of steps you can use a spinner: - -.. code-block:: python - - from progress.spinner import Spinner - - spinner = Spinner('Loading ') - while state != 'FINISHED': - # Do some work - spinner.next() - -There are 5 predefined spinners: - -- ``Spinner`` -- ``PieSpinner`` -- ``MoonSpinner`` -- ``LineSpinner`` -- ``PixelSpinner`` - - -Other -===== - -There are a number of other classes available too, please check the source or -subclass one of them to create your own. - - -License -======= - -progress is licensed under ISC diff --git a/src/torchprune/torchprune/util/models/cnn/utils/progress/demo.gif b/src/torchprune/torchprune/util/models/cnn/utils/progress/demo.gif deleted file mode 100644 index 64b1e95..0000000 Binary files a/src/torchprune/torchprune/util/models/cnn/utils/progress/demo.gif and /dev/null differ diff --git a/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/__init__.py b/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/__init__.py deleted file mode 100644 index 09dfc1e..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/__init__.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) 2012 Giorgos Verigakis -# -# Permission to use, copy, modify, and distribute this software for any -# purpose with or without fee is hereby granted, provided that the above -# copyright notice and this permission notice appear in all copies. -# -# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -from __future__ import division - -from collections import deque -from datetime import timedelta -from math import ceil -from sys import stderr -from time import time - - -__version__ = '1.3' - - -class Infinite(object): - file = stderr - sma_window = 10 # Simple Moving Average window - - def __init__(self, *args, **kwargs): - self.index = 0 - self.start_ts = time() - self.avg = 0 - self._ts = self.start_ts - self._xput = deque(maxlen=self.sma_window) - for key, val in kwargs.items(): - setattr(self, key, val) - - def __getitem__(self, key): - if key.startswith('_'): - return None - return getattr(self, key, None) - - @property - def elapsed(self): - return int(time() - self.start_ts) - - @property - def elapsed_td(self): - return timedelta(seconds=self.elapsed) - - def update_avg(self, n, dt): - if n > 0: - self._xput.append(dt / n) - self.avg = sum(self._xput) / len(self._xput) - - def update(self): - pass - - def start(self): - pass - - def finish(self): - pass - - def next(self, n=1): - now = time() - dt = now - self._ts - self.update_avg(n, dt) - self._ts = now - self.index = self.index + n - self.update() - - def iter(self, it): - try: - for x in it: - yield x - self.next() - finally: - self.finish() - - -class Progress(Infinite): - def __init__(self, *args, **kwargs): - super(Progress, self).__init__(*args, **kwargs) - self.max = kwargs.get('max', 100) - - @property - def eta(self): - return int(ceil(self.avg * self.remaining)) - - @property - def eta_td(self): - return timedelta(seconds=self.eta) - - @property - def percent(self): - return self.progress * 100 - - @property - def progress(self): - return min(1, self.index / self.max) - - @property - def remaining(self): - return max(self.max - self.index, 0) - - def start(self): - self.update() - - def goto(self, index): - incr = index - self.index - self.next(incr) - - def iter(self, it): - try: - self.max = len(it) - except TypeError: - pass - - try: - for x in it: - yield x - self.next() - finally: - self.finish() diff --git a/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/bar.py b/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/bar.py deleted file mode 100644 index 5ee968f..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/bar.py +++ /dev/null @@ -1,88 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (c) 2012 Giorgos Verigakis -# -# Permission to use, copy, modify, and distribute this software for any -# purpose with or without fee is hereby granted, provided that the above -# copyright notice and this permission notice appear in all copies. -# -# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -from __future__ import unicode_literals -from . import Progress -from .helpers import WritelnMixin - - -class Bar(WritelnMixin, Progress): - width = 32 - message = '' - suffix = '%(index)d/%(max)d' - bar_prefix = ' |' - bar_suffix = '| ' - empty_fill = ' ' - fill = '#' - hide_cursor = True - - def update(self): - filled_length = int(self.width * self.progress) - empty_length = self.width - filled_length - - message = self.message % self - bar = self.fill * filled_length - empty = self.empty_fill * empty_length - suffix = self.suffix % self - line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, - suffix]) - self.writeln(line) - - -class ChargingBar(Bar): - suffix = '%(percent)d%%' - bar_prefix = ' ' - bar_suffix = ' ' - empty_fill = '∙' - fill = '█' - - -class FillingSquaresBar(ChargingBar): - empty_fill = '▢' - fill = '▣' - - -class FillingCirclesBar(ChargingBar): - empty_fill = '◯' - fill = '◉' - - -class IncrementalBar(Bar): - phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') - - def update(self): - nphases = len(self.phases) - filled_len = self.width * self.progress - nfull = int(filled_len) # Number of full chars - phase = int((filled_len - nfull) * nphases) # Phase of last char - nempty = self.width - nfull # Number of empty chars - - message = self.message % self - bar = self.phases[-1] * nfull - current = self.phases[phase] if phase > 0 else '' - empty = self.empty_fill * max(0, nempty - len(current)) - suffix = self.suffix % self - line = ''.join([message, self.bar_prefix, bar, current, empty, - self.bar_suffix, suffix]) - self.writeln(line) - - -class PixelBar(IncrementalBar): - phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') - - -class ShadyBar(IncrementalBar): - phases = (' ', '░', '▒', '▓', '█') diff --git a/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/counter.py b/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/counter.py deleted file mode 100644 index 6b45a1e..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/counter.py +++ /dev/null @@ -1,48 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (c) 2012 Giorgos Verigakis -# -# Permission to use, copy, modify, and distribute this software for any -# purpose with or without fee is hereby granted, provided that the above -# copyright notice and this permission notice appear in all copies. -# -# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -from __future__ import unicode_literals -from . import Infinite, Progress -from .helpers import WriteMixin - - -class Counter(WriteMixin, Infinite): - message = '' - hide_cursor = True - - def update(self): - self.write(str(self.index)) - - -class Countdown(WriteMixin, Progress): - hide_cursor = True - - def update(self): - self.write(str(self.remaining)) - - -class Stack(WriteMixin, Progress): - phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') - hide_cursor = True - - def update(self): - nphases = len(self.phases) - i = min(nphases - 1, int(self.progress * nphases)) - self.write(self.phases[i]) - - -class Pie(Stack): - phases = ('○', '◔', '◑', '◕', '●') diff --git a/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/helpers.py b/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/helpers.py deleted file mode 100644 index 9ed90b2..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/helpers.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) 2012 Giorgos Verigakis -# -# Permission to use, copy, modify, and distribute this software for any -# purpose with or without fee is hereby granted, provided that the above -# copyright notice and this permission notice appear in all copies. -# -# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -from __future__ import print_function - - -HIDE_CURSOR = '\x1b[?25l' -SHOW_CURSOR = '\x1b[?25h' - - -class WriteMixin(object): - hide_cursor = False - - def __init__(self, message=None, **kwargs): - super(WriteMixin, self).__init__(**kwargs) - self._width = 0 - if message: - self.message = message - - if self.file.isatty(): - if self.hide_cursor: - print(HIDE_CURSOR, end='', file=self.file) - print(self.message, end='', file=self.file) - self.file.flush() - - def write(self, s): - if self.file.isatty(): - b = '\b' * self._width - c = s.ljust(self._width) - print(b + c, end='', file=self.file) - self._width = max(self._width, len(s)) - self.file.flush() - - def finish(self): - if self.file.isatty() and self.hide_cursor: - print(SHOW_CURSOR, end='', file=self.file) - - -class WritelnMixin(object): - hide_cursor = False - - def __init__(self, message=None, **kwargs): - super(WritelnMixin, self).__init__(**kwargs) - if message: - self.message = message - - if self.file.isatty() and self.hide_cursor: - print(HIDE_CURSOR, end='', file=self.file) - - def clearln(self): - if self.file.isatty(): - print('\r\x1b[K', end='', file=self.file) - - def writeln(self, line): - if self.file.isatty(): - self.clearln() - print(line, end='', file=self.file) - self.file.flush() - - def finish(self): - if self.file.isatty(): - print(file=self.file) - if self.hide_cursor: - print(SHOW_CURSOR, end='', file=self.file) - - -from signal import signal, SIGINT -from sys import exit - - -class SigIntMixin(object): - """Registers a signal handler that calls finish on SIGINT""" - - def __init__(self, *args, **kwargs): - super(SigIntMixin, self).__init__(*args, **kwargs) - signal(SIGINT, self._sigint_handler) - - def _sigint_handler(self, signum, frame): - self.finish() - exit(0) diff --git a/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/spinner.py b/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/spinner.py deleted file mode 100644 index 464c7b2..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/progress/progress/spinner.py +++ /dev/null @@ -1,44 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright (c) 2012 Giorgos Verigakis -# -# Permission to use, copy, modify, and distribute this software for any -# purpose with or without fee is hereby granted, provided that the above -# copyright notice and this permission notice appear in all copies. -# -# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES -# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR -# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF -# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -from __future__ import unicode_literals -from . import Infinite -from .helpers import WriteMixin - - -class Spinner(WriteMixin, Infinite): - message = '' - phases = ('-', '\\', '|', '/') - hide_cursor = True - - def update(self): - i = self.index % len(self.phases) - self.write(self.phases[i]) - - -class PieSpinner(Spinner): - phases = ['◷', '◶', '◵', '◴'] - - -class MoonSpinner(Spinner): - phases = ['◑', '◒', '◐', '◓'] - - -class LineSpinner(Spinner): - phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] - -class PixelSpinner(Spinner): - phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] diff --git a/src/torchprune/torchprune/util/models/cnn/utils/progress/setup.py b/src/torchprune/torchprune/util/models/cnn/utils/progress/setup.py deleted file mode 100755 index c877781..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/progress/setup.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python - -from setuptools import setup - -import progress - - -setup( - name='progress', - version=progress.__version__, - description='Easy to use progress bars', - long_description=open('README.rst').read(), - author='Giorgos Verigakis', - author_email='verigak@gmail.com', - url='http://github.com/verigak/progress/', - license='ISC', - packages=['progress'], - classifiers=[ - 'Environment :: Console', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: ISC License (ISCL)', - 'Programming Language :: Python :: 2.6', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3.3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - ] -) diff --git a/src/torchprune/torchprune/util/models/cnn/utils/progress/test_progress.py b/src/torchprune/torchprune/util/models/cnn/utils/progress/test_progress.py deleted file mode 100755 index 0f68b01..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/progress/test_progress.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python - -from __future__ import print_function - -import random -import time - -from progress.bar import (Bar, ChargingBar, FillingSquaresBar, - FillingCirclesBar, IncrementalBar, PixelBar, - ShadyBar) -from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner, - PixelSpinner) -from progress.counter import Counter, Countdown, Stack, Pie - - -def sleep(): - t = 0.01 - t += t * random.uniform(-0.1, 0.1) # Add some variance - time.sleep(t) - - -for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): - suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]' - bar = bar_cls(bar_cls.__name__, suffix=suffix) - for i in bar.iter(range(200)): - sleep() - -for bar_cls in (IncrementalBar, PixelBar, ShadyBar): - suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]' - bar = bar_cls(bar_cls.__name__, suffix=suffix) - for i in bar.iter(range(200)): - sleep() - -for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): - for i in spin(spin.__name__ + ' ').iter(range(100)): - sleep() - print() - -for singleton in (Counter, Countdown, Stack, Pie): - for i in singleton(singleton.__name__ + ' ').iter(range(100)): - sleep() - print() - -bar = IncrementalBar('Random', suffix='%(index)d') -for i in range(100): - bar.goto(random.randint(0, 100)) - sleep() -bar.finish() diff --git a/src/torchprune/torchprune/util/models/cnn/utils/visualize.py b/src/torchprune/torchprune/util/models/cnn/utils/visualize.py deleted file mode 100644 index 51abeed..0000000 --- a/src/torchprune/torchprune/util/models/cnn/utils/visualize.py +++ /dev/null @@ -1,110 +0,0 @@ -import matplotlib.pyplot as plt -import torch -import torch.nn as nn -import torchvision -import torchvision.transforms as transforms -import numpy as np -from .misc import * - -__all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] - -# functions to show an image -def make_image(img, mean=(0,0,0), std=(1,1,1)): - for i in range(0, 3): - img[i] = img[i] * std[i] + mean[i] # unnormalize - npimg = img.numpy() - return np.transpose(npimg, (1, 2, 0)) - -def gauss(x,a,b,c): - return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) - -def colorize(x): - ''' Converts a one-channel grayscale image to a color heatmap image ''' - if x.dim() == 2: - torch.unsqueeze(x, 0, out=x) - if x.dim() == 3: - cl = torch.zeros([3, x.size(1), x.size(2)]) - cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) - cl[1] = gauss(x,1,.5,.3) - cl[2] = gauss(x,1,.2,.3) - cl[cl.gt(1)] = 1 - elif x.dim() == 4: - cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) - cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) - cl[:,1,:,:] = gauss(x,1,.5,.3) - cl[:,2,:,:] = gauss(x,1,.2,.3) - return cl - -def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): - images = make_image(torchvision.utils.make_grid(images), Mean, Std) - plt.imshow(images) - plt.show() - - -def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): - im_size = images.size(2) - - # save for adding mask - im_data = images.clone() - for i in range(0, 3): - im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize - - images = make_image(torchvision.utils.make_grid(images), Mean, Std) - plt.subplot(2, 1, 1) - plt.imshow(images) - plt.axis('off') - - # for b in range(mask.size(0)): - # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) - mask_size = mask.size(2) - # print('Max %f Min %f' % (mask.max(), mask.min())) - mask = (upsampling(mask, scale_factor=im_size/mask_size)) - # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) - # for c in range(3): - # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] - - # print(mask.size()) - mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) - # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) - plt.subplot(2, 1, 2) - plt.imshow(mask) - plt.axis('off') - -def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): - im_size = images.size(2) - - # save for adding mask - im_data = images.clone() - for i in range(0, 3): - im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize - - images = make_image(torchvision.utils.make_grid(images), Mean, Std) - plt.subplot(1+len(masklist), 1, 1) - plt.imshow(images) - plt.axis('off') - - for i in range(len(masklist)): - mask = masklist[i].data.cpu() - # for b in range(mask.size(0)): - # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) - mask_size = mask.size(2) - # print('Max %f Min %f' % (mask.max(), mask.min())) - mask = (upsampling(mask, scale_factor=im_size/mask_size)) - # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) - # for c in range(3): - # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] - - # print(mask.size()) - mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) - # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) - plt.subplot(1+len(masklist), 1, i+2) - plt.imshow(mask) - plt.axis('off') - - - -# x = torch.zeros(1, 3, 3) -# out = colorize(x) -# out_im = make_image(out) -# plt.imshow(out_im) -# plt.show() \ No newline at end of file diff --git a/src/torchprune/torchprune/util/models/fcnet.py b/src/torchprune/torchprune/util/models/fcnet.py index 6c45f2d..8c26b19 100644 --- a/src/torchprune/torchprune/util/models/fcnet.py +++ b/src/torchprune/torchprune/util/models/fcnet.py @@ -71,3 +71,11 @@ def lenet300_100(num_classes, **kwargs): def lenet500_300_100(num_classes, **kwargs): """Initialize a LeNet500-300-100 with the FCNet class.""" return FCNet([784, 500, 300, 100, num_classes], False) + + +def fcnet_nettrim(num_classes, **kwargs): + """Return a FC architectures according to Net-Trim. + + Net-Trim paper: https://epubs.siam.org/doi/pdf/10.1137/19M1246468 + """ + return FCNet([784, 300, 1000, 100, num_classes], False) diff --git a/src/torchprune/torchprune/util/models/ffjord.py b/src/torchprune/torchprune/util/models/ffjord.py new file mode 100644 index 0000000..15cc920 --- /dev/null +++ b/src/torchprune/torchprune/util/models/ffjord.py @@ -0,0 +1,626 @@ +"""Module containing various FFjord NODE configurations with torchdyn lib.""" + +import torch +import torch.nn as nn +from torch.distributions.utils import lazy_property +from torchdyn.models import CNF, hutch_trace, NeuralODE +from torchdyn.nn import Augmenter + + +def distribution_module(cls): + """Return a class representing a "modulified" torch distributions.""" + # get the actual distribution class from the mro() + cls_distribution = [] + for parent in cls.mro()[1:]: + if ( + issubclass(parent, torch.distributions.Distribution) + and parent != torch.distributions.Distribution + ): + cls_distribution.append(parent) + assert len(cls_distribution) == 1 + cls_distribution = cls_distribution[0] + + class DistributionModule: + """A class prototype for modulifing a desired distribution.""" + + def __init__(self, *args, **kwargs): + """Init distribution and register plain tensors as buffers.""" + super().__init__(*args, **kwargs) + + # after initializing re-register distribution parameters as buffers + k_tensors = [] + for k in self.__dict__: + + # check if it's a lazy_property + # in this case we should simply move on instead of evaluating it. + if self._is_lazy_property(k): + continue + + # check if it's a "plain" tensor + if isinstance(getattr(self, k), torch.Tensor): + k_tensors.append(k) + + for k in k_tensors: + # for a "plain "tensor we will now register it as buffer. + val = getattr(self, k) + delattr(self, k) + self.register_buffer(k, val) + + def __getattribute__(self, name): + """Return attribute with lazy_property special check.""" + if type(self)._is_lazy_property(name): + # deleting the attribute from the instance will simply "reset" + # the lazy_property + delattr(self, name) + return super().__getattribute__(name) + + @classmethod + def _is_lazy_property(cls, name): + return isinstance(getattr(cls, name, object), lazy_property) + + return type( + cls.__name__, (DistributionModule, cls_distribution, nn.Module), {} + ) + + +@distribution_module +class MultiVariateNormalModule(torch.distributions.MultivariateNormal): + """An "modulified" MultiVariateNormalDistribution.""" + + +class NeuralODEClassic(NeuralODE): + """A wrapper for NeuralODE for the interface we are used to.""" + + @property + def defunc(self): + """Return an old-school defunc which is now a vector field.""" + return self.vf.vf + + def __init__(self, cnf, s_span, sensitivity, solver, atol, rtol): + """Initialize the wrapper.""" + super().__init__( + cnf, + sensitivity=sensitivity, + solver=solver, + atol=atol, + rtol=rtol, + atol_adjoint=atol, + rtol_adjoint=rtol, + ) + + # fake assign s_span as before + self.register_buffer("s_span", s_span) + + # make sure classic "defunc" has "m" field + self.vf.vf.m = self.vf.vf.vf + + def forward(self, x, s_span=None): + """Forward in the classic style.""" + if s_span is None: + s_span = self.s_span + return super().forward(x, s_span)[1][1] + + def trajectory(self, x, s_span): + """Compute trajectory in the classic style.""" + return super().trajectory(x, s_span) + + +class Ffjord(nn.Module): + """Neural ODEs for CNFs via ffjord hutchuson trace estimator.""" + + @property + def trace_estimator(self): + """Return the desired trace estimator.""" + return hutch_trace + + def __init__( + self, + num_in, + num_layers, + hidden_size, + module_activate, + s_span=torch.linspace(0, 1, 20), + sensitivity="autograd", + solver="rk4", + atol=1e-5, + rtol=1e-5, + ): + """Initialize ffjord with the desired parameterization.""" + super().__init__() + if num_layers < 2: + raise ValueError("Node must be initialized with min 2 layers.") + if not issubclass(module_activate, nn.Module): + raise ValueError("Please provide valid module as activation.") + + # build up layers + layers = [nn.Linear(num_in, hidden_size), module_activate()] + for _ in range(num_layers - 2): + layers.append(nn.Linear(hidden_size, hidden_size)) + layers.append(module_activate()) + layers.append(nn.Linear(hidden_size, num_in)) + + # wrap in sequential module + self.f_forward = nn.Sequential(*layers) + + # get prior + self.prior = MultiVariateNormalModule( + torch.zeros(num_in), torch.eye(num_in), validate_args=False + ) + + # wrap in cnf + cnf = CNF( + self.f_forward, + trace_estimator=self.trace_estimator, + noise_dist=self.prior, + ) + + # wrap in neural ode + nde = NeuralODEClassic( + cnf, + s_span=s_span, + sensitivity=sensitivity, + solver=solver, + atol=atol, + rtol=rtol, + ) + + # wrap in augmenter + self.model = nn.Sequential( + Augmenter(augment_idx=1, augment_dims=1), nde + ) + + def forward(self, x): + """Forward by passing it on to NeuralODE. + + We wrap output into dictionary and also return prior so that downstream + tasks (e.g. loss and metrics) have access to the prior. + """ + self.model[1].nfe = 0 + return {"out": self.model(x), "prior": self.prior} + + +def ffjord_l4_h64_sigmoid(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + ) + + +def ffjord_l4_h64_softplus(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and softplus.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Softplus, + ) + + +def ffjord_l4_h64_tanh(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and tanh.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Tanh, + ) + + +def ffjord_l4_h64_relu(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and relu.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.ReLU, + ) + + +def ffjord_l8_h64_sigmoid(num_classes): + """Return a ffjord with 8 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=8, + hidden_size=64, + module_activate=nn.Sigmoid, + ) + + +def ffjord_l2_h128_sigmoid(num_classes): + """Return a ffjord with 2 layers, 128 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=2, + hidden_size=128, + module_activate=nn.Sigmoid, + ) + + +def ffjord_l2_h64_sigmoid(num_classes): + """Return a ffjord with 2 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.Sigmoid, + ) + + +def ffjord_l4_h64_sigmoid_dopri_adjoint(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l4_h64_sigmoid_dopri_autograd(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="autograd", + solver="dopri5", + ) + + +def ffjord_l4_h64_sigmoid_rk4_autograd(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 20), + sensitivity="autograd", + solver="rk4", + ) + + +def ffjord_l4_h64_sigmoid_rk4_adjoint(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 20), + sensitivity="adjoint", + solver="rk4", + ) + + +def ffjord_l4_h64_sigmoid_euler_autograd(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 100), + sensitivity="autograd", + solver="euler", + ) + + +def ffjord_l4_h64_sigmoid_euler_adjoint(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 100), + sensitivity="adjoint", + solver="euler", + ) + + +def ffjord_l4_h64_sigmoid_da(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l4_h64_sigmoid_da_autograd(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="autograd", + solver="dopri5", + ) + + +def ffjord_l4_h64_softplus_da(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and softplus.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Softplus, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l4_h64_tanh_da(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and tanh.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l4_h64_relu_da(num_classes): + """Return a ffjord with 4 layers, 64 neurons, and relu.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=64, + module_activate=nn.ReLU, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l8_h64_sigmoid_da(num_classes): + """Return a ffjord with 8 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=8, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l8_h37_sigmoid_da(num_classes): + """Return a ffjord with 8 layers, 37 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=8, + hidden_size=37, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l8_h18_sigmoid_da(num_classes): + """Return a ffjord with 8 layers, 18 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=8, + hidden_size=18, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l8_h10_sigmoid_da(num_classes): + """Return a ffjord with 8 layers, 10 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=8, + hidden_size=10, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l6_h45_sigmoid_da(num_classes): + """Return a ffjord with 6 layers, 45 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=6, + hidden_size=45, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l6_h22_sigmoid_da(num_classes): + """Return a ffjord with 6 layers, 22 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=6, + hidden_size=22, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l6_h12_sigmoid_da(num_classes): + """Return a ffjord with 6 layers, 12 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=6, + hidden_size=12, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l4_h128_sigmoid_da(num_classes): + """Return a ffjord with 4 layers, 128 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=128, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l4_h30_sigmoid_da(num_classes): + """Return a ffjord with 4 layers, 30 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=30, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l4_h17_sigmoid_da(num_classes): + """Return a ffjord with 4 layers, 17 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=4, + hidden_size=17, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l3_h90_sigmoid_da(num_classes): + """Return a ffjord with 3 layers, 90 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=3, + hidden_size=90, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l3_h43_sigmoid_da(num_classes): + """Return a ffjord with 3 layers, 43 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=3, + hidden_size=43, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l3_h23_sigmoid_da(num_classes): + """Return a ffjord with 3 layers, 23 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=3, + hidden_size=23, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l2_h1700_sigmoid_da(num_classes): + """Return a ffjord with 2 layers, 1700 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=2, + hidden_size=1700, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l2_h400_sigmoid_da(num_classes): + """Return a ffjord with 2 layers, 400 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=2, + hidden_size=400, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l2_h128_sigmoid_da(num_classes): + """Return a ffjord with 2 layers, 128 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=2, + hidden_size=128, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def ffjord_l2_h128_sigmoid_da_autograd(num_classes): + """Return a ffjord with 2 layers, 128 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=2, + hidden_size=128, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="autograd", + solver="dopri5", + ) + + +def ffjord_l2_h64_sigmoid_da(num_classes): + """Return a ffjord with 2 layers, 64 neurons, and sigmoid.""" + return Ffjord( + num_in=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) diff --git a/src/torchprune/torchprune/util/models/ffjord_cnf.py b/src/torchprune/torchprune/util/models/ffjord_cnf.py new file mode 100644 index 0000000..8b338b2 --- /dev/null +++ b/src/torchprune/torchprune/util/models/ffjord_cnf.py @@ -0,0 +1,435 @@ +"""Module containing various FFjord CNF configurations from original code.""" + +import torch +import torch.nn as nn +from ..external.ffjord import train_misc +from ..external.ffjord.lib import layers +from ..external.ffjord.lib import odenvp +from ..external.ffjord.lib import multiscale_parallel + + +class FfjordCNF(nn.Module): + """A wrapper class for FFJORD models in the CNF (images) setting.""" + + def __init__(self, model, regularization_fns, regularization_coeffs): + """Initialize with the original model.""" + super().__init__() + self.model = model + self._regularization_fns = regularization_fns + self._regularization_coeffs = regularization_coeffs + + def forward(self, x): + """Run vanilla forward and some extra args for loss computation.""" + # do a forward pass over the network + zero = torch.zeros(x.shape[0], 1).to(x) + z_out, delta_logp = self.model(x, zero) + output = { + "out": z_out, + "delta_logp": delta_logp, + "nelement": x.nelement(), + } + + # add regularizer loss to output + if len(self._regularization_coeffs) > 0: + reg_states = train_misc.get_regularization( + self.model, self._regularization_coeffs + ) + output["reg_loss"] = sum( + reg_state * coeff + for reg_state, coeff in zip( + reg_states, self._regularization_coeffs + ) + if coeff != 0 + ) + + # return output dictionary + return output + + +class FfjordCNFConfig: + """A class containing the required configurations for FFJORD models.""" + + @property + def dims(self): + """Return dims. + + Check out external/ffjord/train_cnf.py for more info. + """ + return self._dims + + @property + def strides(self): + """Return strides. + + Check out external/ffjord/train_cnf.py for more info. + """ + return self._strides + + @property + def num_blocks(self): + """Return num_blocks. + + Check out external/ffjord/train_cnf.py for more info. + """ + return self._num_blocks + + @property + def conv(self): + """Return conv. + + Check out external/ffjord/train_cnf.py for more info. + """ + return True + + @property + def layer_type(self): + """Return layer_type. + + Check out external/ffjord/train_cnf.py for more info. + """ + return "concat" + + @property + def divergence_fn(self): + """Return divergence_fn. + + Check out external/ffjord/train_cnf.py for more info. + """ + return "approximate" + + @property + def nonlinearity(self): + """Return nonlinearity. + + Check out external/ffjord/train_cnf.py for more info. + """ + return self._nonlinearity + + @property + def solver(self): + """Return solver. + + Check out external/ffjord/train_cnf.py for more info. + """ + return self._solver + + @property + def atol(self): + """Return atol. + + Check out external/ffjord/train_cnf.py for more info. + """ + return self._atol + + @property + def rtol(self): + """Return rtol. + + Check out external/ffjord/train_cnf.py for more info. + """ + return self._rtol + + @property + def step_size(self): + """Return step_size. + + Check out external/ffjord/train_cnf.py for more info. + """ + return None + + @property + def test_solver(self): + """Return test_solver. + + Check out external/ffjord/train_cnf.py for more info. + """ + return None + + @property + def test_atol(self): + """Return test_atol. + + Check out external/ffjord/train_cnf.py for more info. + """ + return None + + @property + def test_rtol(self): + """Return test_rtol. + + Check out external/ffjord/train_cnf.py for more info. + """ + return None + + @property + def imagesize(self): + """Return imagesize. + + Check out external/ffjord/train_cnf.py for more info. + """ + return None + + @property + def alpha(self): + """Return alpha. + + Check out external/ffjord/train_cnf.py for more info. + """ + return 1e-6 + + @property + def time_length(self): + """Return time_length. + + Check out external/ffjord/train_cnf.py for more info. + """ + return 1.0 + + @property + def train_T(self): # pylint: disable=C0103 + """Return train_T. + + Check out external/ffjord/train_cnf.py for more info. + """ + return True + + @property + def batch_size(self): + """Return batch_size. + + Check out external/ffjord/train_cnf.py for more info. + + This must be hard-coded here unfortunately. + """ + return 200 + + @property + def batch_norm(self): + """Return batch_norm. + + Check out external/ffjord/train_cnf.py for more info. + """ + return False + + @property + def residual(self): + """Return residual. + + Check out external/ffjord/train_cnf.py for more info. + """ + return False + + @property + def autoencode(self): + """Return autoencode. + + Check out external/ffjord/train_cnf.py for more info. + """ + return False + + @property + def rademacher(self): + """Return rademacher. + + Check out external/ffjord/train_cnf.py for more info. + """ + return False + + @property + def multiscale(self): + """Return multiscale. + + Check out external/ffjord/train_cnf.py for more info. + """ + return True + + @property + def parallel(self): + """Return parallel. + + Check out external/ffjord/train_cnf.py for more info. + """ + return False + + @property + def l1int(self): + """Return l1int. + + Check out external/ffjord/train_cnf.py for more info. + """ + return None + + @property + def l2int(self): + """Return l2int. + + Check out external/ffjord/train_cnf.py for more info. + """ + return None + + @property + def dl2int(self): + """Return dl2int. + + Check out external/ffjord/train_cnf.py for more info. + """ + return None + + @property + def JFrobint(self): # pylint: disable=C0103 + """Return JFrobint. + + Check out external/ffjord/train_cnf.py for more info. + """ + return None + + @property + def JdiagFrobint(self): # pylint: disable=C0103 + """Return JdiagFrobint. + + Check out external/ffjord/train_cnf.py for more info. + """ + return None + + @property + def JoffdiagFrobint(self): # pylint: disable=C0103 + """Return JoffdiagFrobint. + + Check out external/ffjord/train_cnf.py for more info. + """ + return None + + def __init__( + self, + output_size, + data_shape, + dims="64,64,64", + strides="1,1,1,1", + num_blocks=2, + nonlinearity="softplus", + solver="dopri5", + atol=1e-5, + rtol=1e-5, + ): + """Initialize with the variable properties.""" + self._output_size = output_size + self._data_shape = data_shape + self._dims = dims + self._strides = strides + self._num_blocks = num_blocks + self._nonlinearity = nonlinearity + self._solver = solver + self._atol = atol + self._rtol = rtol + + def _create_model(self, args, data_shape, regularization_fns): + """Create model simulating the cnf function.""" + hidden_dims = tuple(map(int, args.dims.split(","))) + strides = tuple(map(int, args.strides.split(","))) + + if args.multiscale: + model = odenvp.ODENVP( + (args.batch_size, *data_shape), + n_blocks=args.num_blocks, + intermediate_dims=hidden_dims, + nonlinearity=args.nonlinearity, + alpha=args.alpha, + cnf_kwargs={ + "T": args.time_length, + "train_T": args.train_T, + "regularization_fns": regularization_fns, + }, + ) + elif args.parallel: + model = multiscale_parallel.MultiscaleParallelCNF( + (args.batch_size, *data_shape), + n_blocks=args.num_blocks, + intermediate_dims=hidden_dims, + alpha=args.alpha, + time_length=args.time_length, + ) + else: + if args.autoencode: + + def build_cnf(): + autoencoder_diffeq = layers.AutoencoderDiffEqNet( + hidden_dims=hidden_dims, + input_shape=data_shape, + strides=strides, + conv=args.conv, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = layers.AutoencoderODEfunc( + autoencoder_diffeq=autoencoder_diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + cnf = layers.CNF( + odefunc=odefunc, + T=args.time_length, + regularization_fns=regularization_fns, + solver=args.solver, + ) + return cnf + + else: + + def build_cnf(): + diffeq = layers.ODEnet( + hidden_dims=hidden_dims, + input_shape=data_shape, + strides=strides, + conv=args.conv, + layer_type=args.layer_type, + nonlinearity=args.nonlinearity, + ) + odefunc = layers.ODEfunc( + diffeq=diffeq, + divergence_fn=args.divergence_fn, + residual=args.residual, + rademacher=args.rademacher, + ) + cnf = layers.CNF( + odefunc=odefunc, + T=args.time_length, + train_T=args.train_T, + regularization_fns=regularization_fns, + solver=args.solver, + ) + return cnf + + chain = ( + [layers.LogitTransform(alpha=args.alpha)] + if args.alpha > 0 + else [layers.ZeroMeanTransform()] + ) + chain = chain + [build_cnf() for _ in range(args.num_blocks)] + if args.batch_norm: + chain.append(layers.MovingBatchNorm2d(data_shape[0])) + model = layers.SequentialFlow(chain) + return model + + def get_model(self): + """Construct and return model.""" + ( + regularization_fns, + regularization_coeffs, + ) = train_misc.create_regularization_fns(self) + model = self._create_model(self, self._data_shape, regularization_fns) + + # wrap model into our tabular Ffjord model for correct output dict + return FfjordCNF(model, regularization_fns, regularization_coeffs) + + +def ffjord_multiscale_cifar(num_classes): + """Return the FFJORD multiscale architecture for CIFAR10.""" + config = FfjordCNFConfig(output_size=num_classes, data_shape=(3, 32, 32)) + return config.get_model() + + +def ffjord_multiscale_mnist(num_classes, **kwargs): + """Return the FFJORD multiscale architecture for MNIST.""" + config = FfjordCNFConfig(output_size=num_classes, data_shape=(1, 28, 28)) + return config.get_model() diff --git a/src/torchprune/torchprune/util/models/ffjord_tabular.py b/src/torchprune/torchprune/util/models/ffjord_tabular.py new file mode 100644 index 0000000..a0218d0 --- /dev/null +++ b/src/torchprune/torchprune/util/models/ffjord_tabular.py @@ -0,0 +1,383 @@ +"""Module containing various FFjord NODE configurations from original code.""" + +import torch +import torch.nn as nn +from ..external.ffjord import train_misc + + +class FfjordTabular(nn.Module): + """A wrapper class for FFJORD models in the tabular data setting.""" + + def __init__( + self, + model, + output_size, + hdim_factor, + regularization_fns, + regularization_coeffs, + ): + """Initialize with the original model.""" + super().__init__() + self.model = model + self._num_weights_threshold = output_size * hdim_factor + self._regularization_fns = regularization_fns + self._regularization_coeffs = regularization_coeffs + + def forward(self, x): + """Run vanilla forward and some extra args for loss computation.""" + # do a forward pass over the network + zero = torch.zeros(x.shape[0], 1).to(x) + z_out, delta_logp = self.model(x, zero) + output = {"out": z_out, "delta_logp": delta_logp} + + # add regularizer loss to output + if len(self._regularization_coeffs) > 0: + reg_states = train_misc.get_regularization( + self.model, self._regularization_coeffs + ) + output["reg_loss"] = sum( + reg_state * coeff + for reg_state, coeff in zip( + reg_states, self._regularization_coeffs + ) + if coeff != 0 + ) + + # return output dictionary + return output + + def is_compressible(self, module): + """Return True if the provided module is compressible.""" + return module.weight.numel() > self._num_weights_threshold + + +class FfjordTabularConfig: + """A class containing the configurations for FFJORD tabular models.""" + + @property + def layer_type(self): + """Return layer_type. + + Check out external/ffjord/train_tabular.py for more info. + """ + return "concatsquash" + + @property + def hdim_factor(self): + """Return hdim_factor. + + Check out external/ffjord/train_tabular.py for more info. + """ + return self._hdim_factor + + @property + def nhidden(self): + """Return nhidden. + + Check out external/ffjord/train_tabular.py for more info. + """ + return self._nhidden + + @property + def dims(self): + """Return dims. + + Check out external/ffjord/train_tabular.py for more info. + """ + return "-".join( + [str(self.hdim_factor * self._output_size)] * self.nhidden + ) + + @property + def num_blocks(self): + """Return num_blocks. + + Check out external/ffjord/train_tabular.py for more info. + """ + return self._num_blocks + + @property + def time_length(self): + """Return time_length. + + Check out external/ffjord/train_tabular.py for more info. + """ + return 1.0 + + @property + def train_T(self): # pylint: disable=C0103 + """Return train_T. + + Check out external/ffjord/train_tabular.py for more info. + """ + return True + + @property + def divergence_fn(self): + """Return divergence_fn. + + Check out external/ffjord/train_tabular.py for more info. + """ + return "approximate" + + @property + def nonlinearity(self): + """Return nonlinearity. + + Check out external/ffjord/train_tabular.py for more info. + """ + return self._nonlinearity + + @property + def solver(self): + """Return solver. + + Check out external/ffjord/train_tabular.py for more info. + """ + return self._solver + + @property + def atol(self): + """Return atol. + + Check out external/ffjord/train_tabular.py for more info. + """ + return self._atol + + @property + def rtol(self): + """Return rtol. + + Check out external/ffjord/train_tabular.py for more info. + """ + return self._rtol + + @property + def step_size(self): + """Return step_size. + + Check out external/ffjord/train_tabular.py for more info. + """ + return None + + @property + def test_solver(self): + """Return test_solver. + + Check out external/ffjord/train_tabular.py for more info. + """ + return None + + @property + def test_atol(self): + """Return test_atol. + + Check out external/ffjord/train_tabular.py for more info. + """ + return None + + @property + def test_rtol(self): + """Return test_rtol. + + Check out external/ffjord/train_tabular.py for more info. + """ + return None + + @property + def residual(self): + """Return residual. + + Check out external/ffjord/train_tabular.py for more info. + """ + return False + + @property + def rademacher(self): + """Return rademacher. + + Check out external/ffjord/train_tabular.py for more info. + """ + return False + + @property + def batch_norm(self): + """Return batch_norm. + + Check out external/ffjord/train_tabular.py for more info. + """ + return False + + @property + def bn_lag(self): + """Return bn_lag. + + Check out external/ffjord/train_tabular.py for more info. + """ + return 0.0 + + @property + def l1int(self): + """Return l1int. + + Check out external/ffjord/train_tabular.py for more info. + """ + return None + + @property + def l2int(self): + """Return l2int. + + Check out external/ffjord/train_tabular.py for more info. + """ + return None + + @property + def dl2int(self): + """Return dl2int. + + Check out external/ffjord/train_tabular.py for more info. + """ + return None + + @property + def JFrobint(self): # pylint: disable=C0103 + """Return JFrobint. + + Check out external/ffjord/train_tabular.py for more info. + """ + return None + + @property + def JdiagFrobint(self): # pylint: disable=C0103 + """Return JdiagFrobint. + + Check out external/ffjord/train_tabular.py for more info. + """ + return None + + @property + def JoffdiagFrobint(self): # pylint: disable=C0103 + """Return JoffdiagFrobint. + + Check out external/ffjord/train_tabular.py for more info. + """ + return None + + def __init__( + self, + output_size, + hdim_factor=10, + nhidden=1, + num_blocks=1, + nonlinearity="softplus", + solver="dopri5", + atol=1e-8, + rtol=1e-6, + ): + """Initialize with the variable properties.""" + self._output_size = output_size + self._hdim_factor = hdim_factor + self._nhidden = nhidden + self._num_blocks = num_blocks + self._nonlinearity = nonlinearity + self._solver = solver + self._atol = atol + self._rtol = rtol + + def get_model(self): + """Construct and return model.""" + ( + regularization_fns, + regularization_coeffs, + ) = train_misc.create_regularization_fns(self) + model = train_misc.build_model_tabular( + self, self._output_size, regularization_fns + ) + + # wrap model into our tabular Ffjord model for correct output dict + return FfjordTabular( + model, + self._output_size, + self.hdim_factor, + regularization_fns, + regularization_coeffs, + ) + + +def ffjord_l3_hm10_f5_tanh(num_classes): + """Return a ffjord with 3 layers, hidden factor 10, 5 flows, tanh. + + This is the standard configuration for the tabular POWER dataset. + """ + config = FfjordTabularConfig( + output_size=num_classes, + nhidden=3, + hdim_factor=10, + num_blocks=5, + nonlinearity="tanh", + ) + return config.get_model() + + +# I think the factor is always hdim_factor * output_size + + +def ffjord_l3_hm20_f5_tanh(num_classes): + """Return a ffjord with 3 layers, hidden factor 20, 5 flows, tanh. + + This is the standard configuration for the tabular GAS dataset. + """ + config = FfjordTabularConfig( + output_size=num_classes, + nhidden=3, + hdim_factor=20, + num_blocks=5, + nonlinearity="tanh", + ) + return config.get_model() + + +def ffjord_l2_hm10_f10_softplus(num_classes): + """Return a ffjord with 2 layers, hidden factor 10, 10 flows, softplus. + + This is the standard configuration for the tabular HEPMASS dataset. + """ + config = FfjordTabularConfig( + output_size=num_classes, + nhidden=2, + hdim_factor=10, + num_blocks=10, + nonlinearity="softplus", + ) + return config.get_model() + + +def ffjord_l2_hm20_f1_softplus(num_classes): + """Return a ffjord with 2 layers, hidden factor 20, 1 flow steps, softplus. + + This is the standard configuration for the tabular MINIBOONE dataset. + """ + config = FfjordTabularConfig( + output_size=num_classes, + nhidden=2, + hdim_factor=20, + num_blocks=1, + nonlinearity="softplus", + ) + return config.get_model() + + +def ffjord_l3_hm20_f2_softplus(num_classes): + """Return a ffjord with 3 layers, hidden factor 20, 2 flow steps, softplus. + + This is the standard configuration for the tabular BSDS300 dataset. + """ + config = FfjordTabularConfig( + output_size=num_classes, + nhidden=3, + hdim_factor=20, + num_blocks=2, + nonlinearity="softplus", + ) + return config.get_model() diff --git a/src/torchprune/torchprune/util/models/node.py b/src/torchprune/torchprune/util/models/node.py new file mode 100644 index 0000000..19a870c --- /dev/null +++ b/src/torchprune/torchprune/util/models/node.py @@ -0,0 +1,369 @@ +"""Module containing various Neural ODE configurations.""" + +import torch +import torch.nn as nn +from .ffjord import NeuralODEClassic + + +class Node(nn.Module): + """Neural ODE for classification.""" + + def __init__( + self, + num_in, + num_out, + num_layers, + hidden_size, + module_activate, + s_span=torch.linspace(0, 1, 20), + sensitivity="autograd", + solver="rk4", + atol=1e-4, + rtol=1e-4, + ): + """Initialize a neural ode with the desired parameterization.""" + super().__init__() + if num_layers < 2: + raise ValueError("Node must be initialized with min 2 layers.") + if not issubclass(module_activate, nn.Module): + raise ValueError("Please provide valid module as activation.") + + # build up layers + layers = [nn.Linear(num_in, hidden_size), module_activate()] + for _ in range(num_layers - 2): + layers.append(nn.Linear(hidden_size, hidden_size)) + layers.append(module_activate()) + layers.append(nn.Linear(hidden_size, num_out)) + + # wrap in sequential module + self.f_forward = nn.Sequential(*layers) + + # wrap in torchdyn NeuralODE + self.model = NeuralODEClassic( + self.f_forward, + s_span=s_span, + sensitivity=sensitivity, + solver=solver, + atol=atol, + rtol=rtol, + ) + + def forward(self, x): + """Forward by passing it on to NeuralODE.""" + return self.model(x) + + +def node_l2_h64_tanh(num_classes): + """Return a classification node with 2 layers, 64 neurons, and tanh.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.Tanh, + ) + + +def node_l2_h64_softplus(num_classes): + """Return a classification node with 2 layers, 64 neurons, and softplus.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.Softplus, + ) + + +def node_l2_h64_sigmoid(num_classes): + """Return a classification node with 2 layers, 64 neurons, and sigmoid.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.Sigmoid, + ) + + +def node_l2_h64_relu(num_classes): + """Return a classification node with 2 layers, 64 neurons, and relu.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.ReLU, + ) + + +def node_l4_h32_tanh(num_classes): + """Return a classification node with 4 layers, 32 neurons, and tanh.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=4, + hidden_size=32, + module_activate=nn.Tanh, + ) + + +def node_l2_h32_tanh(num_classes): + """Return a classification node with 2 layers, 32 neurons, and tanh.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=32, + module_activate=nn.Tanh, + ) + + +def node_l2_h128_tanh(num_classes): + """Return a classification node with 2 layers, 128 neurons, and tanh.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=128, + module_activate=nn.Tanh, + ) + + +def node_l4_h128_tanh(num_classes): + """Return a classification node with 4 layers, 128 neurons, and tanh.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=4, + hidden_size=128, + module_activate=nn.Tanh, + ) + + +def node_l4_h32_tanh_dopri_adjoint(num_classes): + """Return a classification node with 4 layers, 32 neurons, and tanh. + + We also modify solver options for this one + """ + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=4, + hidden_size=32, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def node_l4_h32_tanh_dopri_autograd(num_classes): + """Return a classification node with 4 layers, 32 neurons, and tanh. + + We also modify solver options for this one + """ + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=4, + hidden_size=32, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 2), + sensitivity="autograd", + solver="dopri5", + ) + + +def node_l4_h32_tanh_rk4_adjoint(num_classes): + """Return a classification node with 4 layers, 32 neurons, and tanh. + + We also modify solver options for this one + """ + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=4, + hidden_size=32, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 20), + sensitivity="adjoint", + solver="rk4", + ) + + +def node_l4_h32_tanh_rk4_autograd(num_classes): + """Return a classification node with 4 layers, 32 neurons, and tanh. + + We also modify solver options for this one + """ + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=4, + hidden_size=32, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 20), + sensitivity="autograd", + solver="rk4", + ) + + +def node_l4_h32_tanh_euler_adjoint(num_classes): + """Return a classification node with 4 layers, 32 neurons, and tanh. + + We also modify solver options for this one + """ + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=4, + hidden_size=32, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 80), + sensitivity="adjoint", + solver="euler", + ) + + +def node_l4_h32_tanh_euler_autograd(num_classes): + """Return a classification node with 4 layers, 32 neurons, and tanh. + + We also modify solver options for this one + """ + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=4, + hidden_size=32, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 80), + sensitivity="autograd", + solver="euler", + ) + + +def node_l2_h64_tanh_da(num_classes): + """Return a classification node with 2 layers, 64 neurons, and tanh.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def node_l2_h64_softplus_da(num_classes): + """Return a classification node with 2 layers, 64 neurons, and softplus.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.Softplus, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def node_l2_h64_sigmoid_da(num_classes): + """Return a classification node with 2 layers, 64 neurons, and sigmoid.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.Sigmoid, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def node_l2_h64_relu_da(num_classes): + """Return a classification node with 2 layers, 64 neurons, and relu.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=64, + module_activate=nn.ReLU, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def node_l2_h3_tanh_da(num_classes): + """Return a classification node with 2 layers, 3 neurons, and tanh.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=3, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def node_l2_h32_tanh_da(num_classes): + """Return a classification node with 2 layers, 32 neurons, and tanh.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=32, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def node_l2_h128_tanh_da(num_classes): + """Return a classification node with 2 layers, 128 neurons, and tanh.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=2, + hidden_size=128, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def node_l4_h32_tanh_da(num_classes): + """Return a classification node with 4 layers, 32 neurons, and tanh.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=4, + hidden_size=32, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) + + +def node_l4_h128_tanh_da(num_classes): + """Return a classification node with 4 layers, 128 neurons, and tanh.""" + return Node( + num_in=num_classes, + num_out=num_classes, + num_layers=4, + hidden_size=128, + module_activate=nn.Tanh, + s_span=torch.linspace(0, 1, 2), + sensitivity="adjoint", + solver="dopri5", + ) diff --git a/src/torchprune/torchprune/util/net.py b/src/torchprune/torchprune/util/net.py index f0081a5..1f7590f 100644 --- a/src/torchprune/torchprune/util/net.py +++ b/src/torchprune/torchprune/util/net.py @@ -60,6 +60,10 @@ def register_compressible_layers(self): continue if hasattr(module, "groups") and module.groups > 1: continue + if hasattr(self.torchnet, "is_compressible") and not self.torchnet.is_compressible( + module + ): + continue self.compressible_layers.append(module) self.num_weights.append(module.weight.data.numel()) @@ -76,9 +80,7 @@ def flops(self): flops = 0 if len(self.num_patches) == self.num_compressible_layers: for ell, module in enumerate(self.compressible_layers): - flops += ( - module.weight != 0.0 - ).sum().item() * self.num_patches[ell] + flops += (module.weight != 0.0).sum().item() * self.num_patches[ell] return flops def compressible_size(self): diff --git a/src/torchprune/torchprune/util/nn_loss.py b/src/torchprune/torchprune/util/nn_loss.py index ab4936b..47bdc95 100644 --- a/src/torchprune/torchprune/util/nn_loss.py +++ b/src/torchprune/torchprune/util/nn_loss.py @@ -1,8 +1,12 @@ """A module summarizing all the custom losses and the torch.nn losses.""" +import numpy as np +import torch import torch.nn as nn from torch.nn import CrossEntropyLoss, MSELoss # noqa: F403,F401 +from .external.ffjord.train_misc import standard_normal_logprob + class CrossEntropyLossWithAuxiliary(nn.CrossEntropyLoss): """Cross-entropy loss that can add auxiliary loss if present.""" @@ -24,3 +28,64 @@ class LossFromInput(nn.Module): def forward(self, input, target): """Return loss from the inputs and ignore targets.""" return input["loss"] if isinstance(input, dict) else input[0] + + +class NLLPriorLoss(nn.Module): + """Loss corresponding to NLL between output and prior.""" + + def forward(self, input, target): + """Return average NLL.""" + prior = input["prior"] + out = input["out"] + logprob = prior.log_prob(out[:, 1:]).to(out) - out[:, 0] + return -logprob.mean() + + +class NLLNatsLoss(nn.Module): + """Loss corresponding to standard normal logprob loss. + + Check out util.external.ffjord.train_tabular.compute_loss for more info. + """ + + def _compute_logprob(self, input): + """Compute and return standard normal log prob.""" + z_out = input["out"] + delta_logp = input["delta_logp"] + logpz = ( + standard_normal_logprob(z_out) + .view(z_out.shape[0], -1) + .sum(1, keepdim=True) + ) + logpx = logpz - delta_logp + loss = -torch.mean(logpx) + return loss + + def forward(self, input, target): + """Return average standard normal logprob loss.""" + loss = self._compute_logprob(input) + + # add regularizer loss if needed + if "reg_loss" in input: + loss += input["reg_loss"] + + # return overall loss + return loss + + +class NLLBitsLoss(NLLNatsLoss): + """Loss corresponding to logprob loss expressed in "bits/dim".""" + + def _compute_logprob(self, input): + """Compute and return log prob normalized as bits/dim.""" + z_out = input["out"] + delta_logp = input["delta_logp"] + logpz = ( + standard_normal_logprob(z_out) + .view(z_out.shape[0], -1) + .sum(1, keepdim=True) + ) + # averaged over batches + logpx = logpz - delta_logp + logpx_per_dim = torch.sum(logpx) / input["nelement"] + bits_per_dim = -(logpx_per_dim - np.log(256)) / np.log(2) + return bits_per_dim diff --git a/src/torchprune/torchprune/util/transforms.py b/src/torchprune/torchprune/util/transforms.py index 4f86818..eb2e036 100644 --- a/src/torchprune/torchprune/util/transforms.py +++ b/src/torchprune/torchprune/util/transforms.py @@ -132,3 +132,18 @@ def __call__(self, image, target): """Normalize image but not target.""" image = F.normalize(image, mean=self.mean, std=self.std) return image, target + + +class RandomNoise(object): + """Random uniform noise transformation.""" + + def __init__(self, normalization): + """Initialize with normalization constant.""" + self._normalization = normalization + + def __call__(self, image): + """Return noise image from the current image.""" + noise = image.new().resize_as_(image).uniform_() + image = image * self._normalization + noise + image = image / (self._normalization + 1) + return image