diff --git a/braintools/file/msg_checkpoint.py b/braintools/file/msg_checkpoint.py index f0c2fe6..97f59b8 100644 --- a/braintools/file/msg_checkpoint.py +++ b/braintools/file/msg_checkpoint.py @@ -835,7 +835,7 @@ def msgpack_save( filename: str str or pathlib-like path to store checkpoint files in. target: Any - serializable flax object, usually a flax optimizer. + serializable object. overwrite: bool overwrite existing checkpoint files if a checkpoint at the current or a later step already exits (default: False). @@ -864,6 +864,9 @@ def msgpack_save( os.makedirs(os.path.dirname(filename), exist_ok=True) if not overwrite and os.path.exists(filename): raise InvalidCheckpointPath(filename) + + if isinstance(target, bst.util.FlattedDict): + target = target.to_nest() target = to_bytes(target) # Save the files via I/O sync or async. @@ -943,6 +946,9 @@ def msgpack_load( sys.stdout.flush() file_size = os.path.getsize(filename) + if isinstance(target, bst.util.FlattedDict): + target = target.to_nest() + with open(filename, 'rb') as fp: if parallel and fp.seekable(): buf_size = 128 << 20 # 128M buffer. diff --git a/braintools/metric/_firings.py b/braintools/metric/_firings.py index ae6990b..fdaeb15 100644 --- a/braintools/metric/_firings.py +++ b/braintools/metric/_firings.py @@ -18,7 +18,7 @@ from typing import Union import brainstate as bst -import brainunit as bu +import brainunit as u import jax.numpy as jnp import numpy as onp @@ -56,8 +56,8 @@ def raster_plot( def firing_rate( spikes: bst.typing.ArrayLike, - width: Union[float, bu.Quantity], - dt: Union[float, bu.Quantity] = None + width: Union[float, u.Quantity], + dt: Union[float, u.Quantity] = None ): r"""Calculate the mean firing rate over in a neuron group. @@ -86,5 +86,7 @@ def firing_rate( """ dt = bst.environ.get_dt() if (dt is None) else dt width1 = int(width / 2 / dt) * 2 + 1 - window = jnp.ones(width1) * 1000 / width + window = u.math.ones(width1) / width + if isinstance(window, u.Quantity): + window = window.to_decimal(u.Hz) return jnp.convolve(jnp.mean(spikes, axis=1), window, mode='same') diff --git a/braintools/visualize/style.py b/braintools/visualize/style.py index b880fea..5b7da57 100644 --- a/braintools/visualize/style.py +++ b/braintools/visualize/style.py @@ -13,10 +13,10 @@ # limitations under the License. # ============================================================================== -import matplotlib.pyplot as plt -from matplotlib import RcParams try: + import matplotlib.pyplot as plt + from matplotlib import RcParams import scienceplots # noqa: F401 def exclude(rc: RcParams, keys: list):