Skip to content

Commit

Permalink
Merge pull request #341 from chaoming0625/master
Browse files Browse the repository at this point in the history
Updates
  • Loading branch information
chaoming0625 authored Mar 1, 2023
2 parents e0a3ee1 + cac4b1d commit 8692d3e
Show file tree
Hide file tree
Showing 75 changed files with 1,881 additions and 1,999 deletions.
14 changes: 10 additions & 4 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@
synapses, # synaptic dynamics
synouts, # synaptic output
synplast, # synaptic plasticity
experimental, # experimental model
syn,
)
from brainpy._src.dyn.base import not_pass_shargs
from brainpy._src.dyn.base import (DynamicalSystem as DynamicalSystem,
Module as Module,
from brainpy._src.dyn.base import not_pass_sha
from brainpy._src.dyn.base import (DynamicalSystem,
DynamicalSystemNS,
Container as Container,
Sequential as Sequential,
Network as Network,
Expand All @@ -77,6 +77,8 @@
from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations
LoopOverTime as LoopOverTime,)
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
from brainpy._src.dyn.context import share
from brainpy._src.dyn.delay import Delay


# Part 4: Training #
Expand Down Expand Up @@ -240,3 +242,7 @@
dyn.__dict__['NMDA'] = compat.NMDA
del compat


from brainpy._src import checking
tools.__dict__['checking'] = checking
del checking
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def f_cell(h: Dict):

# call update functions
args = (shared,) + self.args
target.update(*args)
target(*args)

# get new states
new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis))
Expand Down
6 changes: 3 additions & 3 deletions brainpy/_src/analysis/lowdim/lowdim_bifurcation.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,

if self._can_convert_to_one_eq():
if self.convert_type() == C.x_by_y:
X = self.resolutions[self.y_var].value
X = bm.as_jax(self.resolutions[self.y_var])
else:
X = self.resolutions[self.x_var].value
pars = tuple(self.resolutions[p].value for p in self.target_par_names)
X = bm.as_jax(self.resolutions[self.x_var])
pars = tuple(bm.as_jax(self.resolutions[p]) for p in self.target_par_names)
mesh_values = jnp.meshgrid(*((X,) + pars))
mesh_values = tuple(jnp.moveaxis(v, 0, 1).flatten() for v in mesh_values)
candidates = mesh_values[0]
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/analysis/lowdim/lowdim_phase_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ def plot_fixed_point(self, with_plot=True, with_return=False, show=False,

if self._can_convert_to_one_eq():
if self.convert_type() == C.x_by_y:
candidates = self.resolutions[self.y_var].value
candidates = bm.as_jax(self.resolutions[self.y_var])
else:
candidates = self.resolutions[self.x_var].value
candidates = bm.as_jax(self.resolutions[self.x_var])
else:
if select_candidates == 'fx-nullcline':
candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys()
Expand Down
File renamed without changes.
29 changes: 22 additions & 7 deletions brainpy/_src/checkpoints/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,13 @@
get_tensorstore_spec = None

from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import Collector
from brainpy.errors import (AlreadyExistsError,
MPACheckpointingRequiredError,
MPARestoreTargetRequiredError,
MPARestoreDataCorruptedError,
MPARestoreTypeNotMatchError,
InvalidCheckpointPath,
InvalidCheckpointError)
from brainpy.tools import DotDict
from brainpy.types import PyTree

__all__ = [
Expand Down Expand Up @@ -154,17 +152,27 @@ def from_state_dict(target, state: Dict[str, Any], name: str = '.'):
A copy of the object with the restored state.
"""
ty = _NamedTuple if _is_namedtuple(target) else type(target)
if ty not in _STATE_DICT_REGISTRY:
for t in _STATE_DICT_REGISTRY.keys():
if issubclass(ty, t):
ty = t
break
else:
return state
ty_from_state_dict = _STATE_DICT_REGISTRY[ty][1]
with _record_path(name):
return ty_from_state_dict(target, state)



def to_state_dict(target) -> Dict[str, Any]:
"""Returns a dictionary with the state of the given target."""
ty = _NamedTuple if _is_namedtuple(target) else type(target)
if ty not in _STATE_DICT_REGISTRY:

for t in _STATE_DICT_REGISTRY.keys():
if issubclass(ty, t):
ty = t
break
else:
return target

ty_to_state_dict = _STATE_DICT_REGISTRY[ty][0]
Expand Down Expand Up @@ -269,8 +277,9 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any]):

register_serialization_state(Array, _array_dict_state, _restore_array)
register_serialization_state(dict, _dict_state_dict, _restore_dict)
register_serialization_state(DotDict, _dict_state_dict, _restore_dict)
register_serialization_state(Collector, _dict_state_dict, _restore_dict)
# register_serialization_state(DotDict, _dict_state_dict, _restore_dict)
# register_serialization_state(Collector, _dict_state_dict, _restore_dict)
# register_serialization_state(ArrayCollector, _dict_state_dict, _restore_dict)
register_serialization_state(list, _list_state_dict, _restore_list)
register_serialization_state(tuple,
_list_state_dict,
Expand Down Expand Up @@ -1221,8 +1230,9 @@ def _save_main_ckpt_file2(target: bytes,
def save_pytree(
filename: str,
target: PyTree,
overwrite: bool = False,
overwrite: bool = True,
async_manager: Optional[AsyncManager] = None,
verbose: bool = True,
) -> None:
"""Save a checkpoint of the model. Suitable for single-host.
Expand Down Expand Up @@ -1250,12 +1260,16 @@ def save_pytree(
if defined, the save will run without blocking the main
thread. Only works for single host. Note that an ongoing save will still
block subsequent saves, to make sure overwrite/keep logic works correctly.
verbose: bool
Whether output the print information.
Returns
-------
out: str
Filename of saved checkpoint.
"""
if verbose:
print(f'Saving checkpoint into {filename}')
start_time = time.time()
# Make sure all saves are finished before the logic of checking and removing
# outdated checkpoints happens.
Expand Down Expand Up @@ -1284,6 +1298,7 @@ def save_main_ckpt_task():
end_time - start_time)



def multiprocess_save(
ckpt_dir: Union[str, os.PathLike],
target: PyTree,
Expand Down
1 change: 0 additions & 1 deletion brainpy/_src/dyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
channels, neurons, rates, # neuron related
synapses, synouts, synplast, # synapse related
networks,
layers, # ANN related
runners,
transform,
)
Expand Down
Loading

0 comments on commit 8692d3e

Please sign in to comment.