Skip to content
/ keras-bcr Public

Batch Correlation Regularizer for TF2/Keras

License

Notifications You must be signed in to change notification settings

ulf1/keras-bcr

Repository files navigation

PyPI version PyPi downloads DOI

keras-bcr : Batch Correlation Regularizer for TF2/Keras

The batch correlation regularization (BCR) technique adds a penalty loss if the inputs and outputs before the skip-connection of a specific feature element are correlated. The correlation coefficients are computed for each feature element seperatly across the current batch.

Usage

from keras_bcr import BatchCorrRegularizer
import tensorflow as tf

# The BCR layer is added before the addition of the skip-connection
def build_resnet_block(inputs, units=64, activation="gelu",
                       dropout=0.4, bcr_rate=0.1):
    h = tf.keras.layers.Dense(units=units)(inputs)
    h = h = tf.keras.layers.Activation(activation=activation)(h)
    h = tf.keras.layers.Dropout(rate=dropout)(h)
    h = BatchCorrRegularizer(bcr_rate=bcr_rate)([h, inputs])  # << HERE
    outputs = tf.keras.layers.Add()([h, inputs])
    return outputs

# An model with 3 ResNet blocks
def build_model(input_dim):
    inputs = tf.keras.Input(shape=(input_dim,))
    h = build_resnet_block(inputs, units=input_dim)
    h = build_resnet_block(h, units=input_dim)
    outputs = build_resnet_block(h, units=input_dim)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

INPUT_DIM = 64
model = build_model(input_dim=INPUT_DIM)
model.compile(optimizer=tf.keras.optimizers.Adam(), loss="mean_squared_error")

BATCH_SZ = 128
X_train = tf.random.normal([BATCH_SZ, INPUT_DIM])
y_train = tf.random.normal([BATCH_SZ])

history = model.fit(X_train, y_train, verbose=1, epochs=2)

Appendix

Installation

The keras-bcr git repo is available as PyPi package

pip install keras-bcr
pip install git+ssh://git@github.com/ulf1/keras-bcr.git

Install a virtual environment

python3 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip
pip install -r requirements.txt --no-cache-dir
pip install -r requirements-dev.txt --no-cache-dir
pip install -r requirements-demo.txt --no-cache-dir

(If your git repo is stored in a folder with whitespaces, then don't use the subfolder .venv. Use an absolute path without whitespaces.)

Python commands

  • Jupyter for the examples: jupyter lab
  • Check syntax: flake8 --ignore=F401 --exclude=$(grep -v '^#' .gitignore | xargs | sed -e 's/ /,/g')
  • Run Unit Tests: PYTHONPATH=. pytest

Publish

python setup.py sdist 
twine upload -r pypi dist/*

Clean up

find . -type f -name "*.pyc" | xargs rm
find . -type d -name "__pycache__" | xargs rm -r
rm -r .pytest_cache
rm -r .venv

Support

Please open an issue for support.

Contributing

Please contribute using Github Flow. Create a branch, add commits, and open a pull request.

Acknowledgements

The "Evidence" project was funded by the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) - 433249742 (GU 798/27-1; GE 1119/11-1).

Maintenance

  • till 31.Aug.2023 (v0.2.0) the code repository was maintained within the DFG project 433249742
  • since 01.Sep.2023 (v0.3.0) the code repository is maintained by Ulf Hamster.

About

Batch Correlation Regularizer for TF2/Keras

Resources

License

Stars

Watchers

Forks

Sponsor this project

 

Packages

No packages published

Languages