Skip to content

Commit

Permalink
[documentation] update documentation to brainpy>=2.4.0 (#361)
Browse files Browse the repository at this point in the history
[documentation] Update documentation to brainpy>=2.4.0
  • Loading branch information
chaoming0625 authored Apr 15, 2023
2 parents 3d63531 + be3c613 commit dd238fc
Show file tree
Hide file tree
Showing 61 changed files with 5,429 additions and 3,568 deletions.
2 changes: 2 additions & 0 deletions brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def __init__(
if f_loss_batch is not None:
raise UnsupportedError('"f_loss_batch" is no longer supported, please '
'use "f_loss" instead.')
if fun_inputs is not None:
raise UnsupportedError('"fun_inputs" is no longer supported.')
if f_loss is None:
f_loss = losses.mean_squared_error if f_type == constants.DISCRETE else losses.mean_square
self.f_loss = f_loss
Expand Down
1 change: 1 addition & 0 deletions brainpy/_src/analysis/lowdim/lowdim_bifurcation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from functools import partial

import jax
import jax.numpy as jnp
from jax import vmap
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions brainpy/_src/analysis/lowdim/lowdim_phase_plane.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-

import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/checkpoints/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from brainpy import errors
import brainpy.math as bm
from brainpy._src.math.object_transform.base import BrainPyObject, ArrayCollector
from brainpy._src.math.object_transform.base import BrainPyObject
from brainpy._src.math.object_transform.collectors import ArrayCollector


logger = logging.getLogger('brainpy.brainpy_object.io')
Expand Down
12 changes: 6 additions & 6 deletions brainpy/_src/checkpoints/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ def _record_saved_duration(checkpoint_start_time: float):
# Note: for the very first checkpoint, this is the interval between program
# init and the current checkpoint start time.
duration_since_last_checkpoint = checkpoint_start_time - _LAST_CHECKPOINT_WRITE_TIME
if jax.version.__version_info__ > (0, 3, 25):
if monitoring is not None:
monitoring.record_event_duration_secs(
'/jax/checkpoint/write/duration_since_last_checkpoint_secs',
duration_since_last_checkpoint)
Expand Down Expand Up @@ -1151,7 +1151,7 @@ def save_main_ckpt_task():
else:
save_main_ckpt_task()
end_time = time.time()
if jax.version.__version_info__ > (0, 3, 25):
if monitoring is not None:
monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT,
end_time - start_time)
return ckpt_path
Expand Down Expand Up @@ -1281,7 +1281,7 @@ def save_main_ckpt_task():
else:
save_main_ckpt_task()
end_time = time.time()
if jax.version.__version_info__ > (0, 3, 25):
if monitoring is not None:
monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT,
end_time - start_time)

Expand Down Expand Up @@ -1390,7 +1390,7 @@ def save_main_ckpt_task():
keep, overwrite, keep_every_n_steps, start_time, async_manager)

end_time = time.time()
if jax.version.__version_info__ > (0, 3, 25):
if monitoring is not None:
monitoring.record_event_duration_secs(_WRITE_CHECKPOINT_EVENT,
end_time - start_time)
return ckpt_path
Expand Down Expand Up @@ -1553,7 +1553,7 @@ def read_chunk(i):
restored_checkpoint = from_state_dict(target, state_dict)

end_time = time.time()
if jax.version.__version_info__ > (0, 3, 25):
if monitoring is not None:
monitoring.record_event_duration_secs(_READ_CHECKPOINT_EVENT, end_time - start_time)

return restored_checkpoint
Expand Down Expand Up @@ -1616,7 +1616,7 @@ def read_chunk(i):

state_dict = msgpack_restore(checkpoint_contents)
end_time = time.time()
if jax.version.__version_info__ > (0, 3, 25):
if monitoring is not None:
monitoring.record_event_duration_secs(_READ_CHECKPOINT_EVENT, end_time - start_time)

return state_dict
12 changes: 3 additions & 9 deletions brainpy/_src/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,16 +1475,10 @@ def __getitem__(self, key: Union[int, slice, str]):
elif isinstance(key, slice):
return Sequential(*(self.__all_nodes()[key]))
elif isinstance(key, int):
key = self.__format_key(key)
return self._static_modules[key] if (key not in self._dyn_modules) else self._dyn_modules[key]
return self.__all_nodes()[key]
elif isinstance(key, (tuple, list)):
nodes = []
for i in key:
if isinstance(i, int):
i = self.__format_key(i)
assert isinstance(i, str)
nodes.append(self._static_modules[i] if (i not in self._dyn_modules) else self._dyn_modules[i])
return Sequential(*nodes)
_all_nodes = self.__all_nodes()
return Sequential(*[_all_nodes[k] for k in key])
else:
raise KeyError(f'Unknown type of key: {type(key)}')

Expand Down
8 changes: 8 additions & 0 deletions brainpy/_src/dyn/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ def save(self, *args, **kwargs) -> None:
for identifier, data in kwargs.items():
self._arguments[identifier] = data

def __setitem__(self, key, value):
"""Enable setting the shared item by ``bp.share[key] = value``."""
self.save(key, value)

def __getitem__(self, item):
"""Enable loading the shared parameter by ``bp.share[key]``."""
return self.load(item)

def get_shargs(self) -> DotDict:
"""Get all shared arguments in the global context."""
return self._arguments.copy()
Expand Down
6 changes: 3 additions & 3 deletions brainpy/_src/dyn/neurons/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,8 @@ def __init__(
self.input_var = input_var

# initializers
self._W_initializer = check.is_initializer(V_initializer, allow_none=False)
self._V_initializer = check.is_initializer(W_initializer, allow_none=False)
self._W_initializer = check.is_initializer(W_initializer, allow_none=False)
self._V_initializer = check.is_initializer(V_initializer, allow_none=False)

# variables
self.reset_state(self.mode)
Expand Down Expand Up @@ -491,7 +491,7 @@ def dW(self, W, t, V):

@property
def derivative(self):
return JointEq([self.dV, self.dW])
return JointEq(self.dV, self.dW)

def update(self, x=None):
t = share.load('t')
Expand Down
8 changes: 4 additions & 4 deletions brainpy/_src/dyn/neurons/input_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def __init__(
self.num_times = len(times)

# data about times and indices
self.times = jnp.asarray(times)
self.indices = jnp.asarray(indices, dtype=bm.int_)
self.times = bm.asarray(times)
self.indices = bm.asarray(indices, dtype=bm.int_)
if need_sort:
sort_idx = jnp.argsort(self.times)
sort_idx = bm.argsort(self.times)
self.indices.value = self.indices[sort_idx]
self.times.value = self.times[sort_idx]

Expand All @@ -144,7 +144,7 @@ def __init__(
# functions
def cond_fun(t):
i = self.i.value
return jnp.logical_and(i < self.num_times, t >= self.times[i])
return bm.logical_and(i < self.num_times, t >= self.times[i])

def body_fun(t):
i = self.i.value
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,6 @@ def _step_func_predict(self, shared_args, t, i, x):
shared = tools.DotDict(t=t, i=i, dt=self.dt)
shared.update(shared_args)
share.save(**shared)
self.target.clear_input()
self._step_func_input(shared)

# dynamics update step
Expand All @@ -655,6 +654,7 @@ def _step_func_predict(self, shared_args, t, i, x):
if self.progress_bar:
id_tap(lambda *arg: self._pbar.update(), ())
share.clear_shargs()
self.target.clear_input()

if self._memory_efficient:
id_tap(self._step_mon_on_cpu, mon)
Expand Down
2 changes: 2 additions & 0 deletions brainpy/_src/initialize/random_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def calculate_gain(nonlinearity, param=None):


def _format_shape(shape):
if isinstance(shape, int):
return (shape, )
if len(shape) == 0:
raise ValueError('Please provide shape.')
if len(shape) == 1:
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/integrators/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def run(
end_t = start_t + duration
# times
times = bm.arange(start_t, end_t, self.dt).value
indices = bm.arange(times.size).value + self.idx
indices = bm.arange(times.size).value + self.idx.value

_dyn_args, _ = tree_flatten(dyn_args)
for _d in _dyn_args:
Expand Down
20 changes: 9 additions & 11 deletions brainpy/_src/math/delayvars.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# -*- coding: utf-8 -*-

from typing import Union, Callable, Optional, Dict
from typing import Union, Callable

import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap
from jax.lax import cond, stop_gradient
from jax.lax import stop_gradient

from brainpy import check
from brainpy.check import is_float, is_integer, jit_error_checking
from brainpy.errors import UnsupportedError
from .interoperability import as_jax
from .compat_numpy import vstack, broadcast_to
from .environment import get_dt, get_float
from .interoperability import as_jax
from .ndarray import ndarray, Array
from .object_transform.base import BrainPyObject
from .object_transform.controls import cond
from .object_transform.variables import Variable

__all__ = [
Expand Down Expand Up @@ -159,8 +159,8 @@ def __init__(
if before_t0 is None:
self._before_type = _DATA_BEFORE
elif callable(before_t0):
self._before_t0 = lambda t: jnp.asarray(jnp.broadcast_to(before_t0(t), delay_target.shape),
dtype=delay_target.dtype)
self._before_t0 = lambda t: as_jax(broadcast_to(before_t0(t), delay_target.shape),
dtype=delay_target.dtype)
self._before_type = _FUNC_BEFORE
elif isinstance(before_t0, (ndarray, jnp.ndarray, float, int)):
self._before_type = _DATA_BEFORE
Expand Down Expand Up @@ -248,17 +248,15 @@ def _after_t0(self, prev_time):
return cond(extra == 0., self._true_fn, self._false_fn, (req_num_step, extra))
elif self.interp_method == _INTERP_ROUND:
req_num_step = jnp.asarray(jnp.round(diff / self.dt), dtype=jnp.int32)
return self._true_fn([req_num_step, 0.])
return self._true_fn(req_num_step, 0.)
else:
raise UnsupportedError(f'Un-supported interpolation method {self.interp_method}, '
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')

def _true_fn(self, div_mod):
req_num_step, extra = div_mod
def _true_fn(self, req_num_step, extra):
return self.data[self.idx[0] + req_num_step]

def _false_fn(self, div_mod):
req_num_step, extra = div_mod
def _false_fn(self, req_num_step, extra):
idx = jnp.asarray([self.idx[0] + req_num_step,
self.idx[0] + req_num_step + 1])
idx %= self.num_delay_step
Expand Down
22 changes: 12 additions & 10 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-

import operator
from typing import Union, Optional, Sequence, Any, Tuple as TupleType, List
from typing import Union, Optional, Sequence

import jax
import numpy as np
from jax import numpy as jnp
from jax.dtypes import canonicalize_dtype
from jax.tree_util import register_pytree_node
from jax.tree_util import register_pytree_node_class

import brainpy.math
from brainpy.errors import MathError
Expand Down Expand Up @@ -60,6 +60,7 @@ def _get_dtype(v):
return dtype


@register_pytree_node_class
class Array(object):
"""Multiple-dimensional array in BrainPy.
"""
Expand Down Expand Up @@ -170,8 +171,8 @@ def __iter__(self):
- https://github.com/google/jax/issues/7713
- https://github.com/google/jax/pull/3821
"""
for v in self.value:
yield v
for i in range(self.value.shape[0]):
yield self.value[i]

def __getitem__(self, index):
if isinstance(index, slice) and (index == _all_slice):
Expand Down Expand Up @@ -1378,12 +1379,13 @@ def expand(self, *sizes) -> 'Array':
f'dimension {i}. Target sizes: {sizes}. Tensor sizes: {self.shape}')
return Array(jnp.broadcast_to(self.value, sizes_list))

def tree_flatten(self):
return (self._value,), None

@classmethod
def tree_unflatten(cls, aux_data, flat_contents):
return cls(*flat_contents)


JaxArray = Array
ndarray = Array

register_pytree_node(
Array,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: Array(*flat_contents)
)
38 changes: 14 additions & 24 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,18 @@
from jax.errors import UnexpectedTracerError
from jax.tree_util import tree_flatten, tree_unflatten

from brainpy import errors, tools, check
from brainpy import errors, tools
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import (Array, )
from ._tools import (evaluate_dyn_vars,
dynvar_deprecation,
node_deprecation,
abstract)
from .variables import (Variable, VariableStack)
from .base import BrainPyObject, ObjectTransform
from .naming import (get_unique_name,
get_stack_cache,
cache_stack)
from ._utils import infer_dyn_vars
from .base import BrainPyObject, ArrayCollector, ObjectTransform
from .variables import (Variable, VariableStack)

__all__ = [
'make_loop',
Expand Down Expand Up @@ -520,9 +519,11 @@ def ifelse(
conditions: Union[bool, Sequence[bool]],
branches: Sequence[Any],
operands: Any = None,
show_code: bool = False,

# deprecated
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
show_code: bool = False,
):
"""``If-else`` control flows looks like native Pythonic programming.
Expand Down Expand Up @@ -585,13 +586,9 @@ def ifelse(
raise ValueError(f'The numbers of branches and conditions do not match. '
f'Got len(conditions)={len(conditions)} and len(branches)={len(branches)}. '
f'We expect len(conditions) + 1 == len(branches). ')
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
dyn_vars = ArrayCollector(dyn_vars)
for f in branches:
dyn_vars += infer_dyn_vars(f)
for obj in check.is_all_objs(child_objs, out_as='tuple'):
dyn_vars.update(obj.vars().unique())
dyn_vars = tuple(dyn_vars.unique().values())

dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)

# format new codes
if len(conditions) == 1:
Expand All @@ -604,18 +601,10 @@ def ifelse(
codes = ['def f(operands):',
f' f0 = branches[{len(conditions)}]']
num_cond = len(conditions) - 1
if len(dyn_vars) > 0:
code_scope['_cond'] = cond
code_scope['dyn_vars'] = dyn_vars
for i in range(len(conditions) - 1):
codes.append(f' f{i + 1} = lambda r: _cond(conditions[{num_cond - i}], '
f'branches[{num_cond - i}], f{i}, r, dyn_vars)')
codes.append(f' return _cond(conditions[0], branches[0], f{len(conditions) - 1}, operands, dyn_vars)')
else:
code_scope['_cond'] = lax.cond
for i in range(len(conditions) - 1):
codes.append(f' f{i + 1} = lambda r: _cond(conditions[{num_cond - i}], branches[{num_cond - i}], f{i}, r)')
codes.append(f' return _cond(conditions[0], branches[0], f{len(conditions) - 1}, operands)')
code_scope['_cond'] = cond
for i in range(len(conditions) - 1):
codes.append(f' f{i + 1} = lambda r: _cond(conditions[{num_cond - i}], branches[{num_cond - i}], f{i}, r)')
codes.append(f' return _cond(conditions[0], branches[0], f{len(conditions) - 1}, operands)')
codes = '\n'.join(codes)
if show_code: print(codes)
exec(compile(codes.strip(), '', 'exec'), code_scope)
Expand Down Expand Up @@ -751,6 +740,7 @@ def for_loop(
with VariableStack() as dyn_vars:
_ = jax.eval_shape(body_fun, *op_vals)
cache_stack(body_fun, dyn_vars) # cache
del op_vals

# functions
def fun2scan(carry, x):
Expand Down
Loading

0 comments on commit dd238fc

Please sign in to comment.