From ee50f333cd728ee51b12289b53e94a3074117fc6 Mon Sep 17 00:00:00 2001 From: ago109 Date: Thu, 8 Aug 2024 17:52:46 -0400 Subject: [PATCH] cleaned up raf --- .../components/neurons/spiking/RAFCell.py | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/ngclearn/components/neurons/spiking/RAFCell.py b/ngclearn/components/neurons/spiking/RAFCell.py index 7c1b946b7..0f23107cf 100755 --- a/ngclearn/components/neurons/spiking/RAFCell.py +++ b/ngclearn/components/neurons/spiking/RAFCell.py @@ -93,11 +93,17 @@ class RAFCell(JaxComponent): thr: voltage/membrane threshold (to obtain action potentials in terms of binary spikes) (Default: 5 mV) + omega: angular frequency (Default: 10) + + b: oscillation dampening factor (Default: -1) + v_reset: membrane reset potential condition (Default: 0 mV) - w_reset: reset condition for angular driver (Default: 0 mV) + w_reset: reset condition for angular driver (Default: 0) + + v0: membrane potential initial condition (Default: 0 mV) - b: oscillation dampening factor (Default: -1.) + w0: angular driver initial condition (Default: 0) integration_type: type of integration to use for this cell's dynamics; current supported forms include "euler" (Euler/RK-1 integration) @@ -110,8 +116,8 @@ class RAFCell(JaxComponent): # Define Functions def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400., - omega=10., thr=5., v_reset=0., w_reset=0., b=-1., - integration_type="euler", batch_size=1, **kwargs): + thr=5., omega=10., b=-1., v_reset=0., w_reset=0., + v0=0., w0=0., integration_type="euler", batch_size=1, **kwargs): #v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0., super().__init__(name, **kwargs) @@ -129,6 +135,8 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400., #self.v_rest = v_rest self.v_reset = v_reset self.w_reset = w_reset + self.v0 = v0 + self.w0 = w0 self.thr = thr ## Layer Size Setup @@ -147,9 +155,6 @@ def __init__(self, name, n_units, tau_m=15., resist_m=1., tau_w=400., @staticmethod def _advance_state(t, dt, tau_m, resist_m, tau_w, thr, omega, b, v_reset, w_reset, intgFlag, j, v, w, tols): - ## center variables before running dynamics - v = v - v_reset - w = w - w_reset ## continue with centered dynamics j_ = j * resist_m if intgFlag == 1: ## RK-2/midpoint @@ -164,11 +169,8 @@ def _advance_state(t, dt, tau_m, resist_m, tau_w, thr, omega, b, _, _v = step_euler(0., v, _dfv, dt, v_params) s = _emit_spike(_v, thr) ## hyperpolarize/reset/snap variables - v = _v * (1. - s) + s #* v_reset - w = _w * (1. - s) + s #* w_reset - ## artificially shift variables back to rest/reset values - v = v + v_reset - w = w + w_reset + v = _v * (1. - s) + s * v_reset + w = _w * (1. - s) + s * w_reset tols = _update_times(t, s, tols) return j, v, w, s, tols @@ -181,11 +183,11 @@ def advance_state(self, j, v, w, s, tols): self.tols.set(tols) @staticmethod - def _reset(batch_size, n_units, v_reset, w_reset): + def _reset(batch_size, n_units, v0, w0): restVals = jnp.zeros((batch_size, n_units)) j = restVals # None - v = restVals + v_reset - w = restVals + w_reset + v = restVals + v0 + w = restVals + w0 s = restVals #+ 0 tols = restVals #+ 0 return j, v, w, s, tols @@ -221,9 +223,8 @@ def help(cls): ## component help function "tau_m": "Cell membrane time constant", "resist_m": "Membrane resistance value", "tau_w": "Recovery variable time constant", - "v_thr": "Base voltage threshold value", - "v_rest": "Resting membrane potential value", "v_reset": "Reset membrane potential value", + "w_reset": "Reset angular driver value", "b": "Exponential dampening factor applied to oscillations", "omega": "Angular frequency of neuronal progress per second (radians)", "v0": "Initial condition for membrane potential/voltage",