Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging over dev revisions to main (prep for pip nudge) #56

Merged
merged 8 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 57 additions & 11 deletions ngclearn/components/neurons/graded/rewardErrorCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,34 @@ class RewardErrorCell(JaxComponent): ## Reward prediction error cell

| --- Cell Input Compartments: ---
| reward - current reward signal at time `t`
| accum_reward - current accumulated episodic reward signal
| --- Cell Output Compartments: ---
| mu - current moving average prediction of reward at time `t`
| rpe - current reward prediction error (RPE) signal
| accum_reward - current accumulated episodic reward signal (IF online predictor not used)

Args:
name: the string name of this cell

n_units: number of cellular entities (neural population size)

alpha: decay factor to apply to (exponential) moving average prediction

ema_window_len: exponential moving average window length -- for use only
in `evolve` step for updating episodic reward signals; (default: 10)

use_online_predictor: use online prediction of reward signal (default: True)
-- if set to False, then reward prediction will only occur upon a call
to this cell's `evolve` function
"""
def __init__(self, name, n_units, alpha, batch_size=1, **kwargs):
def __init__(self, name, n_units, alpha, ema_window_len=10,
use_online_predictor=True, batch_size=1, **kwargs):
super().__init__(name, **kwargs)

## RPE meta-parameters
self.alpha = alpha
self.ema_window_len = ema_window_len
self.use_online_predictor = use_online_predictor

## Layer Size Setup
self.n_units = n_units
Expand All @@ -34,29 +47,55 @@ def __init__(self, name, n_units, alpha, batch_size=1, **kwargs):
self.mu = Compartment(restVals) ## reward predictor state(s)
self.reward = Compartment(restVals) ## target reward signal(s)
self.rpe = Compartment(restVals) ## reward prediction error(s)
self.accum_reward = Compartment(restVals) ## accumulated reward signal(s)
self.n_ep_steps = Compartment(jnp.zeros((self.batch_size, 1))) ## number of episode steps taken

@staticmethod
def _advance_state(dt, alpha, mu, rpe, reward):
def _advance_state(dt, use_online_predictor, alpha, mu, rpe, reward,
n_ep_steps, accum_reward):
## compute/update RPE and predictor values
accum_reward = accum_reward + reward
rpe = reward - mu
mu = mu * (1. - alpha) + reward * alpha
return mu, rpe
if use_online_predictor:
mu = mu * (1. - alpha) + reward * alpha
n_ep_steps = n_ep_steps + 1
return mu, rpe, n_ep_steps, accum_reward

@resolver(_advance_state)
def advance_state(self, mu, rpe):
def advance_state(self, mu, rpe, n_ep_steps, accum_reward):
self.mu.set(mu)
self.rpe.set(rpe)
self.n_ep_steps.set(n_ep_steps)
self.accum_reward.set(accum_reward)

@staticmethod
def _evolve(dt, use_online_predictor, ema_window_len, n_ep_steps, mu,
accum_reward):
if use_online_predictor:
## total episodic reward signal
r = accum_reward/n_ep_steps
mu = (1. - 1./ema_window_len) * mu + (1./ema_window_len) * r
return mu

@resolver(_evolve)
def evolve(self, mu):
self.mu.set(mu)

@staticmethod
def _reset(batch_size, n_units):
mu = jnp.zeros((batch_size, n_units)) #None
rpe = jnp.zeros((batch_size, n_units)) #None
return mu, rpe
restVals = jnp.zeros((batch_size, n_units))
mu = restVals
rpe = restVals
accum_reward = restVals
n_ep_steps = jnp.zeros((batch_size, 1))
return mu, rpe, accum_reward, n_ep_steps

@resolver(_reset)
def reset(self, mu, rpe):
def reset(self, mu, rpe, accum_reward, n_ep_steps):
self.mu.set(mu)
self.rpe.set(rpe)
self.accum_reward.set(accum_reward)
self.n_ep_steps.set(n_ep_steps)

@classmethod
def help(cls): ## component help function
Expand All @@ -69,16 +108,23 @@ def help(cls): ## component help function
{"reward": "External reward signals/values"},
"outputs":
{"mu": "Current state of reward predictor",
"rpe": "Current value of reward prediction error at time `t`"},
"rpe": "Current value of reward prediction error at time `t`",
"accum_reward": "Current accumulated episodic reward signal (generally "
"produced at the end of a control episode of `n_steps`)",
"n_ep_steps": "Number of episodic steps taken/tracked thus far "
"(since last `reset` call)",
"use_online_predictor": "Should an online reward predictor be used/maintained?"},
}
hyperparams = {
"n_units": "Number of neuronal cells to model in this layer",
"alpha": "Moving average decay factor",
"ema_window_len": "Exponential moving average window length",
"batch_size": "Batch size dimension of this component"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
"dynamics": "rpe = reward - mu; mu = mu * (1 - alpha) + reward * alpha",
"dynamics": "rpe = reward - mu; mu = mu * (1 - alpha) + reward * alpha; "
"accum_reward = accum_reward + reward",
"hyperparameters": hyperparams}
return info

Expand Down
22 changes: 16 additions & 6 deletions ngclearn/components/neurons/spiking/fitzhughNagumoCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,17 @@ class FitzhughNagumoCell(JaxComponent):

gamma: power-term divisor (Default: 3.)

v_thr: voltage/membrane threshold (to obtain action potentials in terms
of binary spikes)

v0: initial condition / reset for voltage

w0: initial condition / reset for recovery

v_thr: voltage/membrane threshold (to obtain action potentials in terms
of binary spikes)

spike_reset: if True, once voltage crosses threshold, then dynamics
of voltage and recovery are reset/snapped to initial conditions
(default: False)

integration_type: type of integration to use for this cell's dynamics;
current supported forms include "euler" (Euler/RK-1 integration)
and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
Expand All @@ -131,7 +135,7 @@ class FitzhughNagumoCell(JaxComponent):

# Define Functions
def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
beta=0.8, gamma=3., v_thr=1.07, v0=0., w0=0.,
beta=0.8, gamma=3., v0=0., w0=0., v_thr=1.07, spike_reset=False,
integration_type="euler", **kwargs):
super().__init__(name, **kwargs)

Expand All @@ -150,6 +154,7 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
self.v0 = v0 ## initial membrane potential/voltage condition
self.w0 = w0 ## initial w-parameter condition
self.v_thr = v_thr
self.spike_reset = spike_reset

## Layer Size Setup
self.batch_size = 1
Expand All @@ -164,10 +169,13 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
self.tols = Compartment(restVals) ## time-of-last-spike

@staticmethod
def _advance_state(t, dt, tau_m, R_m, tau_w, v_thr, alpha, beta, gamma,
intgFlag, j, v, w, tols):
def _advance_state(t, dt, tau_m, R_m, tau_w, v_thr, spike_reset, v0, w0, alpha,
beta, gamma, intgFlag, j, v, w, tols):
v, w, s = _run_cell(dt, j * R_m, v, w, v_thr, tau_m, tau_w, alpha, beta,
gamma, intgFlag)
if spike_reset: ## if spike-reset used, variables snapped back to initial conditions
v = v * (1. - s) + s * v0
w = w * (1. - s) + s * w0
tols = _update_times(t, s, tols)
return j, v, w, s, tols

Expand Down Expand Up @@ -220,6 +228,8 @@ def help(cls): ## component help function
"resist_m": "Membrane resistance value",
"tau_w": "Recovery variable time constant",
"v_thr": "Base voltage threshold value",
"spike_reset": "Should voltage/recover be snapped to initial "
"condition(s) if spike emitted?",
"alpha": "Dimensionless recovery variable shift factor `a",
"beta": "Dimensionless recovery variable scale factor `b`",
"gamma": "Power-term divisor constant",
Expand Down
29 changes: 4 additions & 25 deletions ngclearn/components/synapses/denseSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,6 @@
from ngclearn.utils.weight_distribution import initialize_params
from ngcsimlib.logger import info

@jit
def _compute_layer(inp, weight, biases=0., Rscale=1.):
"""
Applies the transformation/projection induced by the synaptic efficacie
associated with this synaptic cable

Args:
inp: signal input to run through this synaptic cable

weight: this cable's synaptic value matrix

biases: this cable's bias value vector (default: 0.)

Rscale: scale factor to apply to synapses before transform applied
to input values (default: 1.)

Returns:
a projection/transformation of input "inp"
"""
return jnp.matmul(inp, weight * Rscale) + biases

class DenseSynapse(JaxComponent): ## base dense synaptic cable
"""
A dense synaptic cable; no form of synaptic evolution/adaptation
Expand All @@ -51,7 +30,7 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable
(Default: None, which turns off/disables biases)

resist_scale: a fixed (resistance) scaling factor to apply to synaptic
transform (Default: 1.), i.e., yields: out = ((W * Rscale) * in)
transform (Default: 1.), i.e., yields: out = ((W * in) * resist_scale) + bias

p_conn: probability of a connection existing (default: 1.); setting
this to < 1 and > 0. will result in a sparser synaptic structure
Expand Down Expand Up @@ -98,7 +77,7 @@ def __init__(self, name, shape, weight_init=None, bias_init=None,

@staticmethod
def _advance_state(Rscale, inputs, weights, biases):
outputs = _compute_layer(inputs, weights, biases, Rscale)
outputs = (jnp.matmul(inputs, weights) * Rscale) + biases
return outputs

@resolver(_advance_state)
Expand Down Expand Up @@ -155,12 +134,12 @@ def help(cls): ## component help function
"batch_size": "Batch size dimension of this component",
"weight_init": "Initialization conditions for synaptic weight (W) values",
"bias_init": "Initialization conditions for bias/base-rate (b) values",
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
"resist_scale": "Resistance level scaling factor (Rscale); applied to output of transformation",
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
"dynamics": "outputs = [(W * Rscale) * inputs] + b",
"dynamics": "outputs = [W * inputs] * Rscale + b",
"hyperparameters": hyperparams}
return info

Expand Down
38 changes: 37 additions & 1 deletion ngclearn/utils/weight_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,17 @@ def uniform(amin=0., amax=1., **kwargs):
dist_dict = {"dist": "uniform", "amin": amin, "amax": amax}
return {**kwargs, **dist_dict}

def fan_in_uniform(**kwargs):
"""
Produce a configuration for a fan-in scaled unit uniform
distribution initializer.

Returns:
a fan-in scaled (unit) uniform distribution configuration
"""
dist_dict = {"dist": "fan_in_uniform"}
return {**kwargs, **dist_dict}

def hollow(scale, **kwargs):
"""
Produce a configuration for a constant hollow distribution initializer.
Expand Down Expand Up @@ -103,18 +114,28 @@ def initialize_params(dkey, init_kernel, shape, use_numpy=False):
dkey: PRNG key to control determinism of this routine

init_kernel: dictionary specifying the distribution type and its
parameters (default: `uniform` dist w/ `amin=0.02`, `amax=0.8`)
parameters (default: `uniform` dist w/ `amin=0.02`, `amax=0.8`) --
note that kernel dictionary may contain "post-processing" arguments
that can be "stacked" on top of the base matrix, for example, you
can pass in a dictionary:
{"dist": "uniform", "hollow": True, "lower_triangle": True} which
will create unit-uniform value matrix with upper triangle and main
diagonal values masked to zero (lower-triangle masking applied after
hollow matrix masking)

:Note: Currently supported distribution (dist) kernel schemes include:
"constant" (value);
"uniform" (amin, amax);
"gaussian" (mu, sigma);
"fan_in_gaussian" (NO params);
"fan_in_uniform" (NO params);
"hollow" (scale);
"eye" (scale);
while currently supported post-processing keyword arguments include:
"amin" (clip weights values to be >= amin);
"amax" (clip weights values to be <= amin);
"lower_triangle" (extract lower triangle of params, set rest to 0);
"upper_triangle" (extract upper triangle of params, set rest to 0);
"hollow" (zero out values along main diagonal);
"eye" (zero out off-diagonal values);
"n_row_active" (keep only n random rows non-masked/zero);
Expand Down Expand Up @@ -169,6 +190,13 @@ def initialize_params(dkey, init_kernel, shape, use_numpy=False):
phi = jax.random.normal(dkey, shape)
phi = phi * jnp.sqrt(1.0 / (shape[0] * 1.))
params = phi.astype(jnp.float32)
elif dist_type == "fan_in_uniform": ## fan-in scaled unit uniform init
phi = jnp.sqrt(1.0 / (shape[0] * 1.)) # sometimes "k" in other libraries
if use_numpy:
params = np.random.uniform(low=-phi, high=phi, size=shape)
else:
params = jax.random.uniform(dkey, shape, minval=-phi, maxval=phi)
params = params.astype(jnp.float32)
elif dist_type == "constant": ## constant value (everywhere) init
scale = _init_kernel.get("value", 1.)
if use_numpy:
Expand All @@ -180,6 +208,8 @@ def initialize_params(dkey, init_kernel, shape, use_numpy=False):
## check for any additional distribution post-processing kwargs (e.g., clipping)
clip_min = _init_kernel.get("amin")
clip_max = _init_kernel.get("amax")
lower_triangle = init_kernel.get("lower_triangle", False)
upper_triangle = init_kernel.get("upper_triangle", False)
is_hollow = _init_kernel.get("hollow", False)
is_eye = _init_kernel.get("eye", False)
n_row_active = _init_kernel.get("n_row_active", None)
Expand All @@ -195,6 +225,12 @@ def initialize_params(dkey, init_kernel, shape, use_numpy=False):
params = np.minimum(params, clip_max)
else:
params = jnp.minimum(params, clip_max)
if lower_triangle: ## extract lower triangle of params matrix
ltri_params = jax.numpy.tril(params.shape[0])
params = ltri_params
if upper_triangle: ## extract upper triangle of params matrix
ltri_params = jax.numpy.triu(params.shape[0])
params = ltri_params
if is_hollow: ## apply a hollow mask
if use_numpy:
params = (1. - np.eye(N=shape[0], M=shape[1])) * params
Expand Down
Loading