diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 9b3f8acb7..67073bee0 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -58,11 +58,11 @@ synapses, # synaptic dynamics synouts, # synaptic output synplast, # synaptic plasticity - experimental, # experimental model + syn, ) -from brainpy._src.dyn.base import not_pass_shargs -from brainpy._src.dyn.base import (DynamicalSystem as DynamicalSystem, - Module as Module, +from brainpy._src.dyn.base import not_pass_sha +from brainpy._src.dyn.base import (DynamicalSystem, + DynamicalSystemNS, Container as Container, Sequential as Sequential, Network as Network, @@ -77,6 +77,8 @@ from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations LoopOverTime as LoopOverTime,) from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner +from brainpy._src.dyn.context import share +from brainpy._src.dyn.delay import Delay # Part 4: Training # @@ -240,3 +242,7 @@ dyn.__dict__['NMDA'] = compat.NMDA del compat + +from brainpy._src import checking +tools.__dict__['checking'] = checking +del checking diff --git a/brainpy/_src/analysis/highdim/slow_points.py b/brainpy/_src/analysis/highdim/slow_points.py index 4515b9d64..74388c04e 100644 --- a/brainpy/_src/analysis/highdim/slow_points.py +++ b/brainpy/_src/analysis/highdim/slow_points.py @@ -752,7 +752,7 @@ def f_cell(h: Dict): # call update functions args = (shared,) + self.args - target.update(*args) + target(*args) # get new states new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis)) diff --git a/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py b/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py index 068f92efa..076b03c48 100644 --- a/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py +++ b/brainpy/_src/analysis/lowdim/lowdim_bifurcation.py @@ -225,10 +225,10 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False, if self._can_convert_to_one_eq(): if self.convert_type() == C.x_by_y: - X = self.resolutions[self.y_var].value + X = bm.as_jax(self.resolutions[self.y_var]) else: - X = self.resolutions[self.x_var].value - pars = tuple(self.resolutions[p].value for p in self.target_par_names) + X = bm.as_jax(self.resolutions[self.x_var]) + pars = tuple(bm.as_jax(self.resolutions[p]) for p in self.target_par_names) mesh_values = jnp.meshgrid(*((X,) + pars)) mesh_values = tuple(jnp.moveaxis(v, 0, 1).flatten() for v in mesh_values) candidates = mesh_values[0] diff --git a/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py b/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py index a303c3f97..667c62ec8 100644 --- a/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py +++ b/brainpy/_src/analysis/lowdim/lowdim_phase_plane.py @@ -290,9 +290,9 @@ def plot_fixed_point(self, with_plot=True, with_return=False, show=False, if self._can_convert_to_one_eq(): if self.convert_type() == C.x_by_y: - candidates = self.resolutions[self.y_var].value + candidates = bm.as_jax(self.resolutions[self.y_var]) else: - candidates = self.resolutions[self.x_var].value + candidates = bm.as_jax(self.resolutions[self.x_var]) else: if select_candidates == 'fx-nullcline': candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys() diff --git a/brainpy/tools/checking.py b/brainpy/_src/checking.py similarity index 100% rename from brainpy/tools/checking.py rename to brainpy/_src/checking.py diff --git a/brainpy/_src/checkpoints/serialization.py b/brainpy/_src/checkpoints/serialization.py index d93c04600..7fecee944 100644 --- a/brainpy/_src/checkpoints/serialization.py +++ b/brainpy/_src/checkpoints/serialization.py @@ -37,7 +37,6 @@ get_tensorstore_spec = None from brainpy._src.math.ndarray import Array -from brainpy._src.math.object_transform.base import Collector from brainpy.errors import (AlreadyExistsError, MPACheckpointingRequiredError, MPARestoreTargetRequiredError, @@ -45,7 +44,6 @@ MPARestoreTypeNotMatchError, InvalidCheckpointPath, InvalidCheckpointError) -from brainpy.tools import DotDict from brainpy.types import PyTree __all__ = [ @@ -154,17 +152,27 @@ def from_state_dict(target, state: Dict[str, Any], name: str = '.'): A copy of the object with the restored state. """ ty = _NamedTuple if _is_namedtuple(target) else type(target) - if ty not in _STATE_DICT_REGISTRY: + for t in _STATE_DICT_REGISTRY.keys(): + if issubclass(ty, t): + ty = t + break + else: return state ty_from_state_dict = _STATE_DICT_REGISTRY[ty][1] with _record_path(name): return ty_from_state_dict(target, state) + def to_state_dict(target) -> Dict[str, Any]: """Returns a dictionary with the state of the given target.""" ty = _NamedTuple if _is_namedtuple(target) else type(target) - if ty not in _STATE_DICT_REGISTRY: + + for t in _STATE_DICT_REGISTRY.keys(): + if issubclass(ty, t): + ty = t + break + else: return target ty_to_state_dict = _STATE_DICT_REGISTRY[ty][0] @@ -269,8 +277,9 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any]): register_serialization_state(Array, _array_dict_state, _restore_array) register_serialization_state(dict, _dict_state_dict, _restore_dict) -register_serialization_state(DotDict, _dict_state_dict, _restore_dict) -register_serialization_state(Collector, _dict_state_dict, _restore_dict) +# register_serialization_state(DotDict, _dict_state_dict, _restore_dict) +# register_serialization_state(Collector, _dict_state_dict, _restore_dict) +# register_serialization_state(ArrayCollector, _dict_state_dict, _restore_dict) register_serialization_state(list, _list_state_dict, _restore_list) register_serialization_state(tuple, _list_state_dict, @@ -1221,8 +1230,9 @@ def _save_main_ckpt_file2(target: bytes, def save_pytree( filename: str, target: PyTree, - overwrite: bool = False, + overwrite: bool = True, async_manager: Optional[AsyncManager] = None, + verbose: bool = True, ) -> None: """Save a checkpoint of the model. Suitable for single-host. @@ -1250,12 +1260,16 @@ def save_pytree( if defined, the save will run without blocking the main thread. Only works for single host. Note that an ongoing save will still block subsequent saves, to make sure overwrite/keep logic works correctly. + verbose: bool + Whether output the print information. Returns ------- out: str Filename of saved checkpoint. """ + if verbose: + print(f'Saving checkpoint into {filename}') start_time = time.time() # Make sure all saves are finished before the logic of checking and removing # outdated checkpoints happens. @@ -1284,6 +1298,7 @@ def save_main_ckpt_task(): end_time - start_time) + def multiprocess_save( ckpt_dir: Union[str, os.PathLike], target: PyTree, diff --git a/brainpy/_src/dyn/__init__.py b/brainpy/_src/dyn/__init__.py index 56d10ee90..6f03f4245 100644 --- a/brainpy/_src/dyn/__init__.py +++ b/brainpy/_src/dyn/__init__.py @@ -8,7 +8,6 @@ channels, neurons, rates, # neuron related synapses, synouts, synplast, # synapse related networks, - layers, # ANN related runners, transform, ) diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index b3eba41aa..078a0aba6 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -19,11 +19,13 @@ from brainpy.errors import NoImplementationError, UnsupportedError from brainpy.types import ArrayType, Shape +share = None + + __all__ = [ # general class 'DynamicalSystem', - 'Module', - 'FuncAsDynSys', + 'DynamicalSystemNS', # containers 'Container', 'Network', 'Sequential', 'System', @@ -48,7 +50,7 @@ SLICE_VARS = 'slice_vars' -def not_pass_shargs(func: Callable): +def not_pass_sha(func: Callable): """Label the update function as the one without passing shared arguments. The original update function explicitly requires shared arguments at the first place:: @@ -113,6 +115,8 @@ class DynamicalSystem(BrainPyObject): The model computation mode. It should be instance of :py:class:`~.Mode`. """ + pass_shared: bool = True + global_delay_data: Dict[str, Tuple[Union[bm.LengthDelay, None], Variable]] = dict() '''Global delay data, which stores the delay variables and corresponding delay targets. This variable is useful when the same target variable is used in multiple mappings, @@ -153,22 +157,32 @@ def __repr__(self): def __call__(self, *args, **kwargs): """The shortcut to call ``update`` methods.""" - if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'): - if len(args) and isinstance(args[0], dict): - for k, v in args[0].items(): - bm.share.save(k, v) - return self.update(*args[1:], **kwargs) + global share + if share is None: + from brainpy._src.dyn.context import share + + if self.pass_shared: + if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'): + if len(args) and isinstance(args[0], dict): + share.save(**args[0]) + return self.update(*args[1:], **kwargs) + else: + return self.update(*args, **kwargs) else: - return self.update(*args, **kwargs) + if len(args) and isinstance(args[0], dict): + return self.update(*args, **kwargs) + else: + # If first argument is not shared argument, + # we should get the shared arguments from the global context. + # However, users should set and update shared arguments + # in the global context when using this mode. + return self.update(share.get_shargs(), *args, **kwargs) else: - if len(args) and isinstance(args[0], dict): - return self.update(*args, **kwargs) + if len(args) and isinstance(args[0], dict): # it may be shared arguments + share.save(**args[0]) + return self.update(*args[1:], **kwargs) else: - # If first argument is not shared argument, - # we should get the shared arguments from the global context. - # However, users should set and update shared arguments - # in the global context when using this mode. - return self.update(bm.share.get_shargs(), *args, **kwargs) + return self.update(*args, **kwargs) def register_delay( self, @@ -299,18 +313,18 @@ def update(self, *args, **kwargs): """ raise NotImplementedError('Must implement "update" function by subclass self.') - def reset(self, batch_size=None): + def reset(self, *args, **kwargs): """Reset function which reset the whole variables in the model. """ - self.reset_state(batch_size) + self.reset_state(*args, **kwargs) - def reset_state(self, batch_size: Optional[int] = None): + def reset_state(self, *args, **kwargs): """Reset function which reset the states in the model. """ child_nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique() if len(child_nodes) > 0: for node in child_nodes.values(): - node.reset_state(batch_size=batch_size) + node.reset_state(*args, **kwargs) self.reset_local_delays(child_nodes) else: raise NotImplementedError('Must implement "reset_state" function by subclass self. ' @@ -387,59 +401,10 @@ def clear_input(self): pass -Module = DynamicalSystem +class DynamicalSystemNS(DynamicalSystem): + """Dynamical system without the need of shared parameters passing into ``update()`` function.""" - -class FuncAsDynSys(DynamicalSystem): - """Transform a Python function as a :py:class:`~.DynamicalSystem` - - Parameters - ---------- - target : Callable - The function to wrap. - child_objs : optional, BrainPyObject, sequence of BrainPyObject, dict - The nodes in the defined function ``f``. - dyn_vars : optional, ndarray, sequence of ndarray, dict - The dynamically changed variables. - name : optional, str - The name of the transformed object. - mode: Mode - The computation mode. - """ - - def __init__( - self, - target: Callable, - child_objs: Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]] = None, - dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None - ): - super().__init__(name=name, mode=mode) - - self.target = target - if child_objs is not None: - self.register_implicit_nodes(child_objs, node_cls=DynamicalSystem) - if dyn_vars is not None: - self.register_implicit_vars(dyn_vars) - - def update(self, *args, **kwargs): - """Update function of the transformed dynamical system.""" - return self.target(*args, **kwargs) - - def clear_input(self): - """Function for clearing input in the wrapped children dynamical system.""" - for child in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values(): - child.clear_input() - - def __repr__(self): - name = self.__class__.__name__ - indent = " " * (len(name) + 1) - indent2 = indent + " " * len('nodes=') - nodes = [tools.repr_context(str(n), indent2) for n in self.implicit_nodes.values()] - node_string = ", \n".join(nodes) - return (f'{name}(nodes=[{node_string}],\n' + - f'{indent}num_of_vars={len(self.implicit_vars)})') + pass_shared = False @@ -519,7 +484,7 @@ def update(self, tdi, *args, **kwargs): """ nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique() for node in nodes.values(): - node.update(tdi) + node(tdi) def __getitem__(self, item): """Overwrite the slice access (`self['']`). """ @@ -661,7 +626,9 @@ def update(self, s, x) -> ArrayType: The output tensor. """ for m in self._modules: - if isinstance(m, DynamicalSystem): + if isinstance(m, DynamicalSystemNS): + x = m(x) + elif isinstance(m, DynamicalSystem): x = m(s, x) else: x = m(x) @@ -698,6 +665,7 @@ def __init__( mode=mode, **ds_dict) + @not_pass_sha def update(self, *args, **kwargs): """Step function of a network. @@ -713,19 +681,18 @@ def update(self, *args, **kwargs): other_nodes = nodes - neuron_groups - synapse_groups - ds_views # shared arguments - shared = args[0] # update synapse nodes for node in synapse_groups.values(): - node.update(shared) + node() # update neuron nodes for node in neuron_groups.values(): - node.update(shared) + node() # update other types of nodes for node in other_nodes.values(): - node.update(shared) + node() # update delays self.update_local_delays(nodes) @@ -777,6 +744,8 @@ class NeuGroup(DynamicalSystem): .. versionadded:: 2.1.13 mode: Mode + The computing mode. + .. versionadded:: 2.2.0 """ @@ -825,13 +794,6 @@ def get_batch_shape(self, batch_size=None): def update(self, *args): """The function to specify the updating rule. - - Parameters - ---------- - tdi : DotDict - The shared arguments, especially time `t`, step `dt`, and iteration `i`. - x: Any - The input for a neuron group. """ raise NotImplementedError(f'Subclass of {self.__class__.__name__} must ' f'implement "update" function.') @@ -937,7 +899,7 @@ def update(self, tdi, pre_spike=None): raise NotImplementedError('Must implement "update" function by subclass self.') -class SynComponent(DynamicalSystem): +class _SynComponent(DynamicalSystem): """Base class for modeling synaptic components, including synaptic output, synaptic short-term plasticity, synaptic long-term plasticity, and others. """ @@ -946,7 +908,7 @@ class SynComponent(DynamicalSystem): master: SynConn def __init__(self, *args, **kwargs): - super(SynComponent, self).__init__(*args, **kwargs) + super(_SynComponent, self).__init__(*args, **kwargs) self._registered = False @@ -980,7 +942,7 @@ def __repr__(self): def __call__(self, *args, **kwargs): return self.filter(*args, **kwargs) - def clone(self) -> 'SynComponent': + def clone(self) -> '_SynComponent': """The function useful to clone a new object when it has been used.""" raise NotImplementedError @@ -988,7 +950,7 @@ def filter(self, g): raise NotImplementedError -class SynOut(SynComponent): +class SynOut(_SynComponent): """Base class for synaptic current output.""" def __init__( @@ -1023,14 +985,14 @@ def update(self, tdi): pass -class SynSTP(SynComponent): +class SynSTP(_SynComponent): """Base class for synaptic short-term plasticity.""" def update(self, tdi, pre_spike): pass -class SynLTP(SynComponent): +class SynLTP(_SynComponent): """Base class for synaptic long-term plasticity.""" def update(self, tdi, pre_spike): diff --git a/brainpy/_src/dyn/context.py b/brainpy/_src/dyn/context.py new file mode 100644 index 000000000..9293add7d --- /dev/null +++ b/brainpy/_src/dyn/context.py @@ -0,0 +1,121 @@ +""" +Context for brainpy computation. + +This context defines all shared data used in all modules in a computation. +""" + +from typing import Dict, Any, Union + +from brainpy._src.tools.dicts import DotDict +from brainpy._src.dyn.delay import Delay +from brainpy._src.math.environment import get_dt +from brainpy._src.math.object_transform.base import BrainPyObject, dyn_dict + +__all__ = [ + 'share', +] + + +class _ShareContext(BrainPyObject): + def __init__(self): + super().__init__() + + # Shared data across all nodes at current time step. + # ------------- + + self._arguments = DotDict() + self._delays: Dict[str, Delay] = dyn_dict() + + @property + def dt(self): + if 'dt' in self._arguments: + return self._arguments['dt'] + else: + return get_dt() + + @dt.setter + def dt(self, dt): + self.set_dt(dt) + + def set_dt(self, dt: Union[int, float]): + self._arguments['dt'] = dt + + def load(self, key, value: Any = None): + """Get the shared data by the ``key``. + + Args: + key (str): the key to indicate the data. + value (Any): the default value when ``key`` is not defined in the shared. + """ + if key == 'dt': + return self.dt + if key in self._arguments: + return self._arguments[key] + if key in self._delays: + return self._delays[key] + if value is None: + raise KeyError(f'Cannot found shared data of {key}.') + else: + return value + + def save(self, *args, **kwargs) -> None: + """Save shared arguments in the global context.""" + assert len(args) % 2 == 0 + for i in range(0, len(args), 2): + identifier = args[i * 2] + data = args[i * 2 + 1] + if isinstance(data, Delay): + if identifier in self._delays: + raise ValueError(f'{identifier} has been used. Please assign another name.') + self._delays[identifier] = data + else: + self._arguments[identifier] = data + for identifier, data in kwargs.items(): + if isinstance(data, Delay): + if identifier in self._delays: + raise ValueError(f'{identifier} has been used. Please assign another name.') + self._delays[identifier] = data + else: + self._arguments[identifier] = data + + def get_shargs(self) -> DotDict: + """Get all shared arguments in the global context.""" + return self._arguments.copy() + + def clear_delays(self, *delays) -> None: + """Clear all delay variables in this global context.""" + if len(delays): + for d in delays: + self._delays.pop(d) + else: + self._delays.clear() + + def clear_shargs(self, *args) -> None: + """Clear all shared arguments in the global context.""" + if len(args) > 0: + for a in args: + self._arguments.pop(a) + else: + self._arguments.clear() + + def clear(self) -> None: + """Clear all shared data in this computation context.""" + self._arguments.clear() + self._delays.clear() + + def __call__(self, *args, **kwargs): + return self.update(*args, **kwargs) + + def update(self, *args, **kwargs): + for delay in self._delays.values(): + delay.update() + + def reset(self, batch_size: int = None): + self.reset_state(batch_size=batch_size) + + def reset_state(self, batch_size: int = None): + for delay in self._delays.values(): + delay.reset_state(batch_size) + + +share = _ShareContext() diff --git a/brainpy/_src/dyn/delay.py b/brainpy/_src/dyn/delay.py new file mode 100644 index 000000000..35c6d33cc --- /dev/null +++ b/brainpy/_src/dyn/delay.py @@ -0,0 +1,297 @@ +# -*- coding: utf-8 -*- + +from typing import Union, Callable, Optional, Dict + +import jax +import jax.numpy as jnp +import numpy as np +from jax.lax import stop_gradient + +from brainpy import check +from brainpy import math as bm +from brainpy._src.dyn.base import DynamicalSystemNS +from brainpy._src.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE +from brainpy.check import is_integer, jit_error_checking + + +class Delay(DynamicalSystemNS): + """Delay variable which has a fixed delay length. + + The data in this delay variable is arranged as:: + + delay = 0 [ data + delay = 1 data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + + Parameters + ---------- + target: Variable + The initial delay data. + length: int + The delay data length. + before_t0: Any + The delay data. It can be a Python number, like float, int, boolean values. + It can also be arrays. Or a callable function or instance of ``Connector``. + Note that ``initial_delay_data`` should be arranged as the following way:: + + delay = 1 [ data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + method: str + The method used for updating delay. + + """ + + data: Optional[bm.Variable] + idx: Optional[bm.Variable] + length: int + + def __init__( + self, + target: bm.Variable, + length: int = 0, + before_t0: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, + entries: Optional[Dict] = None, + name: str = None, + method: str = ROTATE_UPDATE, + ): + + super().__init__(name=name) + if method is None: + if self.mode.is_a(bm.NonBatchingMode): + method = ROTATE_UPDATE + elif self.mode.is_parent_of(bm.TrainingMode): + method = CONCAT_UPDATE + else: + method = ROTATE_UPDATE + assert method in [ROTATE_UPDATE, CONCAT_UPDATE] + self.method = method + + # target + self.target = target + if not isinstance(target, bm.Variable): + raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') + + # delay length + self.length = is_integer(length, allow_none=False, min_bound=0) + + # delay data + if before_t0 is not None: + assert isinstance(before_t0, (int, float, bool, bm.Array, jax.Array, Callable)) + self._before_t0 = before_t0 + if length > 0: + self._init_data(length) + else: + self.data = None + + # time variables + if self.method == ROTATE_UPDATE: + self.idx = bm.Variable(stop_gradient(jnp.asarray(0, dtype=jnp.int32))) + + # other info + self._access_to_step = dict() + for entry, value in entries.items(): + self.register_entry(entry, value) + + def register_entry( + self, + entry: str, + delay_time: Optional[Union[float, bm.Array, Callable]] = None, + delay_step: Optional[Union[int, bm.Array, Callable]] = None, + ) -> 'Delay': + """Register an entry to access the data. + + Args: + entry (str): The entry to access the delay data. + delay_step: The delay step of the entry (must be an integer, denoting the delay step). + delay_time: The delay time of the entry (can be a float). + + Returns: + Return the self. + """ + if entry in self._access_to_step: + raise KeyError(f'Entry {entry} has been registered.') + + if delay_time is not None: + if delay_step is not None: + raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.') + if callable(delay_time): + delay_time = bm.as_jax(delay_time(self.delay_target_shape)) + delay_step = jnp.asarray(delay_time / bm.get_dt(), dtype=bm.get_int()) + elif isinstance(delay_time, float): + delay_step = int(delay_time / bm.get_dt()) + else: + delay_step = jnp.asarray(bm.as_jax(delay_time) / bm.get_dt(), dtype=bm.get_int()) + + # delay steps + if delay_step is None: + delay_type = 'none' + elif isinstance(delay_step, int): + delay_type = 'homo' + elif isinstance(delay_step, (bm.Array, jax.Array, np.ndarray)): + if delay_step.size == 1 and delay_step.ndim == 0: + delay_type = 'homo' + else: + delay_type = 'heter' + delay_step = bm.Array(delay_step) + elif callable(delay_step): + delay_step = delay_step(self.delay_target_shape) + delay_type = 'heter' + else: + raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' + f'integer, array of integers, callable function, brainpy.init.Initializer.') + if delay_type == 'heter': + if delay_step.dtype not in [jnp.int32, jnp.int64]: + raise ValueError('Only support delay steps of int32, int64. If your ' + 'provide delay time length, please divide the "dt" ' + 'then provide us the number of delay steps.') + if self.delay_target_shape[0] != delay_step.shape[0]: + raise ValueError(f'Shape is mismatched: {self.delay_target_shape[0]} != {delay_step.shape[0]}') + if delay_type == 'heter': + max_delay_step = int(max(delay_step)) + elif delay_type == 'homo': + max_delay_step = delay_step + else: + max_delay_step = None + + # delay variable + if max_delay_step is not None: + if self.length < max_delay_step: + self._init_data(max_delay_step) + self.length = max_delay_step + self._access_to_step[entry] = delay_step + return self + + def at(self, entry: str, *indices) -> bm.Array: + """Get the data at the given entry. + + Args: + entry (str): The entry to access the data. + *indices: + + Returns: + The data. + """ + assert isinstance(entry, str) + if entry not in self._access_to_step: + raise KeyError(f'Does not find delay entry "{entry}".') + delay_step = self._access_to_step[entry] + if delay_step is None: + return self.target.value + else: + if self.data is None: + return self.target.value + else: + if isinstance(delay_step, slice): + return self.retrieve(delay_step, *indices) + elif np.ndim(delay_step) == 0: + return self.retrieve(delay_step, *indices) + else: + if len(indices) == 0 and len(delay_step) == self.target.shape[0]: + indices = (jnp.arange(delay_step.size),) + return self.retrieve(delay_step, *indices) + + @property + def delay_target_shape(self): + """The data shape of the delay target.""" + return self.target.shape + + def __repr__(self): + name = self.__class__.__name__ + return (f'{name}(num_delay_step={self.length}, ' + f'delay_target_shape={self.delay_target_shape}, ' + f'update_method={self.method})') + + def _check_delay(self, delay_len): + raise ValueError(f'The request delay length should be less than the ' + f'maximum delay {self.length}. ' + f'But we got {delay_len}') + + def retrieve(self, delay_step, *indices): + """Retrieve the delay data according to the delay length. + + Parameters + ---------- + delay_step: int, ArrayType + The delay length used to retrieve the data. + """ + assert delay_step is not None + if check.is_checking(): + jit_error_checking(jnp.any(delay_step > self.length), self._check_delay, delay_step) + + if self.method == ROTATE_UPDATE: + delay_idx = (self.idx.value + delay_step) % (self.length + 1) + delay_idx = stop_gradient(delay_idx) + + elif self.method == CONCAT_UPDATE: + delay_idx = delay_step + + else: + raise ValueError(f'Unknown updating method "{self.method}"') + + # the delay index + if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): + raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') + indices = (delay_idx,) + tuple(indices) + + # the delay data + return self.data[indices] + + def update(self, latest_value: Optional[Union[bm.Array, jax.Array]] = None) -> None: + """Update delay variable with the new data. + """ + if self.data is not None: + # get the latest target value + if latest_value is None: + latest_value = self.target.value + + # update the delay data at the rotation index + if self.method == ROTATE_UPDATE: + self.idx.value = stop_gradient(bm.as_jax((self.idx - 1) % (self.length + 1))) + self.data[self.idx.value] = latest_value + + # update the delay data at the first position + elif self.method == CONCAT_UPDATE: + if self.length >= 2: + self.data.value = bm.vstack([latest_value, self.data[1:]]) + else: + self.data[0] = latest_value + + def reset_state(self, batch_size: int = None): + """Reset the delay data. + """ + # initialize delay data + if self.data is not None: + self._init_data(self.length, batch_size) + + # time variables + if self.method == ROTATE_UPDATE: + self.idx.value = stop_gradient(jnp.asarray(0, dtype=jnp.int32)) + + def _init_data(self, length, batch_size: int = None): + if batch_size is not None: + if self.target.batch_size != batch_size: + raise ValueError(f'The batch sizes of delay variable and target variable differ ' + f'({self.target.batch_size} != {batch_size}). ' + 'Please reset the target variable first, because delay data ' + 'depends on the target variable. ') + + if self.target.batch_axis is None: + batch_axis = None + else: + batch_axis = self.target.batch_axis + 1 + self.data = bm.Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype), + batch_axis=batch_axis) + # update delay data + self.data[0] = self.target.value + if isinstance(self._before_t0, (bm.Array, jax.Array, float, int, bool)): + self.data[1:] = self._before_t0 + elif callable(self._before_t0): + self.data[1:] = self._before_t0((length,) + self.target.shape, dtype=self.target.dtype) diff --git a/brainpy/_src/dyn/layers/__init__.py b/brainpy/_src/dyn/layers/__init__.py index ef2224d3b..68ae252a3 100644 --- a/brainpy/_src/dyn/layers/__init__.py +++ b/brainpy/_src/dyn/layers/__init__.py @@ -2,7 +2,6 @@ from .base import * from .dropout import * -from .linear import * from .nvar import * from .reservoir import * from .rnncells import * diff --git a/brainpy/_src/dyn/layers/base.py b/brainpy/_src/dyn/layers/base.py index 830267557..992d6243c 100644 --- a/brainpy/_src/dyn/layers/base.py +++ b/brainpy/_src/dyn/layers/base.py @@ -1,40 +1,8 @@ -# -*- coding: utf-8 -*- +from brainpy._src.dyn.base import DynamicalSystemNS -from typing import Optional - -import brainpy.math as bm -from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs - -__all__ = [ - 'Layer' -] - - -class Layer(DynamicalSystem): +class Layer(DynamicalSystemNS): """Base class for a layer of artificial neural network.""" - def __init__(self, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None): - super().__init__(name=name, mode=mode) - - def reset_state(self, batch_size: Optional[int] = None): - child_nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique() - if len(child_nodes) > 0: - for node in child_nodes.values(): - node.reset_state(batch_size=batch_size) - self.reset_local_delays(child_nodes) - else: - pass - - def clear_input(self): - child_nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique() - if len(child_nodes) > 0: - for node in child_nodes.values(): - node.clear_input() - else: - pass - - def update(self, *args): - raise NotImplementedError + def reset_state(self, *args, **kwargs): + pass diff --git a/brainpy/_src/dyn/layers/conv.py b/brainpy/_src/dyn/layers/conv.py index 6c1e51f7c..5fbf393fb 100644 --- a/brainpy/_src/dyn/layers/conv.py +++ b/brainpy/_src/dyn/layers/conv.py @@ -5,7 +5,7 @@ from jax import lax from brainpy import math as bm, tools, check -from brainpy._src.dyn.base import not_pass_shargs +from brainpy._src.dyn.base import not_pass_sha from brainpy._src.initialize import Initializer, XavierNormal, ZeroInit, parameter from brainpy.types import ArrayType from .base import Layer @@ -154,7 +154,7 @@ def _check_input_dim(self, x): raise ValueError(f"input channels={x.shape[-1]} needs to have " f"the same size as in_channels={self.in_channels}.") - @not_pass_shargs + @not_pass_sha def update(self, x): self._check_input_dim(x) w = self.w.value @@ -526,7 +526,7 @@ def __init__( def _check_input_dim(self, x): raise NotImplementedError - @not_pass_shargs + @not_pass_sha def update(self, x): self._check_input_dim(x) diff --git a/brainpy/_src/dyn/layers/dropout.py b/brainpy/_src/dyn/layers/dropout.py index f73988b5e..67d42387c 100644 --- a/brainpy/_src/dyn/layers/dropout.py +++ b/brainpy/_src/dyn/layers/dropout.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- +from brainpy._src.dyn.context import share from brainpy import math as bm, check from .base import Layer -from brainpy._src.dyn.base import not_pass_shargs __all__ = [ 'Dropout' @@ -48,10 +48,10 @@ def __init__( self.prob = check.is_float(prob, min_bound=0., max_bound=1.) self.rng = bm.random.default_rng(seed) - def update(self, s, x): - if s['fit']: + def update(self, x): + if share.load('fit'): keep_mask = self.rng.bernoulli(self.prob, x.shape) - return bm.where(bm.as_jax(keep_mask), x / self.prob, 0.) + return bm.where(keep_mask, x / self.prob, 0.) else: return x diff --git a/brainpy/_src/dyn/layers/function.py b/brainpy/_src/dyn/layers/function.py index 7f36179fc..b4a39f6f2 100644 --- a/brainpy/_src/dyn/layers/function.py +++ b/brainpy/_src/dyn/layers/function.py @@ -6,7 +6,6 @@ import brainpy.math as bm from brainpy import check from .base import Layer -from brainpy._src.dyn.base import not_pass_shargs __all__ = [ 'Activation', @@ -40,7 +39,6 @@ def __init__( self.activate_fun = activate_fun self.kwargs = kwargs - @not_pass_shargs def update(self, *args, **kwargs): return self.activate_fun(*args, **kwargs, **self.kwargs) @@ -64,7 +62,6 @@ def __init__( super().__init__(name, mode) check.is_subclass(self.mode, (bm.NonBatchingMode, bm.BatchingMode, bm.TrainingMode), self.name) - @not_pass_shargs def update(self, x): if isinstance(self.mode, bm.BatchingMode): return x.reshape((x.shape[0], -1)) @@ -84,6 +81,5 @@ def __init__( self._fun = fun self.kwargs = kwargs - @not_pass_shargs def update(self, *args, **kwargs): return self._fun(*args, **kwargs, **self.kwargs) diff --git a/brainpy/_src/dyn/layers/linear.py b/brainpy/_src/dyn/layers/linear.py index c5c21e4e3..711e12ce1 100644 --- a/brainpy/_src/dyn/layers/linear.py +++ b/brainpy/_src/dyn/layers/linear.py @@ -6,6 +6,7 @@ import jax.numpy as jnp from brainpy import math as bm +from brainpy._src.dyn.context import share from brainpy.algorithms import OnlineAlgorithm, OfflineAlgorithm from brainpy.check import is_initializer from brainpy.errors import MathError @@ -94,22 +95,19 @@ def __repr__(self): f'num_out={self.num_out}, ' f'mode={self.mode})') - def update(self, *args): - if len(args) == 1: - sha, x = dict(), bm.as_jax(args[0]) - else: - sha, x = args[0], bm.as_jax(args[1]) + def update(self, x): + x = bm.as_jax(x) res = x @ self.W if self.b is not None: res += self.b # online fitting data - if sha.get('fit', False) and self.online_fit_by is not None: + if share.load('fit', False) and self.online_fit_by is not None: self.fit_record['input'] = x self.fit_record['output'] = res # offline fitting data - if sha.get('fit', False) and self.offline_fit_by is not None: + if share.load('fit', False) and self.offline_fit_by is not None: self.fit_record['input'] = x self.fit_record['output'] = res return res @@ -207,5 +205,5 @@ class Identity(Layer): def __init__(self, *args, **kwargs) -> None: super(Identity, self).__init__(*args, **kwargs) - def update(self, *args): - return args[0] if len(args) == 1 else args[1] + def update(self, x): + return x diff --git a/brainpy/_src/dyn/layers/normalization.py b/brainpy/_src/dyn/layers/normalization.py index dc3c8e80f..6751e2bbe 100644 --- a/brainpy/_src/dyn/layers/normalization.py +++ b/brainpy/_src/dyn/layers/normalization.py @@ -4,6 +4,7 @@ from jax import lax, numpy as jnp +from brainpy._src.dyn.context import share from brainpy import math as bm, check from brainpy.initialize import ZeroInit, OneInit, Initializer, parameter from brainpy.types import ArrayType @@ -123,12 +124,12 @@ def __init__( def _check_input_dim(self, x): raise NotImplementedError - def update(self, sha, x): + def update(self, x): self._check_input_dim(x) x = bm.as_jax(x) - if sha['fit']: + if share.load('fit'): mean = jnp.mean(x, self.axis) mean_of_square = jnp.mean(_square(x), self.axis) if self.axis_name is not None: @@ -486,7 +487,7 @@ def __init__( self.bias = bm.TrainVar(parameter(self.bias_initializer, self.normalized_shape)) self.scale = bm.TrainVar(parameter(self.scale_initializer, self.normalized_shape)) - def update(self, sha, x): + def update(self,x): if x.shape[-len(self.normalized_shape):] != self.normalized_shape: raise ValueError(f'Expect the input shape should be (..., {", ".join(self.normalized_shape)}), ' f'but we got {x.shape}') @@ -571,7 +572,7 @@ def __init__( self.bias = bm.TrainVar(parameter(self.bias_initializer, self.num_channels)) self.scale = bm.TrainVar(parameter(self.scale_initializer, self.num_channels)) - def update(self, sha, x): + def update(self, x): assert x.shape[-1] == self.num_channels origin_shape, origin_dim = x.shape, x.ndim group_shape = (-1,) + x.shape[1:-1] + (self.num_groups, self.num_channels // self.num_groups) diff --git a/brainpy/_src/dyn/layers/nvar.py b/brainpy/_src/dyn/layers/nvar.py index 84c666748..43fc5c66f 100644 --- a/brainpy/_src/dyn/layers/nvar.py +++ b/brainpy/_src/dyn/layers/nvar.py @@ -9,7 +9,7 @@ import brainpy.math as bm from brainpy import check from .base import Layer -from brainpy._src.dyn.base import not_pass_shargs +from brainpy._src.dyn.base import not_pass_sha __all__ = [ 'NVAR' @@ -130,7 +130,6 @@ def reset_state(self, batch_size=None): else: self.store.value = jnp.zeros((self.num_delay, batch_size, self.num_in)) - @not_pass_shargs def update(self, x): all_parts = [] select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay diff --git a/brainpy/_src/dyn/layers/pooling.py b/brainpy/_src/dyn/layers/pooling.py index 0967e4bff..3ff24d8a4 100644 --- a/brainpy/_src/dyn/layers/pooling.py +++ b/brainpy/_src/dyn/layers/pooling.py @@ -8,7 +8,6 @@ from brainpy import math as bm, check from .base import Layer -from brainpy._src.dyn.base import not_pass_shargs __all__ = [ 'MaxPool', @@ -81,7 +80,6 @@ def __init__( f'padding should be sequence of Tuple[int, int]. {padding}' assert all([len(x) == 2 for x in padding]), f"each entry in padding {padding} must be length 2" - @not_pass_shargs def update(self, x): x = bm.as_jax(x) window_shape = self._infer_shape(x.ndim, self.kernel_size) @@ -258,7 +256,6 @@ def __init__( mode=mode, name=name) - @not_pass_shargs def update(self, x): x = bm.as_jax(x) window_shape = self._infer_shape(x.ndim, self.kernel_size) @@ -359,7 +356,6 @@ def __init__( # channel_axis self.channel_axis = check.is_integer(channel_axis, allow_none=True) - @not_pass_shargs def update(self, x): x = bm.as_jax(x) x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) @@ -525,7 +521,6 @@ def __init__( class _AvgPoolNd(_MaxPoolNd): - @not_pass_shargs def update(self, x): x = bm.as_jax(x) x_dim = self.pool_dim + (0 if self.channel_axis is None else 1) @@ -763,7 +758,6 @@ def __init__( raise ValueError("`target_size` must either be an int or tuple of length " f"{num_spatial_dims} containing ints.") - @not_pass_shargs def update(self, x): """Input-output mapping. diff --git a/brainpy/_src/dyn/layers/reservoir.py b/brainpy/_src/dyn/layers/reservoir.py index feffa3854..cc11fc053 100644 --- a/brainpy/_src/dyn/layers/reservoir.py +++ b/brainpy/_src/dyn/layers/reservoir.py @@ -10,7 +10,6 @@ from brainpy.tools import to_size from brainpy.types import ArrayType from .base import Layer -from brainpy._src.dyn.base import not_pass_shargs __all__ = [ 'Reservoir', @@ -124,7 +123,7 @@ def __init__( assert num_out > 0, f'Must be a positive integer, but we got {num_out}' self.leaky_rate = leaky_rate check.is_float(leaky_rate, 'leaky_rate', 0., 1.) - self.activation = getattr(bm.activations, activation) if isinstance(activation, str) else activation + self.activation = getattr(bm, activation) if isinstance(activation, str) else activation check.is_callable(self.activation, allow_none=False) self.activation_type = activation_type check.is_string(activation_type, 'activation_type', ['internal', 'external']) @@ -192,7 +191,6 @@ def __init__( def reset_state(self, batch_size=None): self.state.value = variable(jnp.zeros, batch_size, self.output_shape) - @not_pass_shargs def update(self, x): """Feedforward output.""" # inputs diff --git a/brainpy/_src/dyn/layers/rnncells.py b/brainpy/_src/dyn/layers/rnncells.py index c99b33ab2..4d792533d 100644 --- a/brainpy/_src/dyn/layers/rnncells.py +++ b/brainpy/_src/dyn/layers/rnncells.py @@ -6,19 +6,17 @@ import jax.numpy as jnp import brainpy.math as bm +from .base import Layer +from brainpy.check import (is_integer, + is_initializer) from brainpy.initialize import (XavierNormal, ZeroInit, Orthogonal, parameter, variable, Initializer) -from brainpy.check import (is_integer, - is_initializer) from brainpy.types import ArrayType -from .base import Layer from .conv import _GeneralConv -from brainpy._src.dyn.base import not_pass_shargs - __all__ = [ 'RNNCell', 'GRUCell', 'LSTMCell', @@ -117,7 +115,6 @@ def reset_state(self, batch_size=None): self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False) self.state[:] = self.state2train - @not_pass_shargs def update(self, x): h = x @ self.Wi h += self.state.value @ self.Wh @@ -228,7 +225,6 @@ def reset_state(self, batch_size=None): self.state2train.value = parameter(self._state_initializer, self.num_out, allow_none=False) self.state[:] = self.state2train - @not_pass_shargs def update(self, x): gates_x = jnp.matmul(x, bm.as_jax(self.Wi)) zr_x, a_x = jnp.split(gates_x, indices_or_sections=[2 * self.num_out], axis=-1) @@ -365,8 +361,7 @@ def reset_state(self, batch_size=None): self.state2train.value = parameter(self._state_initializer, self.num_out * 2, allow_none=False) self.state[:] = self.state2train - @not_pass_shargs - def update(self, sha, x): + def update(self, x): h, c = jnp.split(self.state.value, 2, axis=-1) gated = x @ self.Wi if self.b is not None: @@ -563,9 +558,7 @@ def reset_state(self, batch_size: int = 1): self.h[:] = self.h_to_train self.c[:] = self.c_to_train - @not_pass_shargs - def update(self, *args): - x = args[0] if len(args) == 1 else args[1] + def update(self, x): gates = self.input_to_hidden(x) + self.hidden_to_hidden(self.h) i, g, f, o = bm.split(gates, indices_or_sections=4, axis=-1) f = bm.sigmoid(f + 1) diff --git a/brainpy/_src/dyn/neurons/biological_models.py b/brainpy/_src/dyn/neurons/biological_models.py index 32d1c68aa..6238557a6 100644 --- a/brainpy/_src/dyn/neurons/biological_models.py +++ b/brainpy/_src/dyn/neurons/biological_models.py @@ -4,7 +4,8 @@ import brainpy.math as bm from brainpy import check -from brainpy._src.dyn.base import NeuGroup, not_pass_shargs +from brainpy._src.dyn.context import share +from brainpy._src.dyn.base import NeuGroup, not_pass_sha from brainpy._src.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable_ from brainpy._src.integrators.joint_eq import JointEq from brainpy._src.integrators.ode.generic import odeint @@ -274,15 +275,15 @@ def __init__( def reset_state(self, batch_size=None): self.V = variable_(self._V_initializer, self.varshape, batch_size) if self._m_initializer is None: - self.m = self.m_inf(self.V.value) + self.m = bm.Variable(self.m_inf(self.V.value), batch_axis=self.V.batch_axis) else: self.m = variable_(self._m_initializer, self.varshape, batch_size) if self._h_initializer is None: - self.h = self.h_inf(self.V.value) + self.h = bm.Variable(self.h_inf(self.V.value), batch_axis=self.V.batch_axis) else: self.h = variable_(self._h_initializer, self.varshape, batch_size) if self._n_initializer is None: - self.n = self.n_inf(self.V.value) + self.n = bm.Variable(self.n_inf(self.V.value), batch_axis=self.V.batch_axis) else: self.n = variable_(self._n_initializer, self.varshape, batch_size) self.input = variable_(bm.zeros, self.varshape, batch_size) @@ -299,9 +300,9 @@ def dV(self, V, t, m, h, n): def derivative(self): return JointEq(self.dV, self.dm, self.dh, self.dn) - @not_pass_shargs + @not_pass_sha def update(self, x=None): - s = bm.share.get_shargs() + s = share.get_shargs() if x is not None: self.input += x V, m, h, n = self.integral(self.V, self.m, self.h, self.n, s['t'], s['dt']) self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) diff --git a/brainpy/_src/dyn/neurons/input_groups.py b/brainpy/_src/dyn/neurons/input_groups.py index 5724ab8b7..e0532f208 100644 --- a/brainpy/_src/dyn/neurons/input_groups.py +++ b/brainpy/_src/dyn/neurons/input_groups.py @@ -3,9 +3,9 @@ from typing import Union, Sequence import jax.numpy as jnp - +from brainpy._src.dyn.context import share import brainpy.math as bm -from brainpy._src.dyn.base import NeuGroup, not_pass_shargs +from brainpy._src.dyn.base import NeuGroup, not_pass_sha from brainpy._src.initialize import Initializer, parameter, variable_ from brainpy.types import Shape, ArrayType @@ -41,7 +41,7 @@ def __init__( mode=mode) self.spike = None - @not_pass_shargs + @not_pass_sha def update(self, x): return x @@ -73,7 +73,7 @@ def __init__( mode=mode) self.spike = None - @not_pass_shargs + @not_pass_sha def update(self, x): return x @@ -135,22 +135,21 @@ def __init__( # data about times and indices self.times = jnp.asarray(times) self.indices = jnp.asarray(indices, dtype=bm.int_) - - # variables - self.i = bm.Variable(jnp.zeros(1, dtype=bm.int_)) - self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, mode) if need_sort: sort_idx = jnp.argsort(self.times) self.indices.value = self.indices[sort_idx] self.times.value = self.times[sort_idx] + # variables + self.reset_state(self.mode) + # functions def cond_fun(t): - i = self.i[0] + i = self.i.value return jnp.logical_and(i < self.num_times, t >= self.times[i]) def body_fun(t): - i = self.i[0] + i = self.i.value if isinstance(self.mode, bm.BatchingMode): self.spike[:, self.indices[i]] = True else: @@ -160,12 +159,14 @@ def body_fun(t): self._run = bm.make_while(cond_fun, body_fun, dyn_vars=self.vars()) def reset_state(self, batch_size=None): - self.i[0] = 1 - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + self.i = bm.Variable(bm.asarray(0)) + self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) - def update(self, tdi, x=None): - self.spike[:] = False - self._run(tdi['t']) + @not_pass_sha + def update(self): + self.spike.value = bm.zeros_like(self.spike) + self._run(share.load('t')) + return self.spike.value class PoissonGroup(NeuGroup): @@ -192,16 +193,18 @@ def __init__( self.freqs = parameter(freqs, self.num, allow_none=False) # variables - self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, self.mode) self.rng = bm.random.default_rng(seed) + self.reset_state(self.mode) - def update(self, tdi, x=None): - shape = (self.spike.shape[:1] + self.varshape) if isinstance(self.mode, bm.BatchingMode) else self.varshape - self.spike.update(self.rng.random(shape) <= (self.freqs * tdi['dt'] / 1000.)) + @not_pass_sha + def update(self, x=None): + spikes = self.rng.rand_like(self.spike) <= (self.freqs * share.dt / 1000.) + self.spike.value = spikes + return spikes def reset(self, batch_size=None): self.rng.value = bm.random.default_rng(self.seed) self.reset_state(batch_size) def reset_state(self, batch_size=None): - self.spike.value = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) + self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) diff --git a/brainpy/_src/dyn/neurons/noise_groups.py b/brainpy/_src/dyn/neurons/noise_groups.py index 01c2474aa..c6c9749f8 100644 --- a/brainpy/_src/dyn/neurons/noise_groups.py +++ b/brainpy/_src/dyn/neurons/noise_groups.py @@ -3,8 +3,9 @@ from typing import Union, Callable import jax.numpy as jnp +from brainpy._src.dyn.context import share from brainpy import math as bm, initialize as init -from brainpy._src.dyn.base import NeuGroup +from brainpy._src.dyn.base import NeuGroup, not_pass_sha from brainpy._src.initialize import Initializer from brainpy._src.integrators.sde.generic import sdeint from brainpy.types import ArrayType, Shape @@ -56,20 +57,19 @@ def __init__( ): super(OUProcess, self).__init__(size=size, name=name, keep_size=keep_size, mode=mode) - # parameters self.mean = init.parameter(mean, self.varshape, allow_none=False) self.sigma = init.parameter(sigma, self.varshape, allow_none=False) self.tau = init.parameter(tau, self.varshape, allow_none=False) # variables - self.x = init.variable_(lambda s: jnp.ones(s) * self.mean, self.varshape, self.mode) + self.reset_state(self.mode) # integral functions self.integral = sdeint(f=self.df, g=self.dg, method=method) def reset_state(self, batch_size=None): - self.x.value = init.variable_(lambda s: jnp.ones(s) * self.mean, self.varshape, batch_size) + self.x = init.variable_(lambda s: jnp.ones(s) * self.mean, self.varshape, batch_size) def df(self, x, t): return (self.mean - x) / self.tau @@ -77,5 +77,10 @@ def df(self, x, t): def dg(self, x, t): return self.sigma - def update(self, tdi): - self.x.value = self.integral(self.x, tdi['t'], tdi['dt']) + @not_pass_sha + def update(self): + t = share.load('t') + dt = share.load('dt') + self.x.value = self.integral(self.x, t, dt) + return self.x.value + diff --git a/brainpy/_src/dyn/neurons/reduced_models.py b/brainpy/_src/dyn/neurons/reduced_models.py index 28f80adee..419eb4599 100644 --- a/brainpy/_src/dyn/neurons/reduced_models.py +++ b/brainpy/_src/dyn/neurons/reduced_models.py @@ -6,9 +6,14 @@ from jax.lax import stop_gradient import brainpy.math as bm -from brainpy._src.dyn.base import NeuGroup, not_pass_shargs -from brainpy._src.initialize import (ZeroInit, OneInit, Initializer, - parameter, variable_, noise as init_noise) +from brainpy._src.dyn.base import NeuGroup, not_pass_sha +from brainpy._src.dyn.context import share +from brainpy._src.initialize import (ZeroInit, + OneInit, + Initializer, + parameter, + variable_, + noise as init_noise) from brainpy._src.integrators import sdeint, odeint, JointEq from brainpy.check import is_initializer, is_callable, is_subclass from brainpy.types import Shape, ArrayType @@ -79,14 +84,15 @@ def __init__( noise: Union[float, ArrayType, Initializer, Callable] = None, # other parameter + input_var: bool = True, name: str = None, mode: bm.Mode = None, method: str = 'exp_auto', ): - super(LeakyIntegrator, self).__init__(size=size, - mode=mode, - keep_size=keep_size, - name=name) + super().__init__(size=size, + mode=mode, + keep_size=keep_size, + name=name) is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) # parameters @@ -94,10 +100,10 @@ def __init__( self.tau = parameter(tau, self.varshape, allow_none=False) self.R = parameter(R, self.varshape, allow_none=False) self.noise = init_noise(noise, self.varshape) + self.input_var = input_var # initializers - is_initializer(V_initializer, 'V_initializer') - self._V_initializer = V_initializer + self._V_initializer = is_initializer(V_initializer) # integral if self.noise is None: @@ -113,15 +119,25 @@ def derivative(self, V, t, I_ext): def reset_state(self, batch_size=None): self.V = variable_(self._V_initializer, self.varshape, batch_size) - self.input = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, tdi, x=None): - if x is not None: - self.input += x - self.V.value = self.integral(self.V.value, tdi.t, self.input.value, tdi.dt) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + @not_pass_sha + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + self.V.value = self.integral(self.V.value, t, x, dt) + return self.V.value def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. class LIF(NeuGroup): @@ -200,14 +216,16 @@ def __init__( spike_fun: Callable = bm.surrogate.inv_square_grad, # other parameters + input_var: bool = True, + ref_var: bool = False, method: str = 'exp_auto', name: Optional[str] = None, ): # initialization - super(LIF, self).__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode) + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode) is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode), self.name) # parameters @@ -219,19 +237,14 @@ def __init__( self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) self.noise = init_noise(noise, self.varshape) self.spike_fun = is_callable(spike_fun, 'spike_fun') + self.input_var = input_var + self.ref_var = ref_var # initializers - is_initializer(V_initializer, 'V_initializer') - self._V_initializer = V_initializer + self._V_initializer = is_initializer(V_initializer) # variables - self.V = variable_(self._V_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool # the gradient of spike is a float - self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) - if self.tau_ref is not None: - self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, self.mode) - self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) + self.reset_state(self.mode) # integral if self.noise is None: @@ -243,20 +256,29 @@ def derivative(self, V, t, I_ext): return (-V + self.V_rest + self.R * I_ext) / self.tau def reset_state(self, batch_size=None): - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) - sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool # the gradient of spike is a float + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) - self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) - - def update(self, tdi, x=None): - t, dt = tdi.t, tdi.dt - if x is not None: self.input += x + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + if self.ref_var: + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + + @not_pass_sha + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, self.input.value, dt) + V = self.integral(self.V.value, t, x, dt) if self.tau_ref is not None: # refractory @@ -272,16 +294,17 @@ def update(self, tdi, x=None): V += (self.V_reset - V) * spike_no_grad spike_ = spike_no_grad > 0. # will be used in other place, like Delta Synapse, so stop its gradient - refractory = stop_gradient(bm.logical_or(refractory, spike_).value) + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) else: spike = V >= self.V_th V = bm.where(spike, self.V_reset, V) - refractory = bm.logical_or(refractory, spike) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) t_last_spike = bm.where(spike, t, self.t_last_spike.value) self.V.value = V self.spike.value = spike - self.refractory.value = refractory self.t_last_spike.value = t_last_spike else: @@ -295,9 +318,11 @@ def update(self, tdi, x=None): V = bm.where(spike, self.V_reset, V) self.V.value = V self.spike.value = spike + return spike def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. class ExpIF(NeuGroup): @@ -412,6 +437,8 @@ def __init__( V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), noise: Union[float, ArrayType, Initializer, Callable] = None, keep_size: bool = False, + input_var: bool = True, + ref_var: bool = False, mode: bm.Mode = None, method: str = 'exp_auto', name: str = None @@ -433,19 +460,14 @@ def __init__( self.tau = parameter(tau, self.varshape, allow_none=False) self.R = parameter(R, self.varshape, allow_none=False) self.noise = init_noise(noise, self.varshape) + self.input_var = input_var + self.ref_var = ref_var # initializers - is_initializer(V_initializer, 'V_initializer') - self._V_initializer = V_initializer + self._V_initializer = is_initializer(V_initializer) # variables - self.V = variable_(V_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) - self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, self.mode) - if self.tau_ref is not None: - self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) + self.reset_state(self.mode) # integral if self.noise is None: @@ -454,42 +476,50 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) - self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + if self.ref_var: + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def derivative(self, V, t, I_ext): exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) dvdt = (- (V - self.V_rest) + exp_v + self.R * I_ext) / self.tau return dvdt - def update(self, tdi, x=None): - t, dt = tdi.t, tdi.dt - if x is not None: self.input += x - V = self.integral(self.V.value, t, self.input.value, dt) - + @not_pass_sha + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V = self.integral(self.V.value, t, x, dt) if self.tau_ref is not None: refractory = (t - self.t_last_spike) <= self.tau_ref V = bm.where(refractory, self.V.value, V) spike = self.V_th <= V - t_last_spike = bm.where(spike, t, self.t_last_spike.value) V = bm.where(spike, self.V_reset, V) - self.refractory.value = bm.logical_or(refractory, spike) + self.t_last_spike.value = bm.where(spike, t, self.t_last_spike) + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) else: spike = self.V_th <= V - t_last_spike = bm.where(spike, t, self.t_last_spike.value) V = bm.where(spike, self.V_reset, V) - self.V.value = V self.spike.value = spike - self.t_last_spike.value = t_last_spike + return spike def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. class AdExIF(NeuGroup): @@ -586,6 +616,7 @@ def __init__( noise: Optional[Union[float, ArrayType, Initializer, Callable]] = None, method: str = 'exp_auto', keep_size: bool = False, + input_var: bool = True, mode: bm.Mode = None, name: Optional[str] = None ): @@ -608,22 +639,14 @@ def __init__( self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) self.delta_T = parameter(delta_T, self.varshape, allow_none=False) self.noise = init_noise(noise, self.varshape, num_vars=2) + self.input_var = input_var # initializers - is_initializer(V_initializer, 'V_initializer') - is_initializer(w_initializer, 'w_initializer') - self._V_initializer = V_initializer - self._w_initializer = w_initializer + self._V_initializer = is_initializer(V_initializer) + self._w_initializer = is_initializer(w_initializer) # variables - self.V = variable_(V_initializer, self.varshape, self.mode) - self.w = variable_(w_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - sp_type = bm.float_ if isinstance(self.mode, bm.BatchingMode) else bool - self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) - if self.tau_ref is not None: - self.refractory = variable_(partial(bm.zeros, dtype=bool), self.varshape, self.mode) - self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e8, self.varshape, self.mode) + self.reset_state(self.mode) # functions if self.noise is None: @@ -632,18 +655,15 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.w.value = variable_(self._w_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) - self.spike.value = variable_( - lambda s: bm.zeros(s, dtype=(bm.float_ - if isinstance(self.mode, bm.TrainingMode) - else bool)), - self.varshape, batch_size - ) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.w = variable_(self._w_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + sp_type = bm.float_ if isinstance(self.mode, bm.BatchingMode) else bool + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.refractory.value = variable_(partial(bm.zeros, dtype=bool), self.varshape, batch_size) - self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e8, self.varshape, batch_size) + self.refractory = variable_(partial(bm.zeros, dtype=bool), self.varshape, batch_size) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e8, self.varshape, batch_size) def dV(self, V, t, w, I_ext): exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) @@ -658,10 +678,17 @@ def dw(self, w, t, V): def derivative(self): return JointEq([self.dV, self.dw]) - def update(self, tdi, x=None): - t, dt = tdi.t, tdi.dt - if x is not None: self.input += x - V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt) + @not_pass_sha + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, w = self.integral(self.V.value, self.w.value, t, x, dt) if self.tau_ref is not None: refractory = (t - self.t_last_spike) <= self.tau_ref V = bm.where(refractory, self.V.value, V) @@ -672,9 +699,11 @@ def update(self, tdi, x=None): if self.tau_ref is not None: self.refractory.value = bm.logical_or(refractory, spike) self.t_last_spike.value = bm.where(spike, t, self.t_last_spike.value) + return spike def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. class QuaIF(NeuGroup): @@ -758,6 +787,7 @@ def __init__( V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), noise: Union[float, ArrayType, Initializer, Callable] = None, keep_size: bool = False, + input_var: bool = True, mode: bm.Mode = None, method: str = 'exp_auto', name: str = None @@ -779,19 +809,13 @@ def __init__( self.tau = parameter(tau, self.varshape, allow_none=False) self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) self.noise = init_noise(noise, self.varshape, num_vars=1) + self.input_var = input_var # initializers - is_initializer(V_initializer, '_V_initializer', allow_none=False) - self._V_initializer = V_initializer + self._V_initializer = is_initializer(V_initializer) # variables - self.V = variable_(V_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) - self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, self.mode) - if self.tau_ref is not None: - self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) + self.reset_state(self.mode) # integral if self.noise is None: @@ -800,22 +824,30 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) - self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def derivative(self, V, t, I_ext): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau return dVdt - def update(self, tdi, x=None): - t, dt = tdi.t, tdi.dt - if x is not None: self.input += x - V = self.integral(self.V.value, t, self.input.value, dt) + @not_pass_sha + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V = self.integral(self.V.value, t, x, dt) if self.tau_ref is not None: refractory = (t - self.t_last_spike) <= self.tau_ref V = bm.where(refractory, self.V.value, V) @@ -823,16 +855,16 @@ def update(self, tdi, x=None): t_last_spike = bm.where(spike, t, self.t_last_spike.value) V = bm.where(spike, self.V_reset, V) self.refractory.value = bm.logical_or(refractory, spike) + self.t_last_spike.value = t_last_spike else: spike = self.V_th <= V - t_last_spike = bm.where(spike, t, self.t_last_spike.value) V = bm.where(spike, self.V_reset, V) self.V.value = V self.spike.value = spike - self.t_last_spike.value = t_last_spike def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. class AdQuaIF(NeuGroup): @@ -929,6 +961,7 @@ def __init__( noise: Union[float, ArrayType, Initializer, Callable] = None, method: str = 'exp_auto', keep_size: bool = False, + input_var: bool = True, mode: bm.Mode = None, name: str = None ): @@ -949,20 +982,14 @@ def __init__( self.tau = parameter(tau, self.varshape, allow_none=False) self.tau_w = parameter(tau_w, self.varshape, allow_none=False) self.noise = init_noise(noise, self.varshape, num_vars=2) + self.input_var = input_var # initializers - is_initializer(V_initializer, 'V_initializer', allow_none=False) - is_initializer(w_initializer, 'w_initializer', allow_none=False) - self._V_initializer = V_initializer - self._w_initializer = w_initializer + self._V_initializer = is_initializer(V_initializer) + self._w_initializer = is_initializer(w_initializer) # variables - self.V = variable_(V_initializer, self.varshape, self.mode) - self.w = variable_(w_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) - self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) + self.reset_state(self.mode) # integral if self.noise is None: @@ -971,12 +998,13 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.w.value = variable_(self._w_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.w = variable_(self._w_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) - self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, w, I_ext): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I_ext) / self.tau @@ -990,17 +1018,26 @@ def dw(self, w, t, V): def derivative(self): return JointEq([self.dV, self.dw]) - def update(self, tdi, x=None): - t, dt = tdi.t, tdi.dt - if x is not None: self.input += x - V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt) + @not_pass_sha + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, w = self.integral(self.V.value, self.w.value, t, x, dt) spike = self.V_th <= V self.V.value = bm.where(spike, self.V_reset, V) self.w.value = bm.where(spike, w + self.b, w) self.spike.value = spike + return spike def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. class GIF(NeuGroup): @@ -1109,17 +1146,18 @@ def __init__( noise: Union[float, ArrayType, Initializer, Callable] = None, method: str = 'exp_auto', keep_size: bool = False, + input_var: bool = True, name: str = None, # parameter for training mode: bm.Mode = None, - spike_fun: Callable = bm.spike_with_sigmoid_grad, + spike_fun: Callable = bm.surrogate.sigmoid, ): # initialization - super(GIF, self).__init__(size=size, - keep_size=keep_size, - name=name, - mode=mode) + super().__init__(size=size, + keep_size=keep_size, + name=name, + mode=mode) is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) # params @@ -1139,25 +1177,16 @@ def __init__( self.A2 = parameter(A2, self.varshape, allow_none=False) self.noise = init_noise(noise, self.varshape, num_vars=4) self.spike_fun = is_callable(spike_fun, 'spike_fun') + self.input_var = input_var # initializers - is_initializer(V_initializer, 'V_initializer') - is_initializer(I1_initializer, 'I1_initializer') - is_initializer(I2_initializer, 'I2_initializer') - is_initializer(Vth_initializer, 'Vth_initializer') - self._V_initializer = V_initializer - self._I1_initializer = I1_initializer - self._I2_initializer = I2_initializer - self._Vth_initializer = Vth_initializer + self._V_initializer = is_initializer(V_initializer, 'V_initializer') + self._I1_initializer = is_initializer(I1_initializer, 'I1_initializer') + self._I2_initializer = is_initializer(I2_initializer, 'I2_initializer') + self._Vth_initializer = is_initializer(Vth_initializer, 'Vth_initializer') # variables - self.I1 = variable_(I1_initializer, self.varshape, self.mode) - self.I2 = variable_(I2_initializer, self.varshape, self.mode) - self.V_th = variable_(Vth_initializer, self.varshape, self.mode) - self.V = variable_(V_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) + self.reset_state(self.mode) # integral if self.noise is None: @@ -1166,13 +1195,14 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.I1.value = variable_(self._I1_initializer, self.varshape, batch_size) - self.I2.value = variable_(self._I2_initializer, self.varshape, batch_size) - self.V_th.value = variable_(self._Vth_initializer, self.varshape, batch_size) - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) - sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.I1 = variable_(self._I1_initializer, self.varshape, batch_size) + self.I2 = variable_(self._I2_initializer, self.varshape, batch_size) + self.V_th = variable_(self._Vth_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + sp_type = bm.float_ if self.mode.is_a(bm.TrainingMode) else bool + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) def dI1(self, I1, t): return - self.k1 * I1 @@ -1184,18 +1214,23 @@ def dVth(self, V_th, t, V): return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf) def dV(self, V, t, I1, I2, I_ext): - return (- (V - self.V_rest) + self.R * I_ext + self.R * I1 + self.R * I2) / self.tau + return (- (V - self.V_rest) + self.R * (I_ext + I1 + I2)) / self.tau @property def derivative(self): return JointEq([self.dI1, self.dI2, self.dVth, self.dV]) - def update(self, tdi, x=None): - t, dt = tdi.t, tdi.dt - - # integral - if x is not None: self.input += x - I1, I2, V_th, V = self.integral(self.I1, self.I2, self.V_th, self.V, t, self.input, dt=dt) + @not_pass_sha + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt) # spike and resets if isinstance(self.mode, bm.TrainingMode): @@ -1210,16 +1245,17 @@ def update(self, tdi, x=None): V = bm.where(spike, self.V_reset, V) I1 = bm.where(spike, self.R1 * I1 + self.A1, I1) I2 = bm.where(spike, self.R2 * I2 + self.A2, I2) - reset_th = bm.logical_and(V_th < self.V_th_reset, spike) - V_th = bm.where(reset_th, self.V_th_reset, V_th) + V_th = bm.where(spike, bm.maximum(self.V_th_reset, V_th), V_th) self.spike.value = spike self.I1.value = I1 self.I2.value = I2 self.V_th.value = V_th self.V.value = V + return spike def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. class ALIFBellec2020(NeuGroup): @@ -1274,6 +1310,7 @@ def __init__( # parameter for training spike_fun: Callable = bm.surrogate.relu_grad, + input_var: bool = True, # other parameters method: str = 'exp_auto', @@ -1281,10 +1318,10 @@ def __init__( mode: bm.Mode = None, eprop: bool = False ): - super(ALIFBellec2020, self).__init__(name=name, - size=size, - keep_size=keep_size, - mode=mode) + super().__init__(name=name, + size=size, + keep_size=keep_size, + mode=mode) is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) # parameters @@ -1298,22 +1335,14 @@ def __init__( self.noise = init_noise(noise, self.varshape, num_vars=2) self.spike_fun = is_callable(spike_fun, 'spike_fun') self.eprop = eprop + self.input_var = input_var # initializers - is_initializer(V_initializer, 'V_initializer') - is_initializer(a_initializer, 'a_initializer') - self._V_initializer = V_initializer - self._a_initializer = a_initializer + self._V_initializer = is_initializer(V_initializer, 'V_initializer') + self._a_initializer = is_initializer(a_initializer, 'a_initializer') # variables - self.a = variable_(a_initializer, self.varshape, self.mode) - self.V = variable_(V_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) - if self.tau_ref is not None: - self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, self.mode) - self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) + self.reset_state(self.mode) # integral if self.noise is None: @@ -1332,21 +1361,27 @@ def derivative(self): return JointEq([self.dV, self.da]) def reset_state(self, batch_size=None): - self.a.value = variable_(self._a_initializer, self.varshape, batch_size) - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.a = variable_(self._a_initializer, self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) - self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) - - def update(self, tdi, x=None): - t, dt = tdi.t, tdi.dt - - # integral - if x is not None: self.input += x - V, a = self.integral(self.V, self.a, t, self.input, dt) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + + @not_pass_sha + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, a = self.integral(self.V, self.a, t, x, dt) if self.tau_ref is not None: # refractory @@ -1381,9 +1416,11 @@ def update(self, tdi, x=None): self.spike.value = spike self.V.value = V self.a.value = a + spike + return spike def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. class Izhikevich(NeuGroup): @@ -1463,20 +1500,22 @@ def __init__( d: Union[float, ArrayType, Initializer, Callable] = 8., V_th: Union[float, ArrayType, Initializer, Callable] = 30., tau_ref: Union[float, ArrayType, Initializer, Callable] = None, - V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), - u_initializer: Union[Initializer, Callable, ArrayType] = OneInit(), + V_initializer: Union[Initializer, Callable, ArrayType] = None, + u_initializer: Union[Initializer, Callable, ArrayType] = None, noise: Union[float, ArrayType, Initializer, Callable] = None, method: str = 'exp_auto', mode: bm.Mode = None, spike_fun: Callable = bm.surrogate.inv_square_grad, keep_size: bool = False, + input_var: bool = True, + ref_var: bool = False, name: str = None ): # initialization - super(Izhikevich, self).__init__(size=size, - keep_size=keep_size, - name=name, - mode=mode) + super().__init__(size=size, + keep_size=keep_size, + name=name, + mode=mode) is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) # params @@ -1488,22 +1527,15 @@ def __init__( self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) self.noise = init_noise(noise, self.varshape, num_vars=2) self.spike_fun = is_callable(spike_fun, 'spike_fun') + self.input_var = input_var + self.ref_var = ref_var # initializers - is_initializer(V_initializer, 'V_initializer', allow_none=False) - is_initializer(u_initializer, 'u_initializer', allow_none=False) - self._V_initializer = V_initializer - self._u_initializer = u_initializer + self._V_initializer = is_initializer(V_initializer, allow_none=True) + self._u_initializer = is_initializer(u_initializer, allow_none=True) # variables - self.u = variable_(u_initializer, self.varshape, self.mode) - self.V = variable_(V_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) - if self.tau_ref is not None: - self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, self.mode) - self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) + self.reset_state(self.mode) # functions if self.noise is None: @@ -1512,14 +1544,18 @@ def __init__( self.integral = sdeint(method=method, f=JointEq([self.dV, self.du]), g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.u.value = variable_(self._u_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) + v_init = OneInit(-70.) if self._V_initializer is None else self._V_initializer + self.V = variable_(v_init, self.varshape, batch_size) + u_init = OneInit(self.b * self.V) if self._u_initializer is None else self._u_initializer + self.u = variable_(u_init, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) if self.tau_ref is not None: - self.t_last_spike.value = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) - self.refractory.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) + if self.ref_var: + self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, u, I_ext): dVdt = 0.04 * V * V + 5 * V + 140 - u + I_ext @@ -1529,17 +1565,21 @@ def du(self, u, t, V): dudt = self.a * (self.b * V - u) return dudt - def update(self, tdi, x=None): - t, dt = tdi.t, tdi.dt - - # integrate membrane potential - if x is not None: self.input += x - V, u = self.integral(self.V, self.u, t, self.input, dt) + @not_pass_sha + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, u = self.integral(self.V.value, self.u.value, t, x, dt) if self.tau_ref is not None: - refractory = (t - self.t_last_spike) <= self.tau_ref - if isinstance(self.mode, bm.TrainingMode): - refractory = stop_gradient(refractory) + refractory = bm.as_jax((t - self.t_last_spike) <= self.tau_ref) + refractory = stop_gradient(refractory) V = bm.where(refractory, self.V.value, V) # spike, refractory, and reset membrane potential @@ -1548,16 +1588,16 @@ def update(self, tdi, x=None): spike_no_grad = stop_gradient(spike) V += spike_no_grad * (self.c - self.V_th) u += spike_no_grad * self.d - spike_ = spike_no_grad > 0. - refractory = stop_gradient(bm.logical_or(refractory, spike_)) - t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + t_last_spike = stop_gradient(bm.where(spike_no_grad, t, self.t_last_spike.value)) + if self.ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_no_grad > 0.)) else: spike = self.V_th <= V V = bm.where(spike, self.c, V) u = bm.where(spike, u + self.d, u) - refractory = bm.logical_or(refractory, spike) t_last_spike = bm.where(spike, t, self.t_last_spike.value) - self.refractory.value = refractory + if self.ref_var: + self.refractory.value = bm.logical_or(refractory, spike) self.t_last_spike.value = t_last_spike else: @@ -1576,9 +1616,11 @@ def update(self, tdi, x=None): self.V.value = V self.u.value = u self.spike.value = spike + return spike def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. class HindmarshRose(NeuGroup): @@ -1697,6 +1739,7 @@ def __init__( noise: Union[float, ArrayType, Initializer, Callable] = None, method: str = 'exp_auto', keep_size: bool = False, + input_var: bool = True, name: str = None, # parameters for training @@ -1721,6 +1764,7 @@ def __init__( self.V_rest = parameter(V_rest, self.varshape, allow_none=False) self.noise = init_noise(noise, self.varshape, num_vars=3) self.spike_fun = is_callable(spike_fun, 'spike_fun') + self.input_var = input_var # variables is_initializer(V_initializer, 'V_initializer', allow_none=False) @@ -1731,12 +1775,7 @@ def __init__( self._z_initializer = z_initializer # variables - self.V = variable_(self._V_initializer, self.varshape, self.mode) - self.y = variable_(self._y_initializer, self.varshape, self.mode) - self.z = variable_(self._z_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, self.mode) + self.reset_state(self.mode) # integral if self.noise is None: @@ -1745,12 +1784,13 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.y.value = variable_(self._y_initializer, self.varshape, batch_size) - self.z.value = variable_(self._z_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.y = variable_(self._y_initializer, self.varshape, batch_size) + self.z = variable_(self._z_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) sp_type = bm.float_ if isinstance(self.mode, bm.TrainingMode) else bool - self.spike.value = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) + self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, batch_size) def dV(self, V, t, y, z, I_ext): return y - self.a * V * V * V + self.b * V * V - z + I_ext @@ -1765,10 +1805,17 @@ def dz(self, z, t, V): def derivative(self): return JointEq([self.dV, self.dy, self.dz]) - def update(self, tdi, x=None): - t, dt = tdi.t, tdi.dt - if x is not None: self.input += x - V, y, z = self.integral(self.V, self.y, self.z, t, self.input, dt=dt) + @not_pass_sha + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, y, z = self.integral(self.V, self.y, self.z, t, x, dt=dt) if isinstance(self.mode, bm.TrainingMode): self.spike.value = self.spike_fun(V - self.V_th, self.V - self.V_th) else: @@ -1776,9 +1823,11 @@ def update(self, tdi, x=None): self.V.value = V self.y.value = y self.z.value = z + return self.spike.value def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. class FHN(NeuGroup): @@ -1876,6 +1925,7 @@ def __init__( noise: Union[float, ArrayType, Initializer, Callable] = None, method: str = 'exp_auto', keep_size: bool = False, + input_var: bool = True, name: str = None, # parameters for training @@ -1894,6 +1944,7 @@ def __init__( self.tau = parameter(tau, self.varshape, allow_none=False) self.Vth = parameter(Vth, self.varshape, allow_none=False) self.noise = init_noise(noise, self.varshape, num_vars=2) + self.input_var = input_var # initializers is_initializer(V_initializer, 'V_initializer') @@ -1902,10 +1953,7 @@ def __init__( self._w_initializer = w_initializer # variables - self.V = variable_(self._V_initializer, self.varshape, self.mode) - self.w = variable_(self._w_initializer, self.varshape, self.mode) - self.input = variable_(bm.zeros, self.varshape, self.mode) - self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) + self.reset_state(self.mode) # integral if self.noise is None: @@ -1914,10 +1962,11 @@ def __init__( self.integral = sdeint(method=method, f=self.derivative, g=self.noise) def reset_state(self, batch_size=None): - self.V.value = variable_(self._V_initializer, self.varshape, batch_size) - self.w.value = variable_(self._w_initializer, self.varshape, batch_size) - self.input.value = variable_(bm.zeros, self.varshape, batch_size) - self.spike.value = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.w = variable_(self._w_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) def dV(self, V, t, w, I_ext): return V - V * V * V / 3 - w + I_ext @@ -1929,13 +1978,148 @@ def dw(self, w, t, V): def derivative(self): return JointEq([self.dV, self.dw]) - def update(self, tdi, x=None): - t, dt = tdi.t, tdi.dt - if x is not None: self.input += x - V, w = self.integral(self.V.value, self.w.value, t, self.input.value, dt=dt) + @not_pass_sha + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + V, w = self.integral(self.V.value, self.w.value, t, x, dt=dt) self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth) self.V.value = V self.w.value = w + return self.spike.value def clear_input(self): - self.input[:] = 0. + if self.input_var: + self.input[:] = 0. + + +class LIF_SFA_Bellec2020(NeuGroup): + r"""Leaky Integrate-and-Fire model with SFA [1]_. + + This model is similar to the GLIF2 model in the Technical White Paper + on generalized LIF (GLIF) models from AllenInstitute [2]_. + + Formally, this model is given by: + + .. math:: + + \tau \dot{V} = -(V - V_{\mathrm{rest}}) + R*I \\ + \tau_a \dot{a} = -a + + Once a spike is induced by :math:`V(t) > V_{\mathrm{th}} + \beta a`, then + + .. math:: + + V \gets V - V_{\mathrm{th}} \\ + a \gets a + 1 + + + References + ---------- + .. [1] Bellec, Guillaume, et al. "A solution to the learning dilemma for + recurrent networks of spiking neurons." + Nature communications 11.1 (2020): 1-15. + .. [2] Allen Institute: Cell Types Database. © 2018 Allen Institute for + Brain Science. Allen Cell Types Database, cell feature search. + Available from: celltypes.brain-map.org/data (2018). + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + + # model parameters + V_rest: Union[float, ArrayType, Initializer, Callable] = -70., + V_th: Union[float, ArrayType, Initializer, Callable] = -60., + R: Union[float, ArrayType, Initializer, Callable] = 1., + beta: Union[float, ArrayType, Initializer, Callable] = 1.6, + tau: Union[float, ArrayType, Initializer, Callable] = 20., + tau_a: Union[float, ArrayType, Initializer, Callable] = 2000., + tau_ref: Union[float, ArrayType, Initializer, Callable] = None, + + # initializers + V_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-70.), + a_initializer: Union[Initializer, Callable, ArrayType] = OneInit(-50.), + + # parameter for training + spike_fun: Callable = bm.surrogate.relu_grad, + + # other parameters + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(name=name, size=size, keep_size=keep_size, mode=mode) + is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) + + # parameters + self.V_rest = parameter(V_rest, self.varshape, allow_none=False) + self.V_th = parameter(V_th, self.varshape, allow_none=False) + self.R = parameter(R, self.varshape, allow_none=False) + self.beta = parameter(beta, self.varshape, allow_none=False) + self.tau = parameter(tau, self.varshape, allow_none=False) + self.tau_a = parameter(tau_a, self.varshape, allow_none=False) + self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) + self.spike_fun = is_callable(spike_fun, 'spike_fun') + + # initializers + self._V_initializer = is_initializer(V_initializer) + self._a_initializer = is_initializer(a_initializer) + + # variables + self.reset_state(self.mode) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + def da(self, a, t): + return -a / self.tau_a + + def dV(self, V, t, I_ext): + return (- (V - self.V_rest) + self.R * I_ext) / self.tau + + @property + def derivative(self): + return JointEq([self.dV, self.da]) + + def reset_state(self, batch_size=None): + self.a = variable_(self._a_initializer, self.varshape, batch_size) + self.V = variable_(self._V_initializer, self.varshape, batch_size) + self.spike = variable_(bm.zeros, self.varshape, batch_size) + if self.tau_ref is not None: + self.t_last_spike = variable_(OneInit(-1e7), self.varshape, batch_size) + + @not_pass_sha + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integral + V, a = self.integral(self.V, self.a, t, x, dt) + + if self.tau_ref is not None: + # refractory + refractory = stop_gradient((t - self.t_last_spike) <= self.tau_ref) + V = bm.where(refractory, self.V.value, V) + # spike and reset + spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) + V -= self.V_th * spike + t_last_spike = stop_gradient(bm.where(spike, t, self.t_last_spike.value)) + self.t_last_spike.value = t_last_spike + + else: + # spike and reset + spike = self.spike_fun((V - self.V_th - self.beta * self.a) / self.V_th) + V -= self.V_th * spike + self.spike.value = spike + self.V.value = V + self.a.value = a + spike + return spike diff --git a/brainpy/_src/dyn/rates/populations.py b/brainpy/_src/dyn/rates/populations.py index 170f8a385..fea21c514 100644 --- a/brainpy/_src/dyn/rates/populations.py +++ b/brainpy/_src/dyn/rates/populations.py @@ -164,13 +164,12 @@ def update(self, tdi, x=None): t, dt = tdi['t'], tdi['dt'] # input - if x is not None: self.input += x + if x is not None: + self.input += x if self.x_ou is not None: - self.input += self.x_ou.x - self.x_ou.update(tdi) + self.input += self.x_ou() if self.y_ou is not None: - self.input_y += self.y_ou.x - self.y_ou.update(tdi) + self.input_y += self.y_ou() # integral x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt) @@ -362,11 +361,9 @@ def update(self, tdi, x=None): if x is not None: self.input += x if self.x_ou is not None: - self.input += self.x_ou.x - self.x_ou.update(tdi) + self.input += self.x_ou() if self.y_ou is not None: - self.input_y += self.y_ou.x - self.y_ou.update(tdi) + self.input_y += self.y_ou() x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt) self.x.value = x @@ -545,11 +542,9 @@ def update(self, tdi, x=None): if x is not None: self.input += x if self.x_ou is not None: - self.input += self.x_ou.x - self.x_ou.update(tdi) + self.input += self.x_ou() if self.y_ou is not None: - self.input_y += self.y_ou.x - self.y_ou.update(tdi) + self.input_y += self.y_ou() x, y = self.integral(self.x, self.y, t=t, x_ext=self.input, y_ext=self.input_y, dt=dt) self.x.value = x @@ -680,11 +675,9 @@ def update(self, tdi, x=None): if x is not None: self.input += x if self.x_ou is not None: - self.input += self.x_ou.x - self.x_ou.update(tdi) + self.input += self.x_ou() if self.y_ou is not None: - self.input_y += self.y_ou.x - self.y_ou.update(tdi) + self.input_y += self.y_ou() x, y = self.integral(self.x, self.y, @@ -845,11 +838,9 @@ def update(self, tdi, x=None): t, dt = tdi['t'], tdi['dt'] if x is not None: self.input += x if self.x_ou is not None: - self.input += self.x_ou.x - self.x_ou.update(tdi) + self.input += self.x_ou() if self.y_ou is not None: - self.input_y += self.y_ou.x - self.y_ou.update(tdi) + self.input_y += self.y_ou() x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt) self.x.value = x self.y.value = y diff --git a/brainpy/_src/dyn/runners.py b/brainpy/_src/dyn/runners.py index dd59e473c..30fe76dc8 100644 --- a/brainpy/_src/dyn/runners.py +++ b/brainpy/_src/dyn/runners.py @@ -15,6 +15,7 @@ from brainpy import math as bm, tools from brainpy._src.dyn.base import DynamicalSystem +from brainpy._src.dyn.context import share from brainpy._src.running.runner import Runner from brainpy.check import is_float, serialize_kwargs from brainpy.errors import RunningError, NoLongerSupportError @@ -451,6 +452,7 @@ def predict( warnings.warn('"inputs" has already has the time information. ' 'Therefore there no longer need to provide "duration".', UserWarning) + duration = None num_step = self._get_input_time_step(duration, inputs) description = f'Predict {num_step} steps: ' @@ -615,13 +617,12 @@ def _step_func_predict(self, shared_args, t, i, x): # input step shared = tools.DotDict(t=t, i=i, dt=self.dt) shared.update(shared_args) - for k, v in shared.items(): - bm.share.save(k, v) + share.save(**shared) self.target.clear_input() self._step_func_input(shared) # dynamics update step - args = (shared,) if x is None else (shared, x) + args = () if x is None else (x,) out = self.target(*args) # monitor step @@ -631,6 +632,7 @@ def _step_func_predict(self, shared_args, t, i, x): # finally if self.progress_bar: id_tap(lambda *arg: self._pbar.update(), ()) + share.clear_shargs() return out, mon def _get_f_predict(self, shared_args: Dict = None): diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py index b5e162a9f..f83268ffc 100644 --- a/brainpy/_src/dyn/synapses/abstract_models.py +++ b/brainpy/_src/dyn/synapses/abstract_models.py @@ -9,11 +9,11 @@ from brainpy._src import tools from brainpy._src.connect import TwoEndConnector, All2All, One2One from brainpy._src.dyn.base import NeuGroup, SynOut, SynSTP, TwoEndConn, SynConn +from brainpy._src.dyn.synouts import CUBA, MgBlock from brainpy._src.initialize import Initializer, variable_ from brainpy._src.integrators import odeint, JointEq from brainpy.check import is_integer, is_float, is_subclass from brainpy.types import ArrayType -from ..synouts import CUBA, MgBlock __all__ = [ 'Delta', diff --git a/brainpy/_src/experimental/__init__.py b/brainpy/_src/dyn/synapses_v2/__init__.py similarity index 100% rename from brainpy/_src/experimental/__init__.py rename to brainpy/_src/dyn/synapses_v2/__init__.py diff --git a/brainpy/_src/experimental/synapses.py b/brainpy/_src/dyn/synapses_v2/abstract_models.py similarity index 50% rename from brainpy/_src/experimental/synapses.py rename to brainpy/_src/dyn/synapses_v2/abstract_models.py index 1424e3902..25f1de478 100644 --- a/brainpy/_src/experimental/synapses.py +++ b/brainpy/_src/dyn/synapses_v2/abstract_models.py @@ -1,129 +1,21 @@ -from typing import Union, Dict, Callable, Optional, Tuple +# -*- coding: utf-8 -*- -import jax -import numpy as np +from typing import Union, Dict, Callable, Optional + +from jax import vmap import brainpy.math as bm -from brainpy import check from brainpy._src import tools -from brainpy._src.connect import TwoEndConnector, All2All, One2One, MatConn, IJConn -from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs -from brainpy._src.initialize import Initializer, variable_, parameter +from brainpy._src.connect import TwoEndConnector, All2All, One2One +from brainpy._src.dyn.context import share +from brainpy._src.dyn.synapses_v2.base import SynConn, SynOut, SynSTP +from brainpy._src.initialize import Initializer, variable_ from brainpy._src.integrators import odeint +from brainpy.check import is_float from brainpy.types import ArrayType -from .synout import SynOut -from .synstp import SynSTP - - -class Synapse(DynamicalSystem): - def __init__( - self, - conn: TwoEndConnector, - out: Optional[SynOut] = None, - stp: Optional[SynSTP] = None, - name: str = None, - mode: bm.Mode = None, - ): - super().__init__(name=name, mode=mode) - - # parameters - assert isinstance(conn, TwoEndConnector) - self.conn = self._init_conn(conn) - self.pre_size = conn.pre_size - self.post_size = conn.post_size - self.pre_num = conn.pre_num - self.post_num = conn.post_num - assert out is None or isinstance(out, SynOut) - assert stp is None or isinstance(stp, SynSTP) - self.out = out - self.stp = stp - def _init_conn(self, conn): - if isinstance(conn, TwoEndConnector): - pass - elif isinstance(conn, (bm.ndarray, np.ndarray, jax.Array)): - if (self.pre_num, self.post_num) != conn.shape: - raise ValueError(f'"conn" is provided as a matrix, and it is expected ' - f'to be an array with shape of (self.pre_num, self.post_num) = ' - f'{(self.pre_num, self.post_num)}, however we got {conn.shape}') - conn = MatConn(conn_mat=conn) - elif isinstance(conn, dict): - if not ('i' in conn and 'j' in conn): - raise ValueError(f'"conn" is provided as a dict, and it is expected to ' - f'be a dictionary with "i" and "j" specification, ' - f'however we got {conn}') - conn = IJConn(i=conn['i'], j=conn['j']) - elif conn is None: - conn = None - else: - raise ValueError(f'Unknown "conn" type: {conn}') - return conn - - def _init_weights( - self, - weight: Union[float, ArrayType, Initializer, Callable], - comp_method: str, - data_if_sparse: str = 'csr' - ) -> Tuple[Union[float, ArrayType], ArrayType]: - if comp_method not in ['sparse', 'dense']: - raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}') - if data_if_sparse not in ['csr', 'ij', 'coo']: - raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {data_if_sparse}') - # connections and weights - if isinstance(self.conn, One2One): - weight = parameter(weight, (self.pre_num,), allow_none=False) - conn_mask = None - - elif isinstance(self.conn, All2All): - weight = parameter(weight, (self.pre_num, self.post_num), allow_none=False) - conn_mask = None - - else: - if comp_method == 'sparse': - if data_if_sparse == 'csr': - conn_mask = self.conn.require('pre2post') - elif data_if_sparse in ['ij', 'coo']: - conn_mask = self.conn.require('post_ids', 'pre_ids') - else: - ValueError(f'Unknown sparse data type: {data_if_sparse}') - weight = parameter(weight, conn_mask[0].shape, allow_none=False) - elif comp_method == 'dense': - weight = parameter(weight, (self.pre_num, self.post_num), allow_none=False) - conn_mask = self.conn.require('conn_mat') - else: - raise ValueError(f'Unknown connection type: {comp_method}') - - # training weights - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - return weight, conn_mask - - def _syn2post_with_all2all(self, syn_value, syn_weight, include_self): - if bm.ndim(syn_weight) == 0: - if isinstance(self.mode, bm.BatchingMode): - post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) - else: - post_vs = bm.sum(syn_value) - if not include_self: - post_vs = post_vs - syn_value - post_vs = syn_weight * post_vs - else: - post_vs = syn_value @ syn_weight - return post_vs - - def _syn2post_with_one2one(self, syn_value, syn_weight): - return syn_value * syn_weight - - def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat): - if bm.ndim(syn_weight) == 0: - post_vs = (syn_weight * syn_value) @ conn_mat - else: - post_vs = syn_value @ (syn_weight * conn_mat) - return post_vs - - -class Exponential(Synapse): +class Exponential(SynConn): r"""Exponential decay synapse model. **Model Descriptions** @@ -189,15 +81,15 @@ def __init__( name: str = None, mode: bm.Mode = None, ): - super(Exponential, self).__init__(conn=conn, - out=out, - stp=stp, - name=name, - mode=mode) + super().__init__(conn=conn, + out=out, + stp=stp, + name=name, + mode=mode) # parameters self.comp_method = comp_method - self.tau = check.is_float(tau, allow_int=True) + self.tau = is_float(tau, allow_int=True) # connections and weights self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, data_if_sparse='csr') @@ -215,7 +107,6 @@ def reset_state(self, batch_size=None): if self.stp is not None: self.stp.reset_state(batch_size) - @not_pass_shargs def update(self, pre_spike): if self.stp is not None: syn_value = self.stp(pre_spike) * pre_spike @@ -239,7 +130,7 @@ def update(self, pre_spike): shape=(self.pre_num, self.post_num), transpose=True) if isinstance(self.mode, bm.BatchingMode): - f = jax.vmap(f) + f = vmap(f) post_vs = f(pre_spike) else: f = lambda s: bl.sparse_ops.cusparse_csr_matvec(self.g_max, @@ -249,17 +140,16 @@ def update(self, pre_spike): shape=(self.pre_num, self.post_num), transpose=True) if isinstance(self.mode, bm.BatchingMode): - f = jax.vmap(f) + f = vmap(f) post_vs = f(syn_value) else: post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) # updates - self.g.value = self.integral(self.g.value, bm.share.load('t'), bm.dt) + post_vs + self.g.value = self.integral(self.g.value, share.load('t'), bm.dt) + post_vs # outputs if self.out is not None: return self.out(self.g.value) else: return self.g.value - diff --git a/brainpy/_src/dyn/synapses_v2/base.py b/brainpy/_src/dyn/synapses_v2/base.py new file mode 100644 index 000000000..bcced8c0b --- /dev/null +++ b/brainpy/_src/dyn/synapses_v2/base.py @@ -0,0 +1,133 @@ +from typing import Union, Callable, Optional, Tuple + +import jax +import numpy as np + +import brainpy.math as bm +from brainpy._src.connect import TwoEndConnector, All2All, One2One, MatConn, IJConn +from brainpy._src.dyn.base import DynamicalSystemNS +from brainpy._src.initialize import Initializer, parameter +from brainpy.types import ArrayType + + +class SynConn(DynamicalSystemNS): + def __init__( + self, + conn: TwoEndConnector, + out: Optional['SynOut'] = None, + stp: Optional['SynSTP'] = None, + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(name=name, mode=mode) + + # parameters + assert isinstance(conn, TwoEndConnector) + self.conn = self._init_conn(conn) + self.pre_size = conn.pre_size + self.post_size = conn.post_size + self.pre_num = conn.pre_num + self.post_num = conn.post_num + assert out is None or isinstance(out, SynOut) + assert stp is None or isinstance(stp, SynSTP) + self.out = out + self.stp = stp + + def _init_conn(self, conn): + if isinstance(conn, TwoEndConnector): + pass + elif isinstance(conn, (bm.ndarray, np.ndarray, jax.Array)): + if (self.pre_num, self.post_num) != conn.shape: + raise ValueError(f'"conn" is provided as a matrix, and it is expected ' + f'to be an array with shape of (self.pre_num, self.post_num) = ' + f'{(self.pre_num, self.post_num)}, however we got {conn.shape}') + conn = MatConn(conn_mat=conn) + elif isinstance(conn, dict): + if not ('i' in conn and 'j' in conn): + raise ValueError(f'"conn" is provided as a dict, and it is expected to ' + f'be a dictionary with "i" and "j" specification, ' + f'however we got {conn}') + conn = IJConn(i=conn['i'], j=conn['j']) + elif conn is None: + conn = None + else: + raise ValueError(f'Unknown "conn" type: {conn}') + return conn + + def _init_weights( + self, + weight: Union[float, ArrayType, Initializer, Callable], + comp_method: str, + data_if_sparse: str = 'csr' + ) -> Tuple[Union[float, ArrayType], ArrayType]: + if comp_method not in ['sparse', 'dense']: + raise ValueError(f'"comp_method" must be in "sparse" and "dense", but we got {comp_method}') + if data_if_sparse not in ['csr', 'ij', 'coo']: + raise ValueError(f'"sparse_data" must be in "csr" and "ij", but we got {data_if_sparse}') + + # connections and weights + if isinstance(self.conn, One2One): + weight = parameter(weight, (self.pre_num,), allow_none=False) + conn_mask = None + + elif isinstance(self.conn, All2All): + weight = parameter(weight, (self.pre_num, self.post_num), allow_none=False) + conn_mask = None + + else: + if comp_method == 'sparse': + if data_if_sparse == 'csr': + conn_mask = self.conn.require('pre2post') + elif data_if_sparse in ['ij', 'coo']: + conn_mask = self.conn.require('post_ids', 'pre_ids') + else: + ValueError(f'Unknown sparse data type: {data_if_sparse}') + weight = parameter(weight, conn_mask[0].shape, allow_none=False) + elif comp_method == 'dense': + weight = parameter(weight, (self.pre_num, self.post_num), allow_none=False) + conn_mask = self.conn.require('conn_mat') + else: + raise ValueError(f'Unknown connection type: {comp_method}') + + # training weights + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + return weight, conn_mask + + def _syn2post_with_all2all(self, syn_value, syn_weight, include_self): + if bm.ndim(syn_weight) == 0: + if isinstance(self.mode, bm.BatchingMode): + post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:]) + else: + post_vs = bm.sum(syn_value) + if not include_self: + post_vs = post_vs - syn_value + post_vs = syn_weight * post_vs + else: + post_vs = syn_value @ syn_weight + return post_vs + + def _syn2post_with_one2one(self, syn_value, syn_weight): + return syn_value * syn_weight + + def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat): + if bm.ndim(syn_weight) == 0: + post_vs = (syn_weight * syn_value) @ conn_mat + else: + post_vs = syn_value @ (syn_weight * conn_mat) + return post_vs + + +class SynOut(DynamicalSystemNS): + def update(self, post_g): + raise NotImplementedError + + def reset_state(self, batch_size: Optional[int] = None): + pass + + +class SynSTP(DynamicalSystemNS): + """Base class for synaptic short-term plasticity.""" + + def update(self, pre_spike): + raise NotImplementedError diff --git a/brainpy/_src/experimental/synout.py b/brainpy/_src/dyn/synapses_v2/syn_outs.py similarity index 78% rename from brainpy/_src/experimental/synout.py rename to brainpy/_src/dyn/synapses_v2/syn_outs.py index c93eb7907..9a783f8a1 100644 --- a/brainpy/_src/experimental/synout.py +++ b/brainpy/_src/dyn/synapses_v2/syn_outs.py @@ -1,70 +1,49 @@ -from typing import Union, Optional +# -*- coding: utf-8 -*- -import brainpy.math as bm -from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs -from brainpy.types import ArrayType +from typing import Union +from brainpy.math import Variable, exp +from brainpy.types import ArrayType +from brainpy._src.dyn.synapses_v2.base import SynOut -class SynOut(DynamicalSystem): - @not_pass_shargs - def update(self, g): - raise NotImplementedError - def reset_state(self, batch_size: Optional[int] = None): - pass +__all__ = [ + 'COBA', + 'CUBA', +] -class MgBlock(SynOut): - r"""Synaptic output based on Magnesium blocking. +class COBA(SynOut): + r"""Conductance-based synaptic output. Given the synaptic conductance, the model output the post-synaptic current with .. math:: - I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o}) - - where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to - - .. math:: - - g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} - - Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) Parameters ---------- - E: float, ArrayType - The reversal potential for the synaptic current. [mV] - alpha: float, ArrayType - Binding constant. Default 0.062 - beta: float, ArrayType - Unbinding constant. Default 3.57 - cc_Mg: float, ArrayType - Concentration of Magnesium ion. Default 1.2 [mM]. + E: float, ArrayType, ndarray + The reversal potential. name: str The model name. + + See Also + -------- + CUBA """ - def __init__( - self, - post_potential: bm.Variable, - E: Union[float, ArrayType] = 0., - cc_Mg: Union[float, ArrayType] = 1.2, - alpha: Union[float, ArrayType] = 0.062, - beta: Union[float, ArrayType] = 3.57, - name: str = None, - ): - super(MgBlock, self).__init__(name=name) - assert isinstance(post_potential, bm.Variable) - self.post_potential = post_potential + def __init__(self, + post_potential: Variable, + E: Union[float, ArrayType] = 0., + name: str = None, ): + super().__init__(name=name) self.E = E - self.cc_Mg = cc_Mg - self.alpha = alpha - self.beta = beta + self.post_potential = post_potential - @not_pass_shargs def update(self, g): - I = g * (self.E - self.post_potential) / (1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * self.post_potential)) + I = g * (self.E - self.post_potential) return I @@ -89,43 +68,61 @@ class CUBA(SynOut): """ def __init__(self, name: str = None, ): - super(CUBA, self).__init__(name=name) + super().__init__(name=name) - @not_pass_shargs - def update(self, V, g): + def update(self, g): return g -class COBA(SynOut): - r"""Conductance-based synaptic output. +class MgBlock(SynOut): + r"""Synaptic output based on Magnesium blocking. Given the synaptic conductance, the model output the post-synaptic current with .. math:: - I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o}) + + where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to + + .. math:: + + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} + + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. Parameters ---------- - E: float, ArrayType, ndarray - The reversal potential. + E: float, ArrayType + The reversal potential for the synaptic current. [mV] + alpha: float, ArrayType + Binding constant. Default 0.062 + beta: float, ArrayType + Unbinding constant. Default 3.57 + cc_Mg: float, ArrayType + Concentration of Magnesium ion. Default 1.2 [mM]. name: str The model name. - - See Also - -------- - CUBA """ - def __init__(self, - post_potential: bm.Variable, - E: Union[float, ArrayType] = 0., - name: str = None, ): - super(COBA, self).__init__(name=name) - self.E = E + def __init__( + self, + post_potential: Variable, + E: Union[float, ArrayType] = 0., + cc_Mg: Union[float, ArrayType] = 1.2, + alpha: Union[float, ArrayType] = 0.062, + beta: Union[float, ArrayType] = 3.57, + name: str = None, + ): + super().__init__(name=name) + assert isinstance(post_potential, Variable) self.post_potential = post_potential + self.E = E + self.cc_Mg = cc_Mg + self.alpha = alpha + self.beta = beta - @not_pass_shargs def update(self, g): - I = g * (self.E - self.post_potential) + I = g * (self.E - self.post_potential) / (1 + self.cc_Mg / self.beta * exp(-self.alpha * self.post_potential)) return I + diff --git a/brainpy/_src/experimental/synstp.py b/brainpy/_src/dyn/synapses_v2/syn_plasticity.py similarity index 89% rename from brainpy/_src/experimental/synstp.py rename to brainpy/_src/dyn/synapses_v2/syn_plasticity.py index 7401ed6ea..e011cc8a1 100644 --- a/brainpy/_src/experimental/synstp.py +++ b/brainpy/_src/dyn/synapses_v2/syn_plasticity.py @@ -4,8 +4,9 @@ import jax.numpy as jnp +from brainpy._src.dyn.context import share from brainpy import math as bm, tools -from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs +from brainpy._src.dyn.synapses_v2.base import SynSTP from brainpy._src.initialize import variable_, OneInit, parameter from brainpy._src.integrators import odeint, JointEq from brainpy.types import ArrayType, Shape @@ -16,14 +17,6 @@ ] -class SynSTP(DynamicalSystem): - """Base class for synaptic short-term plasticity.""" - - @not_pass_shargs - def update(self, pre_spike, post_g): - raise NotImplementedError - - class STD(SynSTP): r"""Synaptic output with short-term depression. @@ -66,7 +59,7 @@ def __init__( method: str = 'exp_auto', name: str = None ): - super(STD, self).__init__(name=name) + super().__init__(name=name) # parameters self.pre_size = tools.to_size(pre_size) @@ -84,9 +77,8 @@ def __init__( def reset_state(self, batch_size=None): self.x = variable_(jnp.ones, self.num, batch_size) - @not_pass_shargs def update(self, pre_spike): - x = self.integral(self.x.value, bm.share.load('t'), bm.share.load('dt')) + x = self.integral(self.x.value, share.load('t'), share.load('dt')) self.x.value = bm.where(pre_spike, x - self.U * self.x, x) return self.x.value @@ -144,7 +136,7 @@ def __init__( method: str = 'exp_auto', name: str = None ): - super(STP, self).__init__(name=name) + super().__init__(name=name) # parameters self.pre_size = tools.to_size(pre_size) @@ -167,12 +159,10 @@ def reset_state(self, batch_size=None): du = lambda self, u, t: self.U - u / self.tau_f dx = lambda self, x, t: (1 - x) / self.tau_d - @not_pass_shargs def update(self, pre_spike): - u, x = self.integral(self.u.value, self.x.value, bm.share.load('t'), bm.get_dt()) + u, x = self.integral(self.u.value, self.x.value, share.load('t'), bm.get_dt()) u = bm.where(pre_spike, u + self.U * (1 - self.u), u) x = bm.where(pre_spike, x - u * self.x, x) self.x.value = x self.u.value = u return self.x.value * self.u.value - diff --git a/brainpy/_src/dyn/synouts/conductances.py b/brainpy/_src/dyn/synouts/conductances.py index 9d3a4eb9f..bf060d291 100644 --- a/brainpy/_src/dyn/synouts/conductances.py +++ b/brainpy/_src/dyn/synouts/conductances.py @@ -7,6 +7,7 @@ from brainpy._src.initialize import parameter, Initializer from brainpy.types import ArrayType + __all__ = [ 'COBA', 'CUBA', @@ -104,3 +105,65 @@ def filter(self, g): V = self.membrane_var.value I = g * (self.E - V) return super(COBA, self).filter(I) + + +class eCOBA(SynOut): + r"""Conductance-based synaptic output. + + Given the synaptic conductance, the model output the post-synaptic current with + + .. math:: + + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) + + Parameters + ---------- + E: float, ArrayType, ndarray + The reversal potential. + name: str + The model name. + + See Also + -------- + CUBA + """ + + def __init__(self, + post_potential: Variable, + E: Union[float, ArrayType] = 0., + name: str = None, ): + super().__init__(name=name) + self.E = E + self.post_potential = post_potential + + def update(self, g): + I = g * (self.E - self.post_potential) + return I + + +class eCUBA(SynOut): + r"""Current-based synaptic output. + + Given the conductance, this model outputs the post-synaptic current with a identity function: + + .. math:: + + I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) + + Parameters + ---------- + name: str + The model name. + + + See Also + -------- + COBA + """ + + def __init__(self, name: str = None, ): + super().__init__(name=name) + + def update(self, g): + return g + diff --git a/brainpy/_src/dyn/synouts/ions.py b/brainpy/_src/dyn/synouts/ions.py index 1daea70eb..c7b1f7579 100644 --- a/brainpy/_src/dyn/synouts/ions.py +++ b/brainpy/_src/dyn/synouts/ions.py @@ -9,6 +9,7 @@ from brainpy._src.initialize import parameter, Initializer from brainpy.types import ArrayType + __all__ = [ 'MgBlock', ] @@ -92,3 +93,56 @@ def clone(self): beta=self._beta, target_var=self._target_var, membrane_var=self._membrane_var) + + +class eMgBlock(SynOut): + r"""Synaptic output based on Magnesium blocking. + + Given the synaptic conductance, the model output the post-synaptic current with + + .. math:: + + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o}) + + where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to + + .. math:: + + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} + + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. + + Parameters + ---------- + E: float, ArrayType + The reversal potential for the synaptic current. [mV] + alpha: float, ArrayType + Binding constant. Default 0.062 + beta: float, ArrayType + Unbinding constant. Default 3.57 + cc_Mg: float, ArrayType + Concentration of Magnesium ion. Default 1.2 [mM]. + name: str + The model name. + """ + + def __init__( + self, + post_potential: bm.Variable, + E: Union[float, ArrayType] = 0., + cc_Mg: Union[float, ArrayType] = 1.2, + alpha: Union[float, ArrayType] = 0.062, + beta: Union[float, ArrayType] = 3.57, + name: str = None, + ): + super().__init__(name=name) + assert isinstance(post_potential, bm.Variable) + self.post_potential = post_potential + self.E = E + self.cc_Mg = cc_Mg + self.alpha = alpha + self.beta = beta + + def update(self, g): + I = g * (self.E - self.post_potential) / (1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * self.post_potential)) + return I diff --git a/brainpy/_src/dyn/synplast/short_term_plasticity.py b/brainpy/_src/dyn/synplast/short_term_plasticity.py index f374f234b..93c3d3e13 100644 --- a/brainpy/_src/dyn/synplast/short_term_plasticity.py +++ b/brainpy/_src/dyn/synplast/short_term_plasticity.py @@ -4,11 +4,14 @@ import jax.numpy as jnp +from brainpy._src.dyn.context import share +from brainpy import math as bm, tools from brainpy._src.dyn.base import SynSTP from brainpy._src.initialize import variable +from brainpy._src.initialize import variable_, OneInit, parameter from brainpy._src.integrators import odeint, JointEq from brainpy.check import is_float -from brainpy.types import ArrayType +from brainpy.types import ArrayType, Shape __all__ = [ 'STD', @@ -181,3 +184,4 @@ def filter(self, g): if jnp.shape(g) != self.x.shape: raise ValueError('Shape does not match.') return g * self.x * self.u + diff --git a/brainpy/_src/dyn/transform.py b/brainpy/_src/dyn/transform.py index cf7dc5aef..6ca13ff49 100644 --- a/brainpy/_src/dyn/transform.py +++ b/brainpy/_src/dyn/transform.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- - +import functools from typing import Union, Optional, Dict, Sequence import jax.numpy as jnp @@ -7,9 +7,10 @@ from brainpy._src.math.object_transform.base import BrainPyObject from brainpy import tools, math as bm -from brainpy.check import is_float +from brainpy._src.dyn.context import share +from brainpy.check import is_float, is_integer from brainpy.types import PyTree -from .base import DynamicalSystem, Sequential +from .base import DynamicalSystem, Sequential, DynamicalSystemNS __all__ = [ 'LoopOverTime', @@ -41,7 +42,7 @@ def __repr__(self): return f"{name}({tools.repr_context(str(self.target), ' ' * len(name))})" -class LoopOverTime(DynSysToBPObj): +class LoopOverTime(DynamicalSystemNS): """Transform a single step :py:class:`~.DynamicalSystem` into a multiple-step forward propagation :py:class:`~.BrainPyObject`. @@ -68,17 +69,17 @@ class LoopOverTime(DynSysToBPObj): >>> model = bp.Sequential(l1=bp.layers.RNNCell(n_in, 20), >>> l2=bm.relu, >>> l3=bp.layers.RNNCell(20, 2)) - >>> over_time = bp.LoopOverTime(model) + >>> over_time = bp.LoopOverTime(model, data_first_axis='T') >>> over_time.reset_state(n_batch) (30, 128, 2) >>> - >>> hist_l3 = over_time(bm.random.rand(n_time, n_batch, n_in), data_first_axis='T') + >>> hist_l3 = over_time(bm.random.rand(n_time, n_batch, n_in)) >>> print(hist_l3.shape) >>> >>> # monitor the "l1" layer state - >>> over_time = bp.LoopOverTime(model, out_vars=model['l1'].state) + >>> over_time = bp.LoopOverTime(model, out_vars=model['l1'].state, data_first_axis='T') >>> over_time.reset_state(n_batch) - >>> hist_l3, hist_l1 = over_time(bm.random.rand(n_time, n_batch, n_in), data_first_axis='T') + >>> hist_l3, hist_l1 = over_time(bm.random.rand(n_time, n_batch, n_in)) >>> print(hist_l3.shape) (30, 128, 2) >>> print(hist_l1.shape) @@ -123,6 +124,19 @@ class LoopOverTime(DynSysToBPObj): out_vars: PyTree The variables to monitor over the time loop. + t0: float, optional + The start time to run the system. If None, ``t`` will be no longer generated in the loop. + i0: int, optional + The start index to run the system. If None, ``i`` will be no longer generated in the loop. + dt: float + The time step. + shared_arg: dict + The shared arguments across the nodes. + For instance, `shared_arg={'fit': False}` for the prediction phase. + data_first_axis: str + Denote whether the input data is time major. + If so, we treat the data as `(time, batch, ...)` when the `target` is in Batching mode. + Default is True. name: str The transformed object name. """ @@ -132,9 +146,38 @@ def __init__( target: DynamicalSystem, out_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None, no_state: bool = False, - name: str = None + t0: Optional[float] = 0., + i0: Optional[int] = 0, + dt: Optional[float] = None, + shared_arg: Optional[Dict] = None, + data_first_axis: str = 'T', + name: str = None, + jit: bool = True, + remat: bool = False, ): - super().__init__(target=target, name=name) + super().__init__(name=name) + assert data_first_axis in ['B', 'T'] + is_integer(i0, 'i0', allow_none=True) + is_float(t0, 't0', allow_none=True) + is_float(dt, 'dt', allow_none=True) + dt = share.dt if dt is None else dt + if shared_arg is None: + shared_arg = dict(dt=dt) + else: + assert isinstance(shared_arg, dict) + shared_arg['dt'] = dt + self.dt = dt + self.t0 = t0 + self.i0 = i0 + + self.jit = jit + self.remat = remat + self.shared_arg = shared_arg + self.data_first_axis = data_first_axis + self.target = target + if not isinstance(target, DynamicalSystem): + raise TypeError(f'Must be instance of {DynamicalSystem.__name__}, ' + f'but we got {type(target)}') self.no_state = no_state self.out_vars = out_vars if out_vars is not None: @@ -146,10 +189,6 @@ def __init__( def __call__( self, duration_or_xs: Union[float, PyTree], - t0: float = 0., - dt: Optional[float] = None, - shared_arg: Optional[Dict] = None, - data_first_axis: str = 'T' ): """Forward propagation along the time or inputs. @@ -158,38 +197,19 @@ def __call__( duration_or_xs: float, PyTree If `float`, it indicates a running duration. If a PyTree, it is the given inputs. - t0: float - The start time to run the system. - dt: float - The time step. - shared_arg: dict - The shared arguments across the nodes. - For instance, `shared_arg={'fit': False}` for the prediction phase. - data_first_axis: str - Denote whether the input data is time major. - If so, we treat the data as `(time, batch, ...)` when the `target` is in Batching mode. - Default is True. Returns ------- out: PyTree The accumulated outputs over time. """ - assert data_first_axis in ['B', 'T'] - - is_float(t0, 't0') - is_float(dt, 'dt', allow_none=True) - dt = bm.get_dt() if dt is None else dt - if shared_arg is None: - shared_arg = dict(dt=dt) - else: - assert isinstance(shared_arg, dict) - shared_arg['dt'] = dt - # inputs if isinstance(duration_or_xs, float): - shared = tools.DotDict(t=jnp.arange(t0, duration_or_xs, dt)) - shared['i'] = jnp.arange(0, shared['t'].shape[0]) + shared = tools.DotDict() + if self.t0 is not None: + shared['t'] = jnp.arange(self.t0, duration_or_xs, self.dt) + if self.i0 is not None: + shared['i'] = jnp.arange(0, shared['t'].shape[0]) xs = None if self.no_state: raise ValueError('Under the `no_state=True` setting, input cannot be a duration.') @@ -200,8 +220,8 @@ def __call__( 'of (B, T, ...) or (T, B, ...) with `data_first_axis="T"`, ' 'where B the batch size and T the time length.') xs, tree = tree_flatten(duration_or_xs, lambda a: isinstance(a, bm.Array)) - if isinstance(self.target.mode, bm.BatchingMode): - b_idx, t_idx = (1, 0) if data_first_axis == 'T' else (0, 1) + if self.target.mode.is_child_of(bm.BatchingMode): + b_idx, t_idx = (1, 0) if self.data_first_axis == 'T' else (0, 1) try: batch = tuple(set([x.shape[b_idx] for x in xs])) @@ -225,12 +245,12 @@ def __call__( f'but we got {tree_unflatten(tree, length)}.') if self.no_state: - xs = [jnp.reshape(x, (length[0] * batch[0],) + x.shape[2:]) for x in xs] + xs = [bm.reshape(x, (length[0] * batch[0],) + x.shape[2:]) for x in xs] else: - if data_first_axis == 'B': + if self.data_first_axis == 'B': xs = [jnp.moveaxis(x, 0, 1) for x in xs] xs = tree_unflatten(tree, xs) - origin_shape = (length[0], batch[0]) if data_first_axis == 'T' else (batch[0], length[0]) + origin_shape = (length[0], batch[0]) if self.data_first_axis == 'T' else (batch[0], length[0]) else: @@ -247,36 +267,34 @@ def __call__( # computation if self.no_state: - outputs = self.target(tools.DotDict(shared_arg), xs) + share.save(**self.shared_arg) + outputs = self._run(self.shared_arg, dict(), xs) return tree_map(lambda a: jnp.reshape(a, origin_shape + a.shape[1:]), outputs) else: - shared = tools.DotDict(t=jnp.arange(t0, dt * length[0], dt), - i=jnp.arange(0, length[0])) + shared = tools.DotDict() + shared['t'] = jnp.arange(self.t0, self.dt * length[0], self.dt) + shared['i'] = jnp.arange(0, length[0]) assert not self.no_state - - # function - @bm.to_object(child_objs=self.target) - def f(sha, x): - sha['dt'] = dt - sha.update(shared_arg) - outs = self.target(sha, x) - if self.out_vars is not None: - outs = (outs, tree_map(bm.as_jax, self.out_vars)) - self.target.clear_input() - return outs - - return bm.for_loop(f, (shared, xs)) - - def reset(self, batch_size=None): - """Reset function which reset the whole variables in the model. - """ - self.target.reset(batch_size) + return bm.for_loop(functools.partial(self._run, self.shared_arg), + (shared, xs), + child_objs=(self.target, share), + jit=self.jit, + remat=self.remat) def reset_state(self, batch_size=None): self.target.reset_state(batch_size) + def _run(self, static_sh, dyn_sh, x): + share.save(**static_sh) + share.save(**dyn_sh) + outs = self.target(x) + if self.out_vars is not None: + outs = (outs, tree_map(bm.as_jax, self.out_vars)) + self.target.clear_input() + return outs + class NoSharedArg(DynSysToBPObj): """Transform an instance of :py:class:`~.DynamicalSystem` into a callable diff --git a/brainpy/_src/experimental/delay.py b/brainpy/_src/experimental/delay.py deleted file mode 100644 index 35185ee9b..000000000 --- a/brainpy/_src/experimental/delay.py +++ /dev/null @@ -1,47 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Union, Callable, Optional, Dict - -import jax - -from brainpy import math as bm -from brainpy._src.dyn.base import DynamicalSystem, not_pass_shargs -from brainpy._src.math.delayvars import DelayVariable, ROTATE_UPDATE, CONCAT_UPDATE - - -class Delay(DynamicalSystem, DelayVariable): - """Delay for dynamical systems which has a fixed delay length. - - Detailed docstring please see :py:class:`~.DelayVariable`. - """ - - def __init__( - self, - target: bm.Variable, - length: int = 0, - before_t0: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, - entries: Optional[Dict] = None, - method: str = ROTATE_UPDATE, - mode: bm.Mode = None, - name: str = None, - ): - DynamicalSystem.__init__(self, mode=mode) - if method is None: - if self.mode.is_a(bm.NonBatchingMode): - method = ROTATE_UPDATE - elif self.mode.is_parent_of(bm.TrainingMode): - method = CONCAT_UPDATE - else: - method = ROTATE_UPDATE - DelayVariable.__init__(self, - target=target, - length=length, - before_t0=before_t0, - entries=entries, - method=method, - name=name) - - @not_pass_shargs - def update(self, *args, **kwargs): - return DelayVariable.update(self, *args, **kwargs) - diff --git a/brainpy/_src/experimental/neurons.py b/brainpy/_src/experimental/neurons.py deleted file mode 100644 index e3c31f55b..000000000 --- a/brainpy/_src/experimental/neurons.py +++ /dev/null @@ -1,155 +0,0 @@ -from typing import Union, Callable, Optional - -from jax.lax import stop_gradient - -import brainpy.math as bm -from brainpy._src.dyn.base import NeuGroup, not_pass_shargs -from brainpy._src.initialize import (ZeroInit, OneInit, Initializer, parameter, variable_) -from brainpy._src.integrators import odeint -from brainpy.check import is_initializer, is_callable, is_subclass -from brainpy.types import Shape, ArrayType - - -class LIF(NeuGroup): - r"""Leaky integrate-and-fire neuron model. - - **Model Descriptions** - - The formal equations of a LIF model [1]_ is given by: - - .. math:: - - \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ - \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad - \text{last} \quad \tau_{ref} \quad \text{ms} - - where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting - membrane potential, :math:`V_{reset}` is the reset membrane potential, - :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, - :math:`\tau_{ref}` is the refractory time period, - and :math:`I` is the time-variant synaptic inputs. - - **Model Examples** - - - `(Brette, Romain. 2004) LIF phase locking `_ - - - Parameters - ---------- - size: sequence of int, int - The size of the neuron group. - V_rest: float, ArrayType, Initializer, callable - Resting membrane potential. - V_reset: float, ArrayType, Initializer, callable - Reset potential after spike. - V_th: float, ArrayType, Initializer, callable - Threshold potential of spike. - R: float, ArrayType, Initializer, callable - Membrane resistance. - tau: float, ArrayType, Initializer, callable - Membrane time constant. - tau_ref: float, ArrayType, Initializer, callable - Refractory period length.(ms) - V_initializer: ArrayType, Initializer, callable - The initializer of membrane potential. - method: str - The numerical integration method. - name: str - The group name. - - References - ---------- - - .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model - neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - - # neuron parameter - V_rest: Union[float, ArrayType, Initializer, Callable] = 0., - V_reset: Union[float, ArrayType, Initializer, Callable] = -5., - V_th: Union[float, ArrayType, Initializer, Callable] = 20., - R: Union[float, ArrayType, Initializer, Callable] = 1., - tau: Union[float, ArrayType, Initializer, Callable] = 10., - tau_ref: Optional[Union[float, ArrayType, Initializer, Callable]] = None, - V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), - - # training parameter - mode: Optional[bm.Mode] = None, - spike_fun: Callable = bm.surrogate.inv_square_grad, - - # other parameters - method: str = 'exp_auto', - name: Optional[str] = None, - ): - # initialization - super(LIF, self).__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode) - is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode), self.name) - - # parameters - self.V_rest = parameter(V_rest, self.varshape, allow_none=False) - self.V_reset = parameter(V_reset, self.varshape, allow_none=False) - self.V_th = parameter(V_th, self.varshape, allow_none=False) - self.tau = parameter(tau, self.varshape, allow_none=False) - self.R = parameter(R, self.varshape, allow_none=False) - self.tau_ref = parameter(tau_ref, self.varshape, allow_none=True) - self.spike_fun = is_callable(spike_fun, 'spike_fun') - - # initializers - is_initializer(V_initializer, 'V_initializer') - self._V_initializer = V_initializer - - # integral - self.integral = odeint(method=method, f=self.derivative) - - # variables - self.reset_state(self.mode) - - def derivative(self, V, t, I_ext): - return (-V + self.V_rest + self.R * I_ext) / self.tau - - def reset_state(self, batch_size=None): - self.V = variable_(self._V_initializer, self.varshape, batch_size) - self.spike = variable_(bm.zeros, self.varshape, batch_size) - if self.tau_ref is not None: - self.t_last_spike = variable_(OneInit(-1e7), self.varshape, batch_size) - - @not_pass_shargs - def update(self, current): - t = bm.share.load('t') - - # integrate membrane potential - V = self.integral(self.V.value, t, current, bm.dt) - - if self.tau_ref is not None: - refractory = stop_gradient((t - self.t_last_spike) <= self.tau_ref) - V = bm.where(refractory, self.V.value, V) - - # spike, refractory, spiking time, and membrane potential reset - spike = self.spike_fun(V - self.V_th) - spike_no_grad = stop_gradient(spike) - V += (self.V_reset - V) * spike_no_grad - t_last_spike = bm.where(spike_no_grad, t, self.t_last_spike) - - # updates - self.V.value = V - self.spike.value = spike - self.t_last_spike.value = stop_gradient(t_last_spike) - - else: - # spike, spiking time, and membrane potential reset - spike = self.spike_fun(V - self.V_th) - V += (self.V_reset - V) * stop_gradient(spike) - - # updates - self.V.value = V - self.spike.value = spike - - return spike diff --git a/brainpy/_src/losses/comparison.py b/brainpy/_src/losses/comparison.py index 7a95e7b4b..7f3f1385f 100644 --- a/brainpy/_src/losses/comparison.py +++ b/brainpy/_src/losses/comparison.py @@ -97,7 +97,7 @@ def cross_entropy_loss(predicts, targets, weight=None, reduction='mean'): in the case of K-dimensional loss. """ def _cel(_pred, _tar): - if jnp.ndim(_tar) + 1 == jnp.ndim(_pred): + if bm.ndim(_tar) + 1 == bm.ndim(_pred): _tar = bm.one_hot(_tar, _pred.shape[-1]) loss = logsumexp(bm.as_jax(_pred), axis=-1) - (_pred * _tar).sum(axis=-1) if weight is not None: diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index 1a6fac48f..e852fc710 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -57,5 +57,4 @@ # environment settings from .modes import * from .environment import * -from .context import share diff --git a/brainpy/_src/math/_utils.py b/brainpy/_src/math/_utils.py index 6c4379a21..fa943ffc2 100644 --- a/brainpy/_src/math/_utils.py +++ b/brainpy/_src/math/_utils.py @@ -19,7 +19,9 @@ def _is_leaf(a): def _compatible_with_brainpy_array( fun: Callable, module: str = '' ): - @functools.wraps(fun) + func_to_wrap = fun.__np_wrapped__ if hasattr(fun, '__np_wrapped__') else fun + + @functools.wraps(func_to_wrap) def new_fun(*args, **kwargs): args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) out = None diff --git a/brainpy/_src/math/compat_pytorch.py b/brainpy/_src/math/compat_pytorch.py index 82f4b99c4..70031a17a 100644 --- a/brainpy/_src/math/compat_pytorch.py +++ b/brainpy/_src/math/compat_pytorch.py @@ -13,7 +13,24 @@ 'Tensor', 'flatten', 'cat', - + 'abs', + 'absolute', + 'acos', + 'arccos', + 'acosh', + 'arccosh', + 'add', + 'addcdiv', + 'addcmul', + 'angle', + 'asin', + 'arcsin', + 'asinh', + 'arcsin', + 'atan', + 'arctan', + 'atan2', + 'atanh', ] diff --git a/brainpy/_src/math/context.py b/brainpy/_src/math/context.py deleted file mode 100644 index a4110901e..000000000 --- a/brainpy/_src/math/context.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -Context for brainpy computation. - -This context defines all shared data used in all modules in a computation. -""" - -from typing import Dict, Any - -from brainpy._src.tools.dicts import DotDict -from .delayvars import DelayVariable -from .object_transform.base import BrainPyObject -from .environment import get_dt as _get_dt_ - -__all__ = [ - 'share', -] - - -class DelayEntry: - def __init__(self, target: str, time=None, step=None): - if time is None and step is None: - raise ValueError('Please provide time or step.') - self.target = target - self.time = time - self.step = step - - -class _ShareContext(BrainPyObject): - def __init__(self): - super().__init__() - - # Shared data across all nodes at current time step. - # ------------- - - self._arguments = DotDict() - self._delays: Dict[str, DelayVariable] = DotDict() - self._delay_entries: Dict[str, str] = DotDict() - self._identifiers = set() - - @property - def dt(self): - if 'dt' in self._arguments: - return self._arguments['dt'] - else: - return _get_dt_() - - def load(self, key): - """Get the shared data by the ``key``. - - Args: - key (str): the key to indicate the data. - """ - if key in self._arguments: - return self._arguments[key] - if key in self._delays: - return self._delays[key] - if key in self._delay_entries: - entry = key - delay = self._delay_entries[entry] - return self._delays[delay].at(entry) - raise KeyError(f'Cannot found shared data of {key}.') - - def save(self, identifier: str, data: Any) -> None: - """Save shared arguments in the global context.""" - assert isinstance(identifier, str) - - if isinstance(data, DelayVariable): - if identifier in self._identifiers: - raise ValueError(f'{identifier} has been used. Please assign another name.') - self._delays[identifier] = data - # elif isinstance(data, DelayEntry): - # if isinstance(data.target, DelayVariable): - # delay_key = f'delay{id(data)}' - # self.save(delay_key, data.target) - # delay = data.target - # elif isinstance(data.target, str): - # if data.target not in self._delays: - # raise ValueError(f'Delay target {data.target} has not been registered.') - # delay = self._delays[data.target] - # delay_key = data.target - # else: - # raise ValueError(f'Unknown delay target. {type(data.target)}') - # delay.register_entry(identifier, delay_time=data.time, delay_step=data.step) - # self._delay_entries[identifier] = delay_key - else: - self._arguments[identifier] = data - self._identifiers.add(identifier) - - def get_shargs(self) -> DotDict: - """Get all shared arguments in the global context.""" - return self._arguments.copy() - - def remove_shargs(self, *args) -> None: - """Clear all shared arguments in the global context.""" - if len(args) > 0: - for a in args: - self._arguments.pop(a) - else: - self._arguments.clear() - - def clear(self) -> None: - """Clear all shared data in this computation context.""" - self._arguments.clear() - self._delays.clear() - self._delay_entries.clear() - self._identifiers.clear() - - def update(self): - for delay in self._delays.values(): - delay.update() - - def reset_state(self, batch_axis: int = None): - for delay in self._delays.values(): - delay.reset_state(batch_axis) - - -share = _ShareContext() diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index fa0fd193e..6d8051d45 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -21,7 +21,6 @@ 'AbstractDelay', 'TimeDelay', 'LengthDelay', 'NeuTimeDelay', 'NeuLenDelay', - 'DelayVariable', 'ROTATE_UPDATE', 'CONCAT_UPDATE', ] @@ -471,281 +470,6 @@ def update(self, value: Union[float, int, bool, Array, jnp.DeviceArray]): raise ValueError(f'Unknown updating method "{self.update_method}"') -class DelayVariable(AbstractDelay): - """Delay variable which has a fixed delay length. - - The data in this delay variable is arranged as:: - - delay = 0 [ data - delay = 1 data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - - Parameters - ---------- - target: Variable - The initial delay data. - length: int - The delay data length. - before_t0: Any - The delay data. It can be a Python number, like float, int, boolean values. - It can also be arrays. Or a callable function or instance of ``Connector``. - Note that ``initial_delay_data`` should be arranged as the following way:: - - delay = 1 [ data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - method: str - The method used for updating delay. - - """ - - data: Optional[Variable] - idx: Optional[Variable] - length: int - - def __init__( - self, - target: Variable, - length: int = 0, - before_t0: Union[float, int, bool, Array, jax.Array, Callable] = None, - entries: Optional[Dict] = None, - name: str = None, - method: str = ROTATE_UPDATE, - ): - BrainPyObject.__init__(self, name=name) - assert method in [ROTATE_UPDATE, CONCAT_UPDATE] - self.method = method - - # target - self.target = target - if not isinstance(target, Variable): - raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') - - # delay length - self.length = is_integer(length, allow_none=False, min_bound=0) - - # delay data - if before_t0 is not None: - assert isinstance(before_t0, (int, float, bool, Array, jax.Array, Callable)) - self._before_t0 = before_t0 - if length > 0: - self._init_data(length) - else: - self.data = None - - # time variables - if self.method == ROTATE_UPDATE: - self.idx = Variable(stop_gradient(jnp.asarray(0, dtype=jnp.int32))) - - # other info - self._access_to_step = dict() - for entry, value in entries.items(): - self.register_entry(entry, value) - - def register_entry( - self, - entry: str, - delay_time: Optional[Union[float, Array, Callable]] = None, - delay_step: Optional[Union[int, Array, Callable]] = None, - ) -> 'DelayVariable': - """Register an entry to access the data. - - Args: - entry (str): The entry to access the delay data. - delay_step: The delay step of the entry (must be an integer, denoting the delay step). - delay_time: The delay time of the entry (can be a float). - - Returns: - Return the self. - """ - if entry in self._access_to_step: - raise KeyError(f'Entry {entry} has been registered.') - - if delay_time is not None: - if delay_step is not None: - raise ValueError('Provide either "delay_time" or "delay_step". Both you have given both.') - if callable(delay_time): - delay_time = as_jax(delay_time(self.delay_target_shape)) - delay_step = jnp.asarray(delay_time / get_dt(), dtype=get_int()) - elif isinstance(delay_time, float): - delay_step = int(delay_time / get_dt()) - else: - delay_step = jnp.asarray(as_jax(delay_time) / get_dt(), dtype=get_int()) - - # delay steps - if delay_step is None: - delay_type = 'none' - elif isinstance(delay_step, int): - delay_type = 'homo' - elif isinstance(delay_step, (Array, jax.Array, np.ndarray)): - if delay_step.size == 1 and delay_step.ndim == 0: - delay_type = 'homo' - else: - delay_type = 'heter' - delay_step = Array(delay_step) - elif callable(delay_step): - delay_step = delay_step(self.delay_target_shape) - delay_type = 'heter' - else: - raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' - f'integer, array of integers, callable function, brainpy.init.Initializer.') - if delay_type == 'heter': - if delay_step.dtype not in [jnp.int32, jnp.int64]: - raise ValueError('Only support delay steps of int32, int64. If your ' - 'provide delay time length, please divide the "dt" ' - 'then provide us the number of delay steps.') - if self.delay_target_shape[0] != delay_step.shape[0]: - raise ValueError(f'Shape is mismatched: {self.delay_target_shape[0]} != {delay_step.shape[0]}') - if delay_type == 'heter': - max_delay_step = int(max(delay_step)) - elif delay_type == 'homo': - max_delay_step = delay_step - else: - max_delay_step = None - - # delay variable - if max_delay_step is not None: - if self.length < max_delay_step: - self._init_data(max_delay_step) - self.length = max_delay_step - self._access_to_step[entry] = delay_step - return self - - def at(self, entry: str, *indices) -> Array: - """Get the data at the given entry. - - Args: - entry (str): The entry to access the data. - *indices: - - Returns: - The data. - """ - assert isinstance(entry, str) - if entry not in self._access_to_step: - raise KeyError(f'Does not find delay entry "{entry}".') - delay_step = self._access_to_step[entry] - if delay_step is None: - return self.target.value - else: - if self.data is None: - return self.target.value - else: - if isinstance(delay_step, slice): - return self.retrieve(delay_step, *indices) - elif np.ndim(delay_step) == 0: - return self.retrieve(delay_step, *indices) - else: - if len(indices) == 0 and len(delay_step) == self.target.shape[0]: - indices = (jnp.arange(delay_step.size),) - return self.retrieve(delay_step, *indices) - - @property - def delay_target_shape(self): - """The data shape of the delay target.""" - return self.target.shape - - def __repr__(self): - name = self.__class__.__name__ - return (f'{name}(num_delay_step={self.length}, ' - f'delay_target_shape={self.delay_target_shape}, ' - f'update_method={self.method})') - - def _check_delay(self, delay_len): - raise ValueError(f'The request delay length should be less than the ' - f'maximum delay {self.length}. ' - f'But we got {delay_len}') - - def retrieve(self, delay_step, *indices): - """Retrieve the delay data according to the delay length. - - Parameters - ---------- - delay_step: int, ArrayType - The delay length used to retrieve the data. - """ - assert delay_step is not None - if check.is_checking(): - jit_error_checking(jnp.any(delay_step > self.length), self._check_delay, delay_step) - - if self.method == ROTATE_UPDATE: - delay_idx = (self.idx.value + delay_step) % (self.length + 1) - delay_idx = stop_gradient(delay_idx) - - elif self.method == CONCAT_UPDATE: - delay_idx = delay_step - - else: - raise ValueError(f'Unknown updating method "{self.method}"') - - # the delay index - if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): - raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') - indices = (delay_idx,) + tuple(indices) - - # the delay data - return self.data[indices] - - def update(self, latest_value: Optional[Union[Array, jax.Array]] = None) -> None: - """Update delay variable with the new data. - """ - if self.data is not None: - # get the latest target value - if latest_value is None: - latest_value = self.target.value - - # update the delay data at the rotation index - if self.method == ROTATE_UPDATE: - self.idx.value = stop_gradient(as_jax((self.idx - 1) % (self.length + 1))) - self.data[self.idx.value] = latest_value - - # update the delay data at the first position - elif self.method == CONCAT_UPDATE: - if self.length >= 2: - self.data.value = vstack([latest_value, self.data[1:]]) - else: - self.data[0] = latest_value - - def reset_state(self, batch_size: int = None): - """Reset the delay data. - """ - # initialize delay data - if self.data is not None: - self._init_data(self.length, batch_size) - - # time variables - if self.method == ROTATE_UPDATE: - self.idx.value = stop_gradient(jnp.asarray(0, dtype=jnp.int32)) - - def _init_data(self, length, batch_size: int = None): - if batch_size is not None: - if self.target.batch_size != batch_size: - raise ValueError(f'The batch sizes of delay variable and target variable differ ' - f'({self.target.batch_size} != {batch_size}). ' - 'Please reset the target variable first, because delay data ' - 'depends on the target variable. ') - - if self.target.batch_axis is None: - batch_axis = None - else: - batch_axis = self.target.batch_axis + 1 - self.data = Variable(jnp.zeros((length + 1,) + self.target.shape, dtype=self.target.dtype), - batch_axis=batch_axis) - # update delay data - self.data[0] = self.target.value - if isinstance(self._before_t0, (Array, jax.Array, float, int, bool)): - self.data[1:] = self._before_t0 - elif callable(self._before_t0): - self.data[1:] = self._before_t0((length,) + self.target.shape, dtype=self.target.dtype) - - class NeuLenDelay(LengthDelay): """Neutral Length Delay. Alias of :py:class:`~.LengthDelay`.""" pass diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index c9eeec3b7..16eac59a0 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -485,7 +485,6 @@ class training_environment(environment): >>> with bm.environment(mode=bm.training_mode): >>> pass - """ def __init__(self, diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 041bb8a70..5b1ad169f 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -64,11 +64,11 @@ def __init__(self, name=None): # Used to wrap the implicit variables # which cannot be accessed by self.xxx - self.implicit_vars = ArrayCollector() + self.implicit_vars: dyn_dict = dyn_dict() # Used to wrap the implicit children nodes # which cannot be accessed by self.xxx - self.implicit_nodes = Collector() + self.implicit_nodes: dyn_dict = dyn_dict() def __setattr__(self, key: str, value: Any) -> None: """Overwrite `__setattr__` method for change Variable values. @@ -262,17 +262,15 @@ def _find_nodes(self, method='absolute', level=-1, include_self=True, _lid=0, _p nodes = [] for k, v in self.__dict__.items(): if isinstance(v, BrainPyObject): - path = (id(self), id(v)) - if path not in _paths: - _paths.add(path) - gather[v.name] = v - nodes.append(v) - for node in self.implicit_nodes.values(): - path = (id(self), id(node)) - if path not in _paths: - _paths.add(path) - gather[node.name] = node - nodes.append(node) + _add_node2(self, v, _paths, gather, nodes) + elif isinstance(v, dyn_seq): + for v2 in v: + if isinstance(v2, BrainPyObject): + _add_node2(self, v2, _paths, gather, nodes) + elif isinstance(v, dyn_dict): + for v2 in v.values(): + if isinstance(v2, BrainPyObject): + _add_node2(self, v2, _paths, gather, nodes) for v in nodes: gather.update(v._find_nodes(method=method, level=level, @@ -284,17 +282,15 @@ def _find_nodes(self, method='absolute', level=-1, include_self=True, _lid=0, _p nodes = [] for k, v in self.__dict__.items(): if isinstance(v, BrainPyObject): - path = (id(self), id(v)) - if path not in _paths: - _paths.add(path) - gather[k] = v - nodes.append((k, v)) - for key, node in self.implicit_nodes.items(): - path = (id(self), id(node)) - if path not in _paths: - _paths.add(path) - gather[key] = node - nodes.append((key, node)) + _add_node1(self, k, v, _paths, gather, nodes) + elif isinstance(v, dyn_seq): + for i, v2 in enumerate(v): + if isinstance(v, BrainPyObject): + _add_node1(self, k + '-' + str(i), v2, _paths, gather, nodes) + elif isinstance(v, dyn_dict): + for k2, v2 in v.items(): + if isinstance(v2, BrainPyObject): + _add_node1(self, k + '.' + k2, v2, _paths, gather, nodes) for k1, v1 in nodes: for k2, v2 in v1._find_nodes(method=method, _paths=_paths, @@ -351,10 +347,10 @@ def unique_name(self, name=None, type_=None): check_name_uniqueness(name=name, obj=self) return name - def __state_dict__(self) -> dict: - return self.vars(include_self=True, level=0).unique() + def __save_state__(self) -> dict: + return self.vars(include_self=True, level=0).unique().dict() - def __load_state_dict__(self, state_dict: dict) -> Optional[Tuple[Sequence[str], Sequence[str]]]: + def __load_state__(self, state_dict: dict) -> Optional[Tuple[Sequence[str], Sequence[str]]]: variables = self.vars(include_self=True, level=0).unique() keys1 = set(state_dict.keys()) keys2 = set(variables.keys()) @@ -373,7 +369,7 @@ def state_dict(self) -> dict: A dictionary containing a whole state of the module. """ nodes = self.nodes() # retrieve all nodes - return {key: node.__state_dict__() for key, node in nodes.items()} + return {key: node.__save_state__() for key, node in nodes.items()} def load_state_dict(self, state_dict: Dict[str, Any], warn: bool = True, compatible='v2'): """Copy parameters and buffers from :attr:`state_dict` into @@ -407,7 +403,7 @@ def load_state_dict(self, state_dict: Dict[str, Any], warn: bool = True, compati missing_keys = [] unexpected_keys = [] for name, node in nodes.items(): - missing, unexpected = node.__load_state_dict__(state_dict[name]) + missing, unexpected = node.__load_state__(state_dict[name]) missing_keys.extend([f'{name}.{key}' for key in missing]) unexpected_keys.extend([f'{name}.{key}' for key in unexpected]) else: @@ -494,6 +490,22 @@ def tpu(self): return self.to(device=jax.devices('tpu')[0]) +def _add_node2(self, v, _paths, gather, nodes): + path = (id(self), id(v)) + if path not in _paths: + _paths.add(path) + gather[v.name] = v + nodes.append(v) + + +def _add_node1(self, k, v, _paths, gather, nodes): + path = (id(self), id(v)) + if path not in _paths: + _paths.add(path) + gather[k] = v + nodes.append((k, v)) + + Base = BrainPyObject diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index e55008fe8..ec827a22c 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -10,7 +10,8 @@ from jax.errors import UnexpectedTracerError from brainpy import errors, tools, check -from brainpy._src.math.ndarray import (Array, Variable, +from brainpy._src.math.ndarray import (Array, + Variable, add_context, del_context) from brainpy._src.math.arrayinterporate import as_jax @@ -664,6 +665,7 @@ def for_loop( reverse: bool = False, unroll: int = 1, remat: bool = False, + jit: bool = True, ): """``for-loop`` control flow with :py:class:`~.Variable`. @@ -759,9 +761,6 @@ def for_loop( dyn_vars.update(obj.vars().unique()) dyn_vars = list(ArrayCollector(dyn_vars).unique().values()) outs, _ = tree_flatten(out_vars, lambda s: isinstance(s, Variable)) - for v in outs: - if v not in dyn_vars: - dyn_vars.append(v) # functions def fun2scan(carry, x): @@ -785,11 +784,12 @@ def fun2scan(carry, x): # functions try: add_context(name) - dyn_vals, out_vals = lax.scan(f=fun2scan, - init=[v.value for v in dyn_vars], - xs=operands, - reverse=reverse, - unroll=unroll) + with jax.disable_jit(not jit): + dyn_vals, out_vals = lax.scan(f=fun2scan, + init=[v.value for v in dyn_vars], + xs=operands, + reverse=reverse, + unroll=unroll) del_context(name) except UnexpectedTracerError as e: del_context(name) diff --git a/brainpy/_src/math/object_transform/function.py b/brainpy/_src/math/object_transform/function.py index 785478942..79907b015 100644 --- a/brainpy/_src/math/object_transform/function.py +++ b/brainpy/_src/math/object_transform/function.py @@ -9,7 +9,6 @@ __all__ = [ 'Partial', 'to_object', - 'to_dynsys', 'function', ] @@ -71,43 +70,6 @@ def wrap(func) -> FunAsObject: return FunAsObject(target=f, child_objs=child_objs, dyn_vars=dyn_vars, name=name) -def to_dynsys( - f: Callable = None, - child_objs: Union[Callable, BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]] = None, - dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None, - name: str = None -): - """Transform a Python function to a :py:class:`~.DynamicalSystem`. - - Parameters - ---------- - f: function, callable - The python function. - child_objs: Callable, DynamicalSystem, sequence of DynamicalSystem, dict of DynamicalSystem - The children objects used in this Python function. - dyn_vars: Variable, sequence of Variable, dict of Variable - The `Variable` instance used in the Python function. - name: str - The name of the created ``BrainPyObject``. - - Returns - ------- - func: FunAsDynSys - The instance of ``DynamicalSystem``. - """ - from brainpy._src.dyn.base import FuncAsDynSys - - if f is None: - def wrap(func) -> FuncAsDynSys: - return FuncAsDynSys(target=func, child_objs=child_objs, dyn_vars=dyn_vars, name=name) - - return wrap - else: - if child_objs is None: - raise ValueError(f'"child_objs" cannot be None when "f" is provided.') - return FuncAsDynSys(target=f, child_objs=child_objs, dyn_vars=dyn_vars, name=name) - - def function( f: Callable = None, nodes: Union[Callable, BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]] = None, diff --git a/brainpy/_src/math/remove_vmap.py b/brainpy/_src/math/remove_vmap.py index bde0c8f1b..6075b9452 100644 --- a/brainpy/_src/math/remove_vmap.py +++ b/brainpy/_src/math/remove_vmap.py @@ -5,6 +5,7 @@ from jax.abstract_arrays import ShapedArray from jax.core import Primitive from jax.interpreters import batching, mlir, xla +from .ndarray import Array __all__ = [ 'remove_vmap' @@ -12,6 +13,8 @@ def remove_vmap(x, op='any'): + if isinstance(x, Array): + x = x.value if op == 'any': return _any_without_vmap(x) elif op == 'all': diff --git a/brainpy/_src/tools/naming.py b/brainpy/_src/tools/naming.py index 52de795d0..3ab51a860 100644 --- a/brainpy/_src/tools/naming.py +++ b/brainpy/_src/tools/naming.py @@ -2,6 +2,7 @@ import warnings from brainpy import errors +from brainpy import check __all__ = [ diff --git a/brainpy/_src/train/back_propagation.py b/brainpy/_src/train/back_propagation.py index ac76b93a5..083169806 100644 --- a/brainpy/_src/train/back_propagation.py +++ b/brainpy/_src/train/back_propagation.py @@ -1,20 +1,20 @@ # -*- coding: utf-8 -*- -import sys import time from collections.abc import Iterable from functools import partial -from typing import Union, Dict, Callable, Sequence, Any, Optional -from tqdm import tqdm +from typing import Union, Dict, Callable, Sequence, Optional import jax.numpy as jnp import numpy as np from jax.tree_util import tree_map +from tqdm import tqdm import brainpy.losses as losses import brainpy.math as bm from brainpy import tools, optim from brainpy._src.dyn.base import DynamicalSystem +from brainpy._src.dyn.context import share from brainpy._src.math.object_transform.base import BrainPyObject from brainpy._src.running import constants as c from brainpy.check import serialize_kwargs @@ -75,7 +75,6 @@ def __init__( optimizer: optim.Optimizer = None, # optimizer loss_has_aux: bool = False, # loss auxiliary loss_auto_run: bool = True, # loss auxiliary - logger: Optional[Any] = None, # ------------- # API deprecated @@ -141,9 +140,6 @@ def __init__( self._f_loss_compiled = dict() self._f_grad_compiled = dict() - # others - self.logger = logger - def __repr__(self): name = self.__class__.__name__ prefix = ' ' * len(name) @@ -269,7 +265,10 @@ def fit( fit_t0 = time.time() fit_epoch_metric = dict(loss=[]) _training_data = train_data() if callable(train_data) else train_data - bar = tqdm(total=len(_training_data) if hasattr(_training_data, '__len__') else None) + if hasattr(_training_data, '__len__'): + bar = tqdm(total=len(_training_data)) + else: + bar = None for x, y in _training_data: # reset state @@ -289,7 +288,8 @@ def fit( if k not in fit_epoch_metric: fit_epoch_metric[k] = [] fit_epoch_metric[k].append(v) - bar.update(1) + if bar is not None: + bar.update(1) # report fit_i += 1 @@ -306,9 +306,10 @@ def fit( v.clear() _report = (f'Train {fit_i} steps, use {fit_t + fit_t1 - fit_t0:.4f} s' + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) - bar.set_description(_report, refresh=True) - if self.logger is not None: - self.logger.write(_report + '\n') + if bar is not None: + bar.set_description(_report, refresh=True) + else: + print(_report) if fun_after_report is not None: fun_after_report(fit_i, aux, 'fit') fit_t0 = time.time() @@ -327,22 +328,26 @@ def fit( v.clear() _report = (f'Train {epoch_idx} epoch, use {fit_t1 - fit_t0:.4f} s' + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) - bar.set_description(_report, refresh=True) - if self.logger is not None: - self.logger.write(_report + '\n') + if bar is not None: + bar.set_description(_report, refresh=True) + else: + print(_report) if fun_after_report is not None: fun_after_report(epoch_idx, aux, 'fit') else: fit_t = time.time() - fit_t0 self.optimizer.lr.step_epoch() - bar.close() + if bar is not None: bar.close() # testing set if test_data is not None: test_t0 = time.time() test_epoch_metric = dict(loss=[]) _testing_data = test_data() if callable(test_data) else test_data - bar = tqdm(total=len(_testing_data) if hasattr(_testing_data, '__len__') else None) + if hasattr(_testing_data, '__len__'): + bar = tqdm(total=len(_testing_data)) + else: + bar = None for x, y in _testing_data: # reset state if reset_state: @@ -364,7 +369,7 @@ def fit( else: test_epoch_metric['loss'].append(res) - bar.update(1) + if bar is not None: bar.update(1) # report test_i += 1 @@ -381,9 +386,10 @@ def fit( v.clear() _report = (f'Test {test_i} steps, use {test_t + test_t1 - test_t0:.4f} s' + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) - bar.set_description(_report, refresh=True) - if self.logger is not None: - self.logger.write(_report + '\n') + if bar is not None: + bar.set_description(_report, refresh=True) + else: + print(_report) if fun_after_report is not None: fun_after_report(test_i, aux, 'test') test_t0 = time.time() @@ -402,15 +408,16 @@ def fit( v.clear() _report = (f'Test {epoch_idx} epoch, use {test_t1 - test_t0:.4f} s' + ', {}'.format(", ".join([f"{k} {v}" for k, v in aux.items()]))) - bar.set_description(_report, refresh=True) - if self.logger is not None: - self.logger.write(_report + '\n') + if bar is not None: + bar.set_description(_report, refresh=True) + else: + print(_report) if fun_after_report is not None: fun_after_report(epoch_idx, aux, 'test') else: test_t = time.time() - test_t0 - bar.close() + if bar is not None: bar.close() # finally self._report_train_metrics = {k: np.asarray(v) for k, v in report_train_metric.items()} @@ -567,19 +574,19 @@ def _step_func_fit(self, shared_args, inputs, targets): def _step_func_predict(self, shared, x=None): assert self.data_first_axis == 'B', f'There is no time dimension when using the trainer {self.__class__.__name__}.' for k, v in shared.items(): - bm.share.save(k, v) + share.save(k, v) # input step self.target.clear_input() self._step_func_input(shared) # dynamics update step - args = (shared,) if x is None else (shared, x) + args = () if x is None else (x, ) out = self.target(*args) # monitor step mon = self._step_func_monitor(shared) - bm.share.remove_shargs(shared) + share.clear_shargs() return out, mon def _get_f_predict(self, shared_args: Dict = None, jit: bool = True): @@ -655,108 +662,3 @@ def predict( return (t1 - t0, outs) if eval_time else outs -class _OnlineBPTT(BPTT): - def _step_func_loss(self, shared_args, t, i, input_, target_): - outputs, mon = self._get_f_predict_one_step(shared_args)(t, i, input_) - predicts = (outputs, mon) if len(mon) > 0 else outputs - return self._loss_func(predicts, target_) - - def _get_f_loss(self, shared_args=None, jit=True) -> Callable: - """Get loss function.""" - if shared_args is None: shared_args = dict() - - shared_args2 = {k: v for k, v in shared_args.items()} - shared_args2['_local_jit_'] = jit - shared_args_str = serialize_kwargs(shared_args2) - if shared_args_str not in self._f_loss_compiled: - - self._f_loss_compiled[shared_args_str] = partial(self._step_func_loss, shared_args) - if self.jit[c.LOSS_PHASE] and jit: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView) - self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str], - dyn_vars=dyn_vars) - return self._f_loss_compiled[shared_args_str] - - def _get_f_train(self, shared_args=None) -> Callable: - """Get training function.""" - if shared_args is None: shared_args = dict() - if not isinstance(shared_args, dict): - raise ValueError(f'Only supports dict for "shared_args". ' - f'But got {type(shared_args)}: {shared_args}') - shared_args_str = serialize_kwargs(shared_args) - if shared_args_str not in self._f_fit_compiled: - - def train_step(*x): - # t, i, input_, target_ = x - res = self._get_f_grad(shared_args)(*x) - self.optimizer.update(res[0]) - return res[1:] - - if self.jit[c.FIT_PHASE]: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView) - run_func = lambda all_inputs: bm.for_loop(train_step, all_inputs, dyn_vars=dyn_vars.unique()) - - else: - def run_func(xs): - times, indices, inputs, targets = xs - losses = [] - for i in range(times.shape[0]): - # data at time i - x = tree_map(lambda x: x[i], inputs, is_leaf=_is_brainpy_array) - y = tree_map(lambda x: x[i], targets, is_leaf=_is_brainpy_array) - # step at the i - loss = train_step(times[i], indices[i], x, y) - # append output and monitor - losses.append(loss) - return bm.asarray(losses) - - def train_fun(inputs, targets): - times, indices, inputs, num_step, _, duration, _ = self._format_xs( - None, inputs, inputs_are_batching=True, move_axis=True) - targets = tree_map(lambda x: bm.moveaxis(x, 0, 1), targets, is_leaf=_is_brainpy_array) - ls = run_func([times, indices, inputs, targets]) - self.i0 += num_step - self.t0 += duration - return ls - - self._f_fit_compiled[shared_args_str] = train_fun - return self._f_fit_compiled[shared_args_str] - - def _get_f_predict_one_step(self, shared_args: Dict = None, jit: bool = False): - if shared_args is None: shared_args = tools.DotDict() - if not isinstance(shared_args, dict): - raise ValueError(f'"shared_args" must be a dict, ' - f'but got {type(shared_args)}') - - shared_args2 = {k: v for k, v in shared_args.items()} - shared_args2['_local_jit_'] = jit - shared_args2['_one_step_'] = True - shared_args_str = serialize_kwargs(shared_args) - if shared_args_str not in self._f_predict_compiled: - - monitor_func = self._build_monitors(self._mon_info[0], self._mon_info[1], shared_args) - - def run_func(t, i, x): - shared = tools.DotDict(t=t, i=i, dt=self.dt) - shared.update(shared_args) - self.target.clear_input() - outs = self.target(shared, x) - hist = monitor_func(shared) - return outs, hist - - if self.jit[c.FIT_PHASE] and jit: - dyn_vars = self.target.vars() - dyn_vars.update(self.dyn_vars) - dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView) - self._f_predict_compiled[shared_args_str] = bm.jit(run_func, dyn_vars=dyn_vars.unique()) - else: - self._f_predict_compiled[shared_args_str] = run_func - return self._f_predict_compiled[shared_args_str] - - -class OTTT(BPTrainer): - pass diff --git a/brainpy/_src/train/online.py b/brainpy/_src/train/online.py index 7f22fbc3d..9b21b25fe 100644 --- a/brainpy/_src/train/online.py +++ b/brainpy/_src/train/online.py @@ -11,6 +11,7 @@ from brainpy import math as bm, tools from brainpy._src.dyn.base import DynamicalSystem +from brainpy._src.dyn.context import share from brainpy.algorithms.online import get, OnlineAlgorithm, RLS from brainpy.check import serialize_kwargs from brainpy.errors import NoImplementationError @@ -252,14 +253,15 @@ def run_func(all_inputs): def _step_func_fit(self, shared_args, t, i, x, ys): shared = tools.DotDict(t=t, dt=self.dt, i=i) + shared.update(shared_args) + share.save(**shared) # input step self.target.clear_input() self._step_func_input(shared) # update step - shared.update(shared_args) - args = (shared,) if x is None else (shared, x) + args = () if x is None else (x, ) out = self.target(*args) # monitor step diff --git a/brainpy/check.py b/brainpy/check.py index 6e7704e72..763efa4a4 100644 --- a/brainpy/check.py +++ b/brainpy/check.py @@ -3,7 +3,6 @@ from functools import wraps, partial from typing import Union, Sequence, Dict, Callable, Tuple, Type, Optional, Any -import jax import numpy as np import numpy as onp from jax import numpy as jnp @@ -44,6 +43,7 @@ ] _check = True +_name_check = True def is_checking(): @@ -63,6 +63,8 @@ def turn_off(): _check = False +# def turn_off_name_check + def is_shape_consistency(shapes, free_axes=None, return_format_shapes=False): assert isinstance(shapes, (tuple, list)), f'Must be a sequence of shape. While we got {shapes}.' for shape in shapes: diff --git a/brainpy/experimental.py b/brainpy/experimental.py deleted file mode 100644 index f7540ea37..000000000 --- a/brainpy/experimental.py +++ /dev/null @@ -1,30 +0,0 @@ - -# synaptic delays -from brainpy._src.experimental.delay import ( - Delay as Delay, -) - -# synapse plasticity -from brainpy._src.experimental.synstp import ( - STP as STP, - STD as STD, -) - -# synapse outputs -from brainpy._src.experimental.synout import ( - COBA as COBA, - CUBA as CUBA, - MgBlock as MgBlock, -) - -# Synapses -from brainpy._src.experimental.synapses import ( - Exponential as Exponential, -) - -# neurons -from brainpy._src.experimental.neurons import ( - LIF as LIF, -) - - diff --git a/brainpy/layers.py b/brainpy/layers.py index fc91f6b63..f9e1306c6 100644 --- a/brainpy/layers.py +++ b/brainpy/layers.py @@ -82,8 +82,5 @@ Conv1dLSTMCell as Conv1dLSTMCell, Conv2dLSTMCell as Conv2dLSTMCell, Conv3dLSTMCell as Conv3dLSTMCell, - - VanillaRNN as VanillaRNN, - GRU as GRU, - LSTM as LSTM, ) + diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 3d1582f19..e8456d6b0 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -26,7 +26,6 @@ from .modes import * from .environment import * from .others import * -from .context import share mode = NonBatchingMode() '''Default computation mode.''' diff --git a/brainpy/math/context.py b/brainpy/math/context.py index 19631e022..8b1378917 100644 --- a/brainpy/math/context.py +++ b/brainpy/math/context.py @@ -1,4 +1 @@ -from brainpy._src.math.context import ( - share as share, -) diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py index e1249e751..bae3b9cc2 100644 --- a/brainpy/math/delayvars.py +++ b/brainpy/math/delayvars.py @@ -5,7 +5,6 @@ LengthDelay as LengthDelay, NeuTimeDelay as NeuTimeDelay, NeuLenDelay as NeuLenDelay, - DelayVariable as DelayVariable, ROTATE_UPDATE as ROTATE_UPDATE, CONCAT_UPDATE as CONCAT_UPDATE, ) diff --git a/brainpy/math/object_transform.py b/brainpy/math/object_transform.py index 43158658a..844f01bfc 100644 --- a/brainpy/math/object_transform.py +++ b/brainpy/math/object_transform.py @@ -24,7 +24,6 @@ from brainpy._src.math.object_transform.function import ( to_object as to_object, - to_dynsys as to_dynsys, function as function, ) diff --git a/brainpy/neurons.py b/brainpy/neurons.py index 19c380a35..fc084025e 100644 --- a/brainpy/neurons.py +++ b/brainpy/neurons.py @@ -36,5 +36,5 @@ Izhikevich as Izhikevich, HindmarshRose as HindmarshRose, FHN as FHN, + LIF_SFA_Bellec2020, ) - diff --git a/brainpy/syn.py b/brainpy/syn.py new file mode 100644 index 000000000..8efc9f151 --- /dev/null +++ b/brainpy/syn.py @@ -0,0 +1,17 @@ + +from brainpy._src.dyn.synapses_v2.base import ( + SynConn as SynConn, + SynOut as SynOut, + SynSTP as SynSTP, +) +from brainpy._src.dyn.synapses_v2.syn_plasticity import ( + STD as STD, + STP as STP, +) +from brainpy._src.dyn.synapses_v2.syn_outs import ( + CUBA as CUBA, + COBA as COBA, +) +from brainpy._src.dyn.synapses_v2.abstract_models import ( + Exponential as Exponential, +) diff --git a/brainpy/synapses.py b/brainpy/synapses.py index 05cec9cc0..77e339982 100644 --- a/brainpy/synapses.py +++ b/brainpy/synapses.py @@ -22,3 +22,4 @@ GapJunction as GapJunction, ) + diff --git a/brainpy/synouts.py b/brainpy/synouts.py index 5f66035b2..3365f1038 100644 --- a/brainpy/synouts.py +++ b/brainpy/synouts.py @@ -4,7 +4,14 @@ COBA as COBA, CUBA as CUBA, ) +from brainpy._src.dyn.synouts.conductances import ( + eCOBA, + eCUBA, +) from brainpy._src.dyn.synouts.ions import ( MgBlock as MgBlock, ) +from brainpy._src.dyn.synouts.ions import ( + eMgBlock, +) diff --git a/brainpy/tools/__init__.py b/brainpy/tools.py similarity index 100% rename from brainpy/tools/__init__.py rename to brainpy/tools.py diff --git a/examples/dynamics_analysis/3d_reduced_trn_model.py b/examples/dynamics_analysis/3d_reduced_trn_model.py index ce3d0e8c0..247e91281 100644 --- a/examples/dynamics_analysis/3d_reduced_trn_model.py +++ b/examples/dynamics_analysis/3d_reduced_trn_model.py @@ -7,7 +7,7 @@ bp.math.set_platform('cpu') -class ReducedTRNModel(bp.dyn.NeuGroup): +class ReducedTRNModel(bp.NeuGroup): def __init__(self, size, name=None, T=36., method='rk4'): super(ReducedTRNModel, self).__init__(size=size, name=name) diff --git a/examples/dynamics_analysis/4d_HH_model.py b/examples/dynamics_analysis/4d_HH_model.py index 2d83c8265..3eb063564 100644 --- a/examples/dynamics_analysis/4d_HH_model.py +++ b/examples/dynamics_analysis/4d_HH_model.py @@ -7,9 +7,9 @@ I = 5. model = bp.neurons.HH(1) -runner = bp.DSRunner(model, inputs=(model.input, I), monitors=['V']) -runner.run(100) -bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True) +# runner = bp.DSRunner(model, inputs=(model.input, I), monitors=['V']) +# runner.run(100) +# bp.visualize.line_plot(runner.mon.ts, runner.mon.V, legend='V', show=True) # analysis bm.enable_x64() diff --git a/examples/dynamics_analysis/highdim_RNN_Analysis.py b/examples/dynamics_analysis/highdim_RNN_Analysis.py index 3bf851b03..75b844247 100644 --- a/examples/dynamics_analysis/highdim_RNN_Analysis.py +++ b/examples/dynamics_analysis/highdim_RNN_Analysis.py @@ -13,32 +13,7 @@ import numpy as np import matplotlib.pyplot as plt from sklearn.decomposition import PCA - -# In this tutorial, we will use supervised learning to train a recurrent -# neural network on a simple perceptual decision making task, and analyze -# the trained network using dynamical system analysis. - -# Defining a cognitive task -# ---- -# We will import the task from the neurogym library. -# Please install neurogym: -# https://github.com/neurogym/neurogym - -import neurogym as ngym - -# Environment -task = 'PerceptualDecisionMaking-v0' -kwargs = {'dt': 100} -seq_len = 100 - -# Make supervised dataset -dataset = ngym.Dataset(task, - env_kwargs=kwargs, - batch_size=16, - seq_len=seq_len) - -# A sample environment from dataset -env = dataset.env +import brainpy_datasets as bd # Define a vanilla continuous-time recurrent network @@ -94,14 +69,17 @@ def update(self, sha, x): return self.readout(self.h.value) +ds = bd.cognitive.RatePerceptualDecisionMaking(dt=100., num_trial=16 * 200) +loader = bd.cognitive.TaskLoader(ds, max_seq_len=100, batch_size=16, data_first_axis='B') + # Train the recurrent network on the decision-making task # --- # Instantiate the network and print information with bm.training_environment(): - net = RNNNet(num_input=env.observation_space.shape[0], + net = RNNNet(num_input=ds.num_inputs, num_hidden=64, - num_output=env.action_space.n, - dt=env.dt) + num_output=ds.num_outputs, + dt=ds.dt) def loss(predictions, targets): @@ -117,38 +95,29 @@ def loss(predictions, targets): return total_loss, {'accuracy': accuracy} -def data_generation(): - for _ in range(100): - inputs, labels = dataset() - inputs = bm.asarray(np.moveaxis(inputs, 0, 1)) - labels = bm.asarray(np.moveaxis(labels, 0, 1)) - yield inputs, labels - - -trainer = bp.train.BPTT(net, - loss_fun=loss, - loss_has_aux=True, - optimizer=bp.optim.Adam(lr=1e-3)) -trainer.fit(data_generation, num_epoch=20, num_report=100) +trainer = bp.BPTT(net, + loss_fun=loss, + loss_has_aux=True, + optimizer=bp.optim.Adam(lr=1e-3)) +trainer.fit(loader, num_epoch=20) # Visualize neural activity for in sample trials # --- # We will run the network for 100 sample trials, then visual the neural activity trajectories in a PCA space. runner = bp.DSTrainer(net, monitors={'r': net.h}, progress_bar=False) -env.reset(no_step=True) num_trial = 100 -activity_dict = {} -trial_infos = {} +activity_dict = [] +groundtruths = [] for i in range(num_trial): - env.new_trial() - inputs = bm.asarray(env.ob[np.newaxis]) + ob, re = ds[i][:2] + groundtruths.append(re[-1] - 1) + inputs = bm.asarray(ob[np.newaxis]) _ = runner.predict(inputs) - activity_dict[i] = runner.mon['r'][0] - trial_infos[i] = env.trial + activity_dict.append(runner.mon['r'][0]) # Concatenate activity for PCA -activity = np.concatenate(list(activity_dict[i] for i in range(num_trial)), axis=0) +activity = np.concatenate(activity_dict, axis=0) print('Shape of the neural activity: (Time points, Neurons): ', activity.shape) pca = PCA(n_components=2) @@ -159,7 +128,7 @@ def data_generation(): fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, sharex=True, figsize=(12, 5)) for i in range(num_trial): activity_pc = pca.transform(activity_dict[i]) - color = 'red' if trial_infos[i]['ground_truth'] == 0 else 'blue' + color = 'red' if groundtruths[i] == 0 else 'blue' _ = ax1.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color) if i < 5: _ = ax2.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color) @@ -193,8 +162,7 @@ def data_generation(): plt.figure(figsize=(10, 5)) for i in range(10): activity_pc = pca.transform(activity_dict[i]) - trial = trial_infos[i] - color = 'red' if trial['ground_truth'] == 0 else 'blue' + color = 'red' if groundtruths[i] == 0 else 'blue' plt.plot(activity_pc[:, 0], activity_pc[:, 1], 'o-', color=color, alpha=0.1) # Fixed points are shown in cross fixedpoints_pc = pca.transform(finder.fixed_points['h']) diff --git a/examples/dynamics_simulation/COBA.py b/examples/dynamics_simulation/COBA.py new file mode 100644 index 000000000..2a0e29b84 --- /dev/null +++ b/examples/dynamics_simulation/COBA.py @@ -0,0 +1,114 @@ +import brainpy as bp +import brainpy.connect as C + + +class EINet(bp.DynamicalSystemNS): + def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None): + super().__init__() + + self.bg_exc = e_input + self.bg_inh = i_input + + # network size + num_exc = int(3200 * scale) + num_inh = int(800 * scale) + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.), input_var=False) + self.E = bp.neurons.LIF(num_exc, **pars) + self.I = bp.neurons.LIF(num_inh, **pars) + + # synapses + we = 0.6 / scale # excitatory synaptic weight (voltage) + wi = 6.7 / scale # inhibitory synaptic weight + self.E2E = bp.syn.Exponential( + C.FixedProb(0.02, pre=self.E.size, post=self.E.size), + g_max=we, tau=5., out=bp.syn.COBA(self.E.V, E=0.) + ) + self.E2I = bp.syn.Exponential( + C.FixedProb(0.02, pre=self.E.size, post=self.I.size, ), + g_max=we, tau=5., out=bp.syn.COBA(self.I.V, E=0.) + ) + self.I2E = bp.syn.Exponential( + C.FixedProb(0.02, pre=self.I.size, post=self.E.size), + g_max=wi, tau=10., out=bp.syn.COBA(self.E.V, E=-80.) + ) + self.I2I = bp.syn.Exponential( + C.FixedProb(0.02, pre=self.I.size, post=self.I.size), + g_max=wi, tau=10., out=bp.syn.COBA(self.I.V, E=-80.) + ) + self.delayE = bp.Delay(self.E.spike, entries={'E': delay}) + self.delayI = bp.Delay(self.I.spike, entries={'I': delay}) + + @bp.not_pass_sha + def update(self): + e_spike = self.delayE.at('E') + i_spike = self.delayI.at('I') + e_inp = self.E2E(e_spike) + self.I2E(i_spike) + self.bg_exc + i_inp = self.I2I(i_spike) + self.E2I(e_spike) + self.bg_inh + self.delayE(self.E(e_inp)) + self.delayI(self.I(i_inp)) + + +class EINetv2(bp.DynamicalSystemNS): + def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None): + super().__init__() + + self.bg_exc = e_input + self.bg_inh = i_input + + # network size + num_exc = int(3200 * scale) + num_inh = int(800 * scale) + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.), input_var=False) + self.E = bp.neurons.LIF(num_exc, **pars) + self.I = bp.neurons.LIF(num_inh, **pars) + + # synapses + we = 0.6 / scale # excitatory synaptic weight (voltage) + wi = 6.7 / scale # inhibitory synaptic weight + self.E2E = bp.syn.Exponential( + C.FixedProb(0.02, pre=self.E.size, post=self.E.size), + g_max=we, tau=5., out=bp.syn.COBA(self.E.V, E=0.) + ) + self.E2I = bp.syn.Exponential( + C.FixedProb(0.02, pre=self.E.size, post=self.I.size, ), + g_max=we, tau=5., out=bp.syn.COBA(self.I.V, E=0.) + ) + self.I2E = bp.syn.Exponential( + C.FixedProb(0.02, pre=self.I.size, post=self.E.size), + g_max=wi, tau=10., out=bp.syn.COBA(self.E.V, E=-80.) + ) + self.I2I = bp.syn.Exponential( + C.FixedProb(0.02, pre=self.I.size, post=self.I.size), + g_max=wi, tau=10., out=bp.syn.COBA(self.I.V, E=-80.) + ) + bp.share.save('t', 2.) + bp.share.save('dt', 0.1) + bp.share.save('E-spike', bp.Delay(self.E.spike, entries={'E': delay})) + bp.share.save('I-spike', bp.Delay(self.I.spike, entries={'I': delay})) + + @bp.not_pass_sha + def update(self): + e_spike = bp.share.load('E-spike').at('E') + i_spike = bp.share.load('I-spike').at('I') + e_inp = self.E2E(e_spike) + self.I2E(i_spike) + self.bg_exc + i_inp = self.I2I(i_spike) + self.E2I(e_spike) + self.bg_inh + self.E(e_inp) + self.I(i_inp) + + +# simulation +net = EINet(delay=0., scale=2.) +# net = EINetv2(delay=0., scale=2.) +runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}) +r = runner.run(100., eval_time=True) +print(r) +bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) + + + diff --git a/examples/dynamics_simulation/hh_model.py b/examples/dynamics_simulation/hh_model.py index dc54c1c9a..7e81554ce 100644 --- a/examples/dynamics_simulation/hh_model.py +++ b/examples/dynamics_simulation/hh_model.py @@ -16,7 +16,7 @@ def __init__(self, size): I, length = bp.inputs.section_input(values=[0, 5, 0], durations=[100, 500, 100], return_length=True) -runner = bp.dyn.DSRunner( +runner = bp.DSRunner( hh, monitors=['V', 'INa.p', 'INa.q', 'IK.p'], inputs=[hh.input, I, 'iter'], diff --git a/examples/dynamics_simulation/whole_brain_simulation_with_fhn.py b/examples/dynamics_simulation/whole_brain_simulation_with_fhn.py index bca51a989..acc530986 100644 --- a/examples/dynamics_simulation/whole_brain_simulation_with_fhn.py +++ b/examples/dynamics_simulation/whole_brain_simulation_with_fhn.py @@ -4,7 +4,6 @@ import matplotlib.pyplot as plt import numpy as np -import jax.numpy as jnp import brainpy as bp import brainpy.math as bm @@ -95,6 +94,6 @@ def net_analysis(): if __name__ == '__main__': - bifurcation_analysis() - net_simulation() + # bifurcation_analysis() + # net_simulation() net_analysis() diff --git a/examples/dynamics_training/Song_2016_EI_RNN.py b/examples/dynamics_training/Song_2016_EI_RNN.py index ff693b84c..404d604a7 100644 --- a/examples/dynamics_training/Song_2016_EI_RNN.py +++ b/examples/dynamics_training/Song_2016_EI_RNN.py @@ -1,127 +1,44 @@ -# %% [markdown] -# # *(Song, et al., 2016)*: Training excitatory-inhibitory recurrent network - -# %% [markdown] -# Implementation of the paper: -# -# - Song, H. F. , G. R. Yang , and X. J. Wang . "Training Excitatory-Inhibitory Recurrent Neural Networks for Cognitive Tasks: A Simple and Flexible Framework." Plos Computational Biology 12.2(2016):e1004792. -# -# The original code is based on PyTorch (https://github.com/gyyang/nn-brain/blob/master/EI_RNN.ipynb). However, comparing with the PyTorch codes, the training on BrainPy speeds up nearly four folds. - -# %% [markdown] -# Here we will train recurrent neural network with excitatory and inhibitory neurons on a simple perceptual decision making task. +import brainpy_datasets as bp_data +import matplotlib.pyplot as plt +import numpy as np -# %% import brainpy as bp import brainpy.math as bm -bm.set_platform('cpu') -bm.set_environment(bm.training_mode) - -# %% -import numpy as np -import matplotlib.pyplot as plt - -# %% [markdown] -# ## Defining a perceptual decision making task - -# %% -# We will import the task from the neurogym library. -# Please install neurogym: -# -# https://github.com/neurogym/neurogym - -import neurogym as ngym - -# %% -# Environment -task = 'PerceptualDecisionMaking-v0' -timing = { - 'fixation': ('choice', (50, 100, 200, 400)), - 'stimulus': ('choice', (100, 200, 400, 800)), -} -kwargs = {'dt': 20, 'timing': timing} -seq_len = 100 - -# Make supervised dataset -dataset = ngym.Dataset(task, - env_kwargs=kwargs, - batch_size=16, - seq_len=seq_len) - -# A sample environment from dataset -env = dataset.env - -# Visualize the environment with 2 sample trials -_ = ngym.utils.plot_env(env, num_trials=2, fig_kwargs={'figsize': (10, 6)}) -plt.show() - -# %% -input_size = env.observation_space.shape[0] -output_size = env.action_space.n -batch_size = dataset.batch_size - -print(f'Input size = {input_size}') -print(f'Output size = {output_size}') -print(f'Bacth size = {batch_size}') - - -# %% [markdown] -# ## Define E-I recurrent network -# -# Here we define a E-I recurrent network, in particular, no self-connections are allowed. - -# %% -class RNN(bp.DynamicalSystem): - r"""E-I RNN. - - The RNNs are described by the equations - - .. math:: - - \begin{gathered} - \tau \dot{\mathbf{x}}=-\mathbf{x}+W^{\mathrm{rec}} \mathbf{r}+W^{\mathrm{in}} - \mathbf{u}+\sqrt{2 \tau \sigma_{\mathrm{rec}}^{2}} \xi \\ - \mathbf{r}=[\mathbf{x}]_{+} \\ - \mathbf{z}=W^{\text {out }} \mathbf{r} - \end{gathered} - - In practice, the continuous-time dynamics are discretized to Euler form - in time steps of size :math:`\Delta t` as - - .. math:: - - \begin{gathered} - \mathbf{x}_{t}=(1-\alpha) \mathbf{x}_{t-1}+\alpha\left(W^{\mathrm{rec}} \mathbf{r}_{t-1}+ - W^{\mathrm{in}} \mathbf{u}_{t}\right)+\sqrt{2 \alpha \sigma_{\mathrm{rec}}^{2}} \mathbf{N}(0,1) \\ - \mathbf{r}_{t}=\left[\mathbf{x}_{t}\right]_{+} \\ - \mathbf{z}_{t}=W^{\mathrm{out}} \mathbf{r}_{t} - \end{gathered} - - where :math:`\alpha = \Delta t/\tau` and :math:`N(0, 1)` are normally distributed - random numbers with zero mean and unit variance, sampled independently at every time step. - """ - def __init__(self, num_input, num_hidden, num_output, num_batch, - dt=None, e_ratio=0.8, sigma_rec=0., seed=None, - w_ir=bp.init.KaimingUniform(scale=1.), - w_rr=bp.init.KaimingUniform(scale=1.), - w_ro=bp.init.KaimingUniform(scale=1.)): - super(RNN, self).__init__() +# data +ds = bp_data.cognitive.RatePerceptualDecisionMaking( + dt=20., + t_fixation=lambda: np.random.choice((50, 100, 200, 400)), + t_stimulus=lambda: np.random.choice((100, 200, 400, 800)), + num_trial=64 * 100 +) +loader = bp_data.cognitive.TaskLoader(ds, + max_seq_len=100, + batch_size=64, + data_first_axis='T') + + +# EI RNN model +class EI_RNN(bp.DynamicalSystem): + def __init__( + self, num_input, num_hidden, num_output, dt, + e_ratio=0.8, sigma_rec=0., seed=None, + w_ir=bp.init.KaimingUniform(scale=1.), + w_rr=bp.init.KaimingUniform(scale=1.), + w_ro=bp.init.KaimingUniform(scale=1.) + ): + super(EI_RNN, self).__init__() # parameters self.tau = 100 - self.num_batch = num_batch self.num_input = num_input self.num_hidden = num_hidden self.num_output = num_output self.e_size = int(num_hidden * e_ratio) self.i_size = num_hidden - self.e_size - if dt is None: - self.alpha = 1 - else: - self.alpha = dt / self.tau + self.alpha = dt / self.tau self.sigma_rec = (2 * self.alpha) ** 0.5 * sigma_rec # Recurrent noise - self.rng = bm.random.RandomState(seed) + self.rng = bm.random.RandomState(seed=seed) # hidden mask mask = np.tile([1] * self.e_size + [-1] * self.i_size, (num_hidden, 1)) @@ -129,22 +46,25 @@ def __init__(self, num_input, num_hidden, num_output, num_batch, self.mask = bm.asarray(mask, dtype=bm.float_) # input weight - self.w_ir = bm.TrainVar(w_ir((num_input, num_hidden))) + self.w_ir = bm.TrainVar(bp.init.parameter(w_ir, (num_input, num_hidden))) # recurrent weight bound = 1 / num_hidden ** 0.5 - self.w_rr = bm.TrainVar(w_rr((num_hidden, num_hidden))) + self.w_rr = bm.TrainVar(bp.init.parameter(w_rr, (num_hidden, num_hidden))) self.w_rr[:, :self.e_size] /= (self.e_size / self.i_size) self.b_rr = bm.TrainVar(self.rng.uniform(-bound, bound, num_hidden)) # readout weight bound = 1 / self.e_size ** 0.5 - self.w_ro = bm.TrainVar(w_ro((self.e_size, num_output))) + self.w_ro = bm.TrainVar(bp.init.parameter(w_ro, (self.e_size, num_output))) self.b_ro = bm.TrainVar(self.rng.uniform(-bound, bound, num_output)) # variables - self.h = bm.Variable(bm.zeros((num_batch, num_hidden))) - self.o = bm.Variable(bm.zeros((num_batch, num_output))) + self.reset_state(1) + + def reset_state(self, batch_size): + self.h = bm.Variable(bm.zeros((batch_size, self.num_hidden)), batch_axis=0) + self.o = bm.Variable(bm.zeros((batch_size, self.num_output)), batch_axis=0) def cell(self, x, h): ins = x @ self.w_ir + h @ (bm.abs(self.w_rr) * self.mask) + self.b_rr @@ -155,172 +75,84 @@ def cell(self, x, h): def readout(self, h): return h @ self.w_ro + self.b_ro - def make_update(self, h: bm.Array, o: bm.Array): - def f(x): - h.value = self.cell(x, h.value) - o.value = self.readout(h.value[:, :self.e_size]) - return h.value, o.value - - return f + @bp.not_pass_sha + def update(self, x): + self.h.value = self.cell(x, self.h) + self.o.value = self.readout(self.h[:, :self.e_size]) + return self.h.value, self.o.value def predict(self, xs): self.h[:] = 0. - return bm.for_loop(self.make_update(self.h, self.o), xs, dyn_vars=self.vars()) + return bm.for_loop(self.update, xs) def loss(self, xs, ys): hs, os = self.predict(xs) - os = os.reshape((-1, os.shape[-1])) - return bp.losses.cross_entropy_loss(os, ys.flatten()) + l = bp.losses.cross_entropy_loss(os.reshape((-1, os.shape[-1])), ys.flatten()) + acc = bm.mean(bm.argmax(os, axis=-1) == ys) + return l, acc -# %% [markdown] -# ## Train the network on the decision making task - -# %% # Instantiate the network and print information hidden_size = 50 -net = RNN(num_input=input_size, - num_hidden=hidden_size, - num_output=output_size, - num_batch=batch_size, - dt=env.dt, - sigma_rec=0.15) +net = EI_RNN(num_input=len(ds.input_features), + num_hidden=hidden_size, + num_output=len(ds.output_features), + dt=ds.dt, + sigma_rec=0.15) + -# %% # Adam optimizer opt = bp.optim.Adam(lr=0.001, train_vars=net.train_vars().unique()) -# %% + # gradient function -grad = bm.grad(net.loss, - dyn_vars=net.vars().unique(), - grad_vars=net.train_vars().unique(), - return_value=True) +grad_f = bm.grad(net.loss, + child_objs=net, + grad_vars=net.train_vars().unique(), + return_value=True, + has_aux=True) -# %% -@bm.jit -@bm.to_object(child_objs=(grad, opt)) # add nodes and vars used +# training function +@bm.jit(child_objs=(net, opt)) def train(xs, ys): - grads, l = grad(xs, ys) + grads, loss, acc = grad_f(xs, ys) opt.update(grads) - return l - - -# %% [markdown] -# The training speeds up nearly 4 times, comparing with the original PyTorch codes. - -# %% -running_loss = 0 -print_step = 200 -for i in range(5000): - inputs, labels = dataset() - inputs = bm.asarray(inputs) - labels = bm.asarray(labels) - loss = train(inputs, labels) - running_loss += loss - if i % print_step == (print_step - 1): - running_loss /= print_step - print('Step {}, Loss {:0.4f}'.format(i + 1, running_loss)) - running_loss = 0 - - -# %% [markdown] -# ## Run the network post-training and record neural activity - -# %% -predict = bm.jit(net.predict, dyn_vars=net.vars()) - -# %% -env.reset(no_step=True) -env.timing.update({'fixation': ('constant', 500), 'stimulus': ('constant', 500)}) -perf = 0 -num_trial = 500 -activity_dict = {} -trial_infos = {} -stim_activity = [[], []] # response for ground-truth 0 and 1 -for i in range(num_trial): - env.new_trial() - ob, gt = env.ob, env.gt - inputs = bm.asarray(ob[:, np.newaxis, :]) - rnn_activity, action_pred = predict(inputs) - - # Compute performance - action_pred = bm.as_numpy(action_pred) - choice = np.argmax(action_pred[-1, 0, :]) - correct = choice == gt[-1] - - # Log trial info - trial_info = env.trial - trial_info.update({'correct': correct, 'choice': choice}) - trial_infos[i] = trial_info - - # Log stimulus period activity - rnn_activity = bm.as_numpy(rnn_activity)[:, 0, :] - activity_dict[i] = rnn_activity - - # Compute stimulus selectivity for all units - # Compute each neuron's response in trials where ground_truth=0 and 1 respectively - rnn_activity = rnn_activity[env.start_ind['stimulus']: env.end_ind['stimulus']] - stim_activity[env.trial['ground_truth']].append(rnn_activity) - -print('Average performance', np.mean([val['correct'] for val in trial_infos.values()])) - -# %% [markdown] -# ### Plot neural activity from sample trials - -# %% -trial = 2 + return loss, acc + + +# training +for epoch_i in range(30): + losses = [] + accs = [] + for x, y in loader: + net.reset_state(x.shape[1]) + l, a = train(x, y) + losses.append(l) + accs.append(a) + print(f'Epoch {epoch_i}, loss {np.mean(losses)}, acc {np.mean(accs)}') + + +# testing +ds.t_fixation = 500. # set the fixed time duration for fixation and stimulus +ds.t_stimulus = 500. +x, y = zip(*[ds[i] for i in range(50)]) # get 50 trials +x = np.asarray(x).transpose(1, 0, 2) +y = np.asarray(y).transpose(1, 0) +net.reset_state(x.shape[1]) +rnn_activity, action_pred = net.predict(x) +rnn_activity = bm.as_numpy(rnn_activity) +choice = np.argmax(bm.as_numpy(action_pred[-1]), axis=1) +correct = choice == y[-1] +print('Average performance', np.mean(correct)) + +# plot activity +trial = 0 plt.figure(figsize=(8, 6)) -_ = plt.plot(activity_dict[trial][:, :net.e_size], color='blue', label='Excitatory') -_ = plt.plot(activity_dict[trial][:, net.e_size:], color='red', label='Inhibitory') +_ = plt.plot(rnn_activity[:, trial, :net.e_size], color='blue', label='Excitatory') +_ = plt.plot(rnn_activity[:, trial, net.e_size:], color='red', label='Inhibitory') plt.xlabel('Time step') plt.ylabel('Activity') plt.show() -# %% [markdown] -# ### Compute stimulus selectivity for sorting neurons -# -# Here for each neuron we compute its stimulus period selectivity $d'$ - -# %% -mean_activity = [] -std_activity = [] -for ground_truth in [0, 1]: - activity = np.concatenate(stim_activity[ground_truth], axis=0) - mean_activity.append(np.mean(activity, axis=0)) - std_activity.append(np.std(activity, axis=0)) - -# Compute d' -selectivity = (mean_activity[0] - mean_activity[1]) -selectivity /= np.sqrt((std_activity[0] ** 2 + std_activity[1] ** 2 + 1e-7) / 2) -# Sort index for selectivity, separately for E and I -ind_sort = np.concatenate((np.argsort(selectivity[:net.e_size]), - np.argsort(selectivity[net.e_size:]) + net.e_size)) - -# %% [markdown] -# ### Plot network connectivity sorted by stimulus selectivity - -# %% -# Plot distribution of stimulus selectivity -plt.figure(figsize=(6, 4)) -plt.hist(selectivity) -plt.xlabel('Selectivity') -plt.ylabel('Number of neurons') -plt.show() - -# %% -W = bm.as_numpy(bm.abs(net.w_rr) * net.mask) -# Sort by selectivity -W = W[:, ind_sort][ind_sort, :] -wlim = np.max(np.abs(W)) - -plt.figure(figsize=(10, 10)) -plt.imshow(W, cmap='bwr_r', vmin=-wlim, vmax=wlim) -plt.colorbar() -plt.xlabel('From neurons') -plt.ylabel('To neurons') -plt.title('Network connectivity') -plt.tight_layout() -plt.show() diff --git a/examples/dynamics_training/integrator_rnn.py b/examples/dynamics_training/integrator_rnn.py index 9eb3075ed..706e51bd6 100644 --- a/examples/dynamics_training/integrator_rnn.py +++ b/examples/dynamics_training/integrator_rnn.py @@ -1,10 +1,7 @@ # -*- coding: utf-8 -*- -from functools import partial - import matplotlib.pyplot as plt -import jax.numpy as jnp import brainpy as bp import brainpy.math as bm @@ -13,8 +10,7 @@ num_batch = 128 -@partial(bm.jit, static_argnames=['batch_size']) -@bm.to_object(dyn_vars=bm.random.DEFAULT) +@bm.jit(static_argnames=['batch_size'], dyn_vars=bm.random.DEFAULT) def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10): # Create the white noise input sample = bm.random.normal(size=(batch_size, 1, 1)) @@ -22,7 +18,7 @@ def build_inputs_and_targets(mean=0.025, scale=0.01, batch_size=10): samples = bm.random.normal(size=(batch_size, num_step, 1)) noise_t = scale / dt ** 0.5 * samples inputs = bias + noise_t - targets = jnp.cumsum(inputs, axis=1) + targets = bm.cumsum(inputs, axis=1) return inputs, targets @@ -74,3 +70,4 @@ def loss(predictions, targets, l2_reg=2e-4): plt.plot(bm.as_numpy(predicts[0]).flatten(), label='Prediction') plt.legend() plt.show() + diff --git a/examples/training_ann_models/mnist-cnn.py b/examples/training_ann_models/mnist-cnn.py index 99c227fe6..602191156 100644 --- a/examples/training_ann_models/mnist-cnn.py +++ b/examples/training_ann_models/mnist-cnn.py @@ -58,3 +58,5 @@ def generator(): trainer.fit(get_data(train_dataset), get_data(test_dataset), num_epoch=2) + + diff --git a/examples/training_snn_models/spikebased_bp_for_cifar10.py b/examples/training_snn_models/spikebased_bp_for_cifar10.py index c356ac84e..48c93b871 100644 --- a/examples/training_snn_models/spikebased_bp_for_cifar10.py +++ b/examples/training_snn_models/spikebased_bp_for_cifar10.py @@ -33,7 +33,7 @@ parser = argparse.ArgumentParser(description='CIFAR10 Training') parser.add_argument('-data', default='/mnt/d/data', type=str, help='path to dataset') -parser.add_argument('-b', default=64, type=int, metavar='N') +parser.add_argument('-b', default=16, type=int, metavar='N') parser.add_argument('-T', default=100, type=int, help='Simulation timesteps') parser.add_argument('-lr', default=0.0025, type=float, help='initial learning rate') parser.add_argument('-resume', action='store_true', help='resume from the checkpoint path') @@ -41,7 +41,7 @@ help='number of data loading workers (default: 4)') -class LIFNode(bp.DynamicalSystem): +class LIFNode(bp.DynamicalSystemNS): def __init__(self, size, tau=100.0, v_threshold=1.0, v_reset=0.0, fire: bool = True): super().__init__() bp.check.is_subclass(self.mode, [bp.math.TrainingMode, bp.math.BatchingMode]) @@ -77,7 +77,7 @@ def f_bwd(self, res, g): g = bm.where(res[0], 0., g).value return (self.grad_acc.value * g,) - def update(self, s, dv): + def update(self, dv): self.v += dv if self.fire: spike = self.relu_grad(bm.as_jax(self.v.value - self.v_threshold)) @@ -93,7 +93,7 @@ def update(self, s, dv): return self.v.value -class IFNode(bp.DynamicalSystem): +class IFNode(bp.DynamicalSystemNS): def __init__(self, size, v_threshold=0.75, v_reset=0.0): super().__init__() bp.check.is_subclass(self.mode, [bm.TrainingMode, bm.BatchingMode]) @@ -114,14 +114,14 @@ def grad(dz): return bm.asarray(x > 0.0, bm.float_).value, grad - def update(self, s, dv): + def update(self, dv): self.v += dv spike = self.relu_grad(bm.as_jax(self.v - self.v_threshold)) self.v.value = bm.where(self.v > self.v_threshold, self.v_reset, self.v.value) return spike -class ResNet11(bp.DynamicalSystem): +class ResNet11(bp.DynamicalSystemNS): def __init__(self): super().__init__() @@ -178,14 +178,14 @@ def conv_init(self, shape): def linear_init(self, shape): return bm.random.normal(0., np.sqrt(1.0 / shape[0]), shape) - def update(self, s, x): - x = self.if1(s, self.avgpool1(s, self.lif11(s, self.cnn11(s, x)))) - x = self.lif2(s, self.cnn22(s, self.lif21(s, self.cnn21(s, x))) + self.shortcut1(s, x)) - x = self.lif3(s, self.cnn32(s, self.lif31(s, self.cnn31(s, x))) + self.shortcut2(s, x)) - x = self.lif4(s, self.cnn42(s, self.lif41(s, self.cnn41(s, x))) + self.shortcut3(s, x)) - x = self.lif5(s, self.cnn52(s, self.lif51(s, self.cnn51(s, x))) + self.shortcut4(s, x)) + def update(self, x): + x = self.if1(self.avgpool1(self.lif11(self.cnn11(x)))) + x = self.lif2(self.cnn22(self.lif21(self.cnn21(x))) + self.shortcut1(x)) + x = self.lif3(self.cnn32(self.lif31(self.cnn31(x))) + self.shortcut2(x)) + x = self.lif4(self.cnn42(self.lif41(self.cnn41(x))) + self.shortcut3(x)) + x = self.lif5(self.cnn52(self.lif51(self.cnn51(x))) + self.shortcut4(x)) x = x.reshape(x.shape[0], -1) - x = self.lif_out(s, self.fc1(s, self.lif6(s, self.fc0(s, x)))) + x = self.lif_out(self.fc1(self.lif6(self.fc0(x)))) return x @@ -239,16 +239,14 @@ def main(): bm.random.seed(1234) net = ResNet11() - @bm.jit - @bm.to_object(child_objs=net, dyn_vars=bm.random.DEFAULT) + @bm.jit(child_objs=net, dyn_vars=bm.random.DEFAULT) def loss_fun(x, y, fit=True): + bp.share.save(fit=fit) yy = bm.one_hot(y, 10, dtype=bm.float_) # poisson encoding x = (bm.random.rand(num_time, *x.shape) < bm.abs(x)).astype(bm.float_) * bm.sign(x) # loop over time - s = {'fit': fit} - for i in range(num_time): - o = net(s, x[i]) + o = bm.for_loop(net, x, jit=False)[-1] for m in net.nodes().unique(): if isinstance(m, LIFNode) and m.fire: m.v_acc += (m.v_acc < 1e-3).astype(bm.float_) @@ -264,8 +262,7 @@ def loss_fun(x, y, fit=True): train_vars=net.train_vars().unique(), weight_decay=5e-4) - @bm.jit - @bm.to_object(child_objs=(optimizer, grad_fun)) + @bm.jit(child_objs=(optimizer, grad_fun)) def train_fun(x, y): grads, l, n = grad_fun(x, y) optimizer.update(grads)