Pytorch implementation of PredNet.
This repo was first created by leido and was a helpful starting point. My implementation fixes the known issues of the leido's code such as blurry, black-and-white predictions.
Here's an example plot generating using prednet_relu_bug with default hyperparameters:
This implementation includes features not present in the original code such as the ability to toggle on peephole connections, between tied and untied bias weights, and between multiplicative and subtractive gating as developed by Costa et al.
"Deep Predictive Coding Networks for Video Prediction and Unsupervised Learning"(https://arxiv.org/abs/1605.08104)
The PredNet is a deep recurrent convolutional neural network that is inspired by the neuroscience concept of predictive coding (Rao and Ballard, 1999; Friston, 2005)
Original paper's code is writen in Keras. Examples and project website can be found here.
In leido's code, ConvLSTMCell is borrowed from here.
However, we significantly revamped this module with a proper implementation of peephole connections, gating options, and a more readable style.
Training a prednet model is done via kitti_train.py. Feel free to adjust the following training and model hyperparamters within the script:
- num_epochs: default- 150
- batch_size: default- 4
- lr: learning rate, default- 0.001
- nt: length of video sequences, default- 10
- n_train_seq: number of video sequences per training epoch, default- 500
- n_val_seq: number of video sequenced used for validation, default- 100
- loss_mode: 'L_0' or 'L_all', default- 'L_0'
- peephole: toggles incluse of peephole connection w/n the ConvLSTM, default- False
- lstm_tied_bias: toggles the tieing of biases w/n ConvLST, default- False
- gating_mode: toggles between multiplicative 'mul' or subtractive 'sub' gating w/n ConvLSTM, default- 'mul'
- A_channels & R_channels: number of channels within each layer of PredNet, default- (3, 48, 96, 192)
After training is complete, the script saves two versions of the model: prendet-*-best.pt (version with the lowest loss on validation set) and prednet-*.pt (version saved after the last epoch).
To test your models using kitti_test.py transfer them into your 'models' folder, set the testing and model hyperparamters accordingly, then run the script. It should output the MSE between the GT and predicted sequences as well as the MSE if the model simply predicted the previous frame at each time step.
The default parameters listed above reproduce the results in the paper when using prednet_relu_bug. However, prednet underperforms under these parameters and overfits the data. After a coarse hyperparameter search, we found that shrinking the model helped to alleviate overfitting.
Acquiring the dataset requires multiple steps: 1) downloading the zip files 2) extracting and processsing the images. Step 1 is done via running the download_raw_data_.sh scripts found in kitti_raw_data\raw<category>. Step 2 is handled by running the process_kitti.py.