Skip to content

Commit

Permalink
Merge pull request #58 from NACLab/dev
Browse files Browse the repository at this point in the history
Nudging over recent dev revisions/patches to main
  • Loading branch information
ago109 authored Jul 4, 2024
2 parents ecd76cd + 22024ac commit d9497bd
Show file tree
Hide file tree
Showing 11 changed files with 113 additions and 64 deletions.
2 changes: 1 addition & 1 deletion ngclearn/components/input_encoders/bernoulliCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def save(self, directory, **kwargs):
def load(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".npz"
data = jnp.load(file_name)
self.key.set( data['key'] )
self.key.set(data['key'])

@classmethod
def help(cls): ## component help function
Expand Down
3 changes: 2 additions & 1 deletion ngclearn/components/input_encoders/latencyCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def _advance_state(t, dt, key, inputs, mask, targ_sp_times, tols):
key, *subkeys = random.split(key, 2)
data = inputs ## get sensory pattern data / features
spikes, spk_mask = _extract_spike(targ_sp_times, t, mask) ## get spikes at t
tols = _update_times(t, spikes, tols)
return spikes, tols, spk_mask, targ_sp_times, key

@resolver(_advance_state)
Expand Down Expand Up @@ -237,7 +238,7 @@ def save(self, directory, **kwargs):
def load(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".npz"
data = jnp.load(file_name)
self.key.set( data['key'] )
self.key.set(data['key'])

@classmethod
def help(cls): ## component help function
Expand Down
2 changes: 1 addition & 1 deletion ngclearn/components/input_encoders/poissonCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def save(self, directory, **kwargs):
def load(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".npz"
data = jnp.load(file_name)
self.key.set( data['key'] )
self.key.set(data['key'])

@classmethod
def help(cls): ## component help function
Expand Down
34 changes: 21 additions & 13 deletions ngclearn/components/neurons/graded/gaussianErrorCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
| mu - predicted value (takes in external signals)
| target - desired/goal value (takes in external signals)
| modulator - modulation signal (takes in optional external signals)
| mask - binary/gating mask to apply to error neuron calculations
| --- Cell Output Compartments: ---
| L - local loss function embodied by this cell
| dmu - derivative of L w.r.t. mu
Expand Down Expand Up @@ -84,39 +85,45 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
self.target = Compartment(restVals) # target. input wire
self.dtarget = Compartment(restVals) # derivative target
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
self.mask = Compartment(restVals + 1.0)

@staticmethod
def _advance_state(dt, mu, dmu, target, dtarget, modulator):
def _advance_state(dt, mu, dmu, target, dtarget, modulator, mask):
## compute Gaussian error cell output
dmu, dtarget, L = _run_cell(dt, target, mu)
dmu = dmu * modulator
dtarget = dtarget * modulator
return dmu, dtarget, L
dmu, dtarget, L = _run_cell(dt, target * mask, mu * mask)
dmu = dmu * modulator * mask
dtarget = dtarget * modulator * mask
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
return dmu, dtarget, L, mask

@resolver(_advance_state)
def advance_state(self, dmu, dtarget, L):
def advance_state(self, dmu, dtarget, L, mask):
self.dmu.set(dmu)
self.dtarget.set(dtarget)
self.L.set(L)
self.mask.set(mask)

@staticmethod
def _reset(batch_size, n_units):
dmu = jnp.zeros((batch_size, n_units))
dtarget = jnp.zeros((batch_size, n_units))
target = jnp.zeros((batch_size, n_units)) #None
mu = jnp.zeros((batch_size, n_units)) #None
restVals = jnp.zeros((batch_size, n_units))
dmu = restVals
dtarget = restVals
target = restVals
mu = restVals
modulator = mu + 1.
L = 0.
return dmu, dtarget, target, mu, modulator, L
mask = jnp.ones((batch_size, n_units))
return dmu, dtarget, target, mu, modulator, L, mask

@resolver(_reset)
def reset(self, dmu, dtarget, target, mu, modulator, L):
def reset(self, dmu, dtarget, target, mu, modulator, L, mask):
self.dmu.set(dmu)
self.dtarget.set(dtarget)
self.target.set(target)
self.mu.set(mu)
self.modulator.set(modulator)
self.L.set(L)
self.mask.set(mask)

@classmethod
def help(cls): ## component help function
Expand All @@ -128,7 +135,8 @@ def help(cls): ## component help function
"inputs":
{"mu": "External input prediction value(s)",
"target": "External input target signal value(s)",
"modulator": "External input modulatory/scaling signal(s)"},
"modulator": "External input modulatory/scaling signal(s)",
"mask": "External binary/gating mask to apply to signals"},
"outputs":
{"L": "Local loss value computed/embodied by this error-cell",
"dmu": "first derivative of loss w.r.t. prediction value(s)",
Expand Down
34 changes: 21 additions & 13 deletions ngclearn/components/neurons/graded/laplacianErrorCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class LaplacianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cel
| mu - predicted value (takes in external signals)
| target - desired/goal value (takes in external signals)
| modulator - modulation signal (takes in optional external signals)
| mask - binary/gating mask to apply to error neuron calculations
| --- Cell Output Compartments: ---
| L - local loss function embodied by this cell
| dmu - derivative of L w.r.t. mu
Expand Down Expand Up @@ -86,39 +87,45 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
self.target = Compartment(restVals) # target. input wire
self.dtarget = Compartment(restVals) # derivative target
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
self.mask = Compartment(restVals + 1.0)

@staticmethod
def _advance_state(dt, mu, target, modulator):
def _advance_state(dt, mu, target, modulator, mask):
## compute Laplacian error cell output
dmu, dtarget, L = _run_cell(dt, target, mu)
dmu = dmu * modulator
dtarget = dtarget * modulator
return dmu, dtarget, L
dmu, dtarget, L = _run_cell(dt, target * mask, mu * mask)
dmu = dmu * modulator * mask
dtarget = dtarget * modulator * mask
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
return dmu, dtarget, L, mask

@resolver(_advance_state)
def advance_state(self, dmu, dtarget, L):
def advance_state(self, dmu, dtarget, L, mask):
self.dmu.set(dmu)
self.dtarget.set(dtarget)
self.L.set(L)
self.mask.set(mask)

@staticmethod
def _reset(batch_size, n_units):
dmu = jnp.zeros((batch_size, n_units))
dtarget = jnp.zeros((batch_size, n_units))
target = jnp.zeros((batch_size, n_units)) #None
mu = jnp.zeros((batch_size, n_units)) #None
restVals = jnp.zeros((batch_size, n_units))
dmu = restVals
dtarget = restVals
target = restVals
mu = restVals
modulator = mu + 1.
L = 0.
return dmu, dtarget, target, mu, modulator, L
mask = jnp.ones((batch_size, n_units))
return dmu, dtarget, target, mu, modulator, L, mask

@resolver(_reset)
def reset(self, dmu, dtarget, target, mu, modulator, L):
def reset(self, dmu, dtarget, target, mu, modulator, L, mask):
self.dmu.set(dmu)
self.dtarget.set(dtarget)
self.target.set(target)
self.mu.set(mu)
self.modulator.set(modulator)
self.L.set(L)
self.mask.set(mask)

@classmethod
def help(cls): ## component help function
Expand All @@ -130,7 +137,8 @@ def help(cls): ## component help function
"inputs":
{"mu": "External input prediction value(s)",
"target": "External input target signal value(s)",
"modulator": "External input modulatory/scaling signal(s)"},
"modulator": "External input modulatory/scaling signal(s)",
"mask": "External binary/gating mask to apply to signals"},
"outputs":
{"L": "Local loss value computed/embodied by this error-cell",
"dmu": "first derivative of loss w.r.t. prediction value(s)",
Expand Down
26 changes: 19 additions & 7 deletions ngclearn/components/neurons/spiking/LIFCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ngclearn.components.jaxComponent import JaxComponent
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
from ngclearn.utils.surrogate_fx import secant_lif_estimator, arctan_estimator, triangular_estimator

@jit
def _update_times(t, s, tols):
Expand Down Expand Up @@ -235,6 +236,11 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
self.batch_size = 1
self.n_units = n_units

## set up surrogate function for spike emission
self.spike_fx, self.d_spike_fx = secant_lif_estimator()
#self.spike_fx, self.d_spike_fx = arctan_estimator() #
#self.spike_fx, self.d_spike_fx = triangular_estimator() # straight_through_estimator()

## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
thr0 = 0.
Expand All @@ -249,36 +255,40 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
self.rfr = Compartment(restVals + self.refract_T)
self.thr_theta = Compartment(restVals + thr0)
self.tols = Compartment(restVals) ## time-of-last-spike
self.surrogate = Compartment(restVals + 1.) ## surrogate signal

@staticmethod
def _advance_state(t, dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T,
thr, tau_theta, theta_plus, one_spike, intgFlag,
thr, tau_theta, theta_plus, one_spike, intgFlag, d_spike_fx,
key, j, v, s, rfr, thr_theta, tols):
skey = None ## this is an empty dkey if single_spike mode turned off
if one_spike: ## old code ~> if self.one_spike is False:
if one_spike:
key, skey = random.split(key, 2)
## run one integration step for neuronal dynamics
#j = _modify_current(j, dt, tau_m, R_m) ## re-scale current in prep for volt ODE
j = j * R_m
#surrogate = d_spike_fx(v, thr + thr_theta)
v, s, raw_spikes, rfr = _run_cell(dt, j, v, thr, thr_theta, rfr, skey,
tau_m, v_rest, v_reset, v_decay, refract_T,
intgFlag)
surrogate = d_spike_fx(v, thr + thr_theta)
#surrogate = d_spike_fx(j, thr + thr_theta)
if tau_theta > 0.:
## run one integration step for threshold dynamics
thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
## update tols
tols = _update_times(t, s, tols)
return v, s, raw_spikes, rfr, thr_theta, tols, key
return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate

@resolver(_advance_state)
def advance_state(self, v, s, s_raw, rfr, thr_theta, tols, key):
def advance_state(self, v, s, s_raw, rfr, thr_theta, tols, key, surrogate):
self.v.set(v)
self.s.set(s)
self.s_raw.set(s_raw)
self.rfr.set(rfr)
self.thr_theta.set(thr_theta)
self.tols.set(tols)
self.key.set(key)
self.surrogate.set(surrogate)

@staticmethod
def _reset(batch_size, n_units, v_rest, refract_T):
Expand All @@ -290,17 +300,19 @@ def _reset(batch_size, n_units, v_rest, refract_T):
rfr = restVals + refract_T
#thr_theta = restVals ## do not reset thr_theta
tols = restVals #+ 0
return j, v, s, s_raw, rfr, tols
surrogate = restVals + 1.
return j, v, s, s_raw, rfr, tols, surrogate

@resolver(_reset)
def reset(self, j, v, s, s_raw, rfr, tols):
def reset(self, j, v, s, s_raw, rfr, tols, surrogate):
self.j.set(j)
self.v.set(v)
self.s.set(s)
self.s_raw.set(s_raw)
self.rfr.set(rfr)
#self.thr_theta.set(thr_theta)
self.tols.set(tols)
self.surrogate.set(surrogate)

def save(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".npz"
Expand Down
2 changes: 2 additions & 0 deletions ngclearn/components/neurons/spiking/WTASCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr_base=0.4, thr_gain=0.0
self.n_units = n_units

## base threshold setup
## according to eqn 26 of the source paper, the initial condition for the
## threshold should technically be between: 1/n_units < threshold0 << 0.5, e.g., 0.15
key, subkey = random.split(self.key.value)
self.threshold0 = thr_base + random.uniform(subkey, (1, n_units),
minval=-thr_jitter, maxval=thr_jitter,
Expand Down
Loading

0 comments on commit d9497bd

Please sign in to comment.