Skip to content

Commit

Permalink
fixup! Add support for TensorFlow 2.16
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Mar 28, 2024
1 parent 6e92b97 commit 3db6004
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 25 deletions.
6 changes: 1 addition & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,10 @@ jobs:
- script: test
python-version: "3.11"
coverage-name: latest
- script: test
tf-version: tensorflow==2.4.4
python-version: "3.8"
coverage-name: oldest
- script: test
tf-version: tensorflow~=2.6.0
python-version: "3.8"
coverage-name: tf-2.6
coverage-name: oldest
- script: remote-docs
tf-version: tensorflow==2.10
python-version: "3.9"
Expand Down
6 changes: 3 additions & 3 deletions .nengobones.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ setup_py:
install_req:
- packaging>=20.9
- scipy>=1.0.0
- tensorflow>=2.4.4
- tensorflow>=2.6.0
tests_req:
- pytest>=6.1.0
- pytest-rng>=1.0.0
Expand Down Expand Up @@ -109,6 +109,6 @@ pyproject_toml: {}
version_py:
type: semver
major: 0
minor: 7
patch: 1
minor: 8
patch: 0
release: false
4 changes: 2 additions & 2 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ Release history
- Removed
- Fixed
0.7.1 (unreleased)
0.8.0 (unreleased)
==================

*Compatible with TensorFlow 2.4 - 2.16*
*Compatible with TensorFlow 2.6 - 2.16*

0.7.0 (July 20, 2023)
=====================
Expand Down
13 changes: 6 additions & 7 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@

# pylint: disable=ungrouped-imports
tf_version = version.parse(tf.__version__)
if tf_version < version.parse("2.6.0rc0"):
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin
elif tf_version < version.parse("2.9.0rc0"):
if tf_version < version.parse("2.9.0rc0"):
from keras.layers.recurrent import DropoutRNNCellMixin
elif tf_version < version.parse("2.13.0rc0"):
from keras.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin
Expand All @@ -32,7 +30,7 @@
from keras.layers import Layer as BaseRandomLayer


@keras.utils.register_keras_serializable("keras-lmu")
@tf.keras.utils.register_keras_serializable("keras-lmu")
class LMUCell(
DropoutRNNCellMixin, BaseRandomLayer
): # pylint: disable=too-many-ancestors
Expand Down Expand Up @@ -158,7 +156,8 @@ def __init__(
self.dropout = dropout
self.recurrent_dropout = recurrent_dropout
self.seed = seed
self.seed_generator = keras.random.SeedGenerator(seed)
if tf_version >= version.parse("2.16.0"):
self.seed_generator = keras.random.SeedGenerator(seed)

self.kernel = None
self.recurrent_kernel = None
Expand Down Expand Up @@ -503,7 +502,7 @@ def from_config(cls, config):
return super().from_config(config)


@keras.utils.register_keras_serializable("keras-lmu")
@tf.keras.utils.register_keras_serializable("keras-lmu")
class LMU(keras.layers.Layer): # pylint: disable=too-many-ancestors,abstract-method
"""
A layer of trainable low-dimensional delay systems.
Expand Down Expand Up @@ -771,7 +770,7 @@ def from_config(cls, config):
return super().from_config(config)


@keras.utils.register_keras_serializable("keras-lmu")
@tf.keras.utils.register_keras_serializable("keras-lmu")
class LMUFeedforward(
keras.layers.Layer
): # pylint: disable=too-many-ancestors,abstract-method
Expand Down
1 change: 1 addition & 0 deletions keras_lmu/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
def pytest_configure(config):
if version.parse(tf.__version__) >= version.parse("2.7.0"):
tf.debugging.disable_traceback_filtering()
if version.parse(tf.__version__) >= version.parse("2.16.0"):
keras.config.disable_traceback_filtering()
15 changes: 9 additions & 6 deletions keras_lmu/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,14 @@ def test_save_load_serialization(mode, tmp_path, trainable_theta, discretizer):

model = keras.Model(inp, out)

model.save(tmp_path / "model.keras")
model_path = (
tmp_path
if version.parse(tf.__version__) < version.parse("2.16.0")
else tmp_path / "model.keras"
)
model.save(model_path)

model_load = keras.models.load_model(tmp_path / "model.keras")
model_load = keras.models.load_model(model_path)

assert np.allclose(
model.predict(np.ones((32, 10, 32))), model_load.predict(np.ones((32, 10, 32)))
Expand Down Expand Up @@ -512,7 +517,7 @@ def test_fit(feedforward, discretizer, trainable_theta):
y_test = tf.ones((5, 1))
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(),
optimizer="adam",
metrics=["accuracy"],
)

Expand Down Expand Up @@ -613,9 +618,7 @@ def test_theta_update(discretizer, trainable_theta, tmp_path):
lmu = keras.layers.RNN(lmu_cell)(inputs)
model = keras.Model(inputs=inputs, outputs=lmu)

model.compile(
loss=keras.losses.MeanSquaredError(), optimizer=keras.optimizers.Adam()
)
model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")

# make sure theta_inv is set correctly to initial value
assert np.allclose(lmu_cell.theta_inv.numpy(), 1 / theta)
Expand Down
2 changes: 1 addition & 1 deletion keras_lmu/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
tagged with the version.
"""

version_info = (0, 7, 1)
version_info = (0, 8, 0)

name = "keras-lmu"
dev = 0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def read(*filenames, **kwargs):
install_req = [
"packaging>=20.9",
"scipy>=1.0.0",
"tensorflow>=2.4.4",
"tensorflow>=2.6.0",
]
docs_req = [
"matplotlib>=3.0.2,<3.4.3",
Expand Down

0 comments on commit 3db6004

Please sign in to comment.