Skip to content

Commit

Permalink
Merge pull request #342 from chaoming0625/master
Browse files Browse the repository at this point in the history
Formalize new styles of neuron and synapse models
  • Loading branch information
chaoming0625 authored Mar 3, 2023
2 parents 8692d3e + f21aff0 commit d8d23db
Show file tree
Hide file tree
Showing 27 changed files with 665 additions and 608 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu

- **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming.
- **[brainpylib](https://github.com/brainpy/brainpylib)**: Efficient operators for the sparse and event-driven computation.
- **[BrainPyExamples](https://github.com/brainpy/BrainPyExamples)**: Comprehensive examples of BrainPy computation.
- **[brainpy-largescale](https://github.com/NH-NCL/brainpy-largescale)**: One solution for the large-scale brain modeling.
- **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation.
- **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling.



Expand Down
8 changes: 4 additions & 4 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,16 @@
synapses, # synaptic dynamics
synouts, # synaptic output
synplast, # synaptic plasticity
syn,
experimental,
)
from brainpy._src.dyn.base import not_pass_sha
from brainpy._src.dyn.base import not_pass_shared
from brainpy._src.dyn.base import (DynamicalSystem,
DynamicalSystemNS,
Container as Container,
Sequential as Sequential,
Network as Network,
NeuGroup as NeuGroup,
NeuGroupNS as NeuGroupNS,
SynConn as SynConn,
SynOut as SynOut,
SynSTP as SynSTP,
Expand All @@ -77,8 +78,7 @@
from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations
LoopOverTime as LoopOverTime,)
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
from brainpy._src.dyn.context import share
from brainpy._src.dyn.delay import Delay
from brainpy._src.dyn.context import share, Delay


# Part 4: Training #
Expand Down
1 change: 1 addition & 0 deletions brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def f_loss():
def train(idx):
gradients, loss = grad_f()
optimizer.update(gradients if isinstance(gradients, dict) else {'a': gradients})
optimizer.lr.step_epoch()
return loss

def batch_train(start_i, n_batch):
Expand Down
19 changes: 10 additions & 9 deletions brainpy/_src/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
SLICE_VARS = 'slice_vars'


def not_pass_sha(func: Callable):
def not_pass_shared(func: Callable):
"""Label the update function as the one without passing shared arguments.
The original update function explicitly requires shared arguments at the first place::
Expand Down Expand Up @@ -610,7 +610,8 @@ def __repr__(self):
entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(self._modules))
return f'{self.__class__.__name__}(\n{entries}\n)'

def update(self, s, x) -> ArrayType:
@not_pass_shared
def update(self, x) -> ArrayType:
"""Update function of a sequential model.
Parameters
Expand All @@ -626,12 +627,7 @@ def update(self, s, x) -> ArrayType:
The output tensor.
"""
for m in self._modules:
if isinstance(m, DynamicalSystemNS):
x = m(x)
elif isinstance(m, DynamicalSystem):
x = m(s, x)
else:
x = m(x)
x = m(x)
return x


Expand Down Expand Up @@ -665,7 +661,7 @@ def __init__(
mode=mode,
**ds_dict)

@not_pass_sha
@not_pass_shared
def update(self, *args, **kwargs):
"""Step function of a network.
Expand Down Expand Up @@ -807,6 +803,11 @@ def __getitem__(self, item):
return NeuGroupView(target=self, index=item)


class NeuGroupNS(NeuGroup):
"""Base class for neuron group without shared arguments passed."""
pass_shared = False


class SynConn(DynamicalSystem):
"""Base class to model two-end synaptic connections.
Expand Down
Loading

0 comments on commit d8d23db

Please sign in to comment.