Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Dec 15, 2024
1 parent f8a1a66 commit 444f558
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
8 changes: 7 additions & 1 deletion braintools/file/msg_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def msgpack_save(
filename: str
str or pathlib-like path to store checkpoint files in.
target: Any
serializable flax object, usually a flax optimizer.
serializable object.
overwrite: bool
overwrite existing checkpoint files if a checkpoint at the
current or a later step already exits (default: False).
Expand Down Expand Up @@ -864,6 +864,9 @@ def msgpack_save(
os.makedirs(os.path.dirname(filename), exist_ok=True)
if not overwrite and os.path.exists(filename):
raise InvalidCheckpointPath(filename)

if isinstance(target, bst.util.FlattedDict):
target = target.to_nest()
target = to_bytes(target)

# Save the files via I/O sync or async.
Expand Down Expand Up @@ -943,6 +946,9 @@ def msgpack_load(
sys.stdout.flush()
file_size = os.path.getsize(filename)

if isinstance(target, bst.util.FlattedDict):
target = target.to_nest()

with open(filename, 'rb') as fp:
if parallel and fp.seekable():
buf_size = 128 << 20 # 128M buffer.
Expand Down
10 changes: 6 additions & 4 deletions braintools/metric/_firings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Union

import brainstate as bst
import brainunit as bu
import brainunit as u
import jax.numpy as jnp
import numpy as onp

Expand Down Expand Up @@ -56,8 +56,8 @@ def raster_plot(

def firing_rate(
spikes: bst.typing.ArrayLike,
width: Union[float, bu.Quantity],
dt: Union[float, bu.Quantity] = None
width: Union[float, u.Quantity],
dt: Union[float, u.Quantity] = None
):
r"""Calculate the mean firing rate over in a neuron group.
Expand Down Expand Up @@ -86,5 +86,7 @@ def firing_rate(
"""
dt = bst.environ.get_dt() if (dt is None) else dt
width1 = int(width / 2 / dt) * 2 + 1
window = jnp.ones(width1) * 1000 / width
window = u.math.ones(width1) / width
if isinstance(window, u.Quantity):
window = window.to_decimal(u.Hz)
return jnp.convolve(jnp.mean(spikes, axis=1), window, mode='same')
4 changes: 2 additions & 2 deletions braintools/visualize/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.
# ==============================================================================

import matplotlib.pyplot as plt
from matplotlib import RcParams

try:
import matplotlib.pyplot as plt
from matplotlib import RcParams
import scienceplots # noqa: F401

def exclude(rc: RcParams, keys: list):
Expand Down

0 comments on commit 444f558

Please sign in to comment.