Skip to content

Commit

Permalink
Additions for inhibition stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
willgebhardt committed Nov 19, 2024
1 parent 6408ee0 commit 35eae76
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 17 deletions.
5 changes: 3 additions & 2 deletions ngclearn/components/base_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def watch(self, compartment, window_length):
"""
cs, end = self._add_path(compartment.path)

dtype = compartment.value.dtype
shape = compartment.value.shape
new_comp = Compartment(np.zeros(shape))
new_comp_store = Compartment(np.zeros((window_length, *shape)))
new_comp = Compartment(np.zeros(shape, dtype=dtype))
new_comp_store = Compartment(np.zeros((window_length, *shape), dtype=dtype))

comp_key = "*".join(compartment.path.split("/"))
store_comp_key = comp_key + "*store"
Expand Down
2 changes: 1 addition & 1 deletion ngclearn/components/input_encoders/poissonCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class PoissonCell(JaxComponent):
"""

@deprecate_args(max_freq="target_freq")
def __init__(self, name, n_units, target_freq=0., batch_size=1, **kwargs):
def __init__(self, name, n_units, target_freq=63.75, batch_size=1, **kwargs):
super().__init__(name, **kwargs)

## Constrained Bernoulli meta-parameters
Expand Down
24 changes: 14 additions & 10 deletions ngclearn/components/other/varTrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ngclearn.utils import tensorstats

@partial(jit, static_argnums=[4])
def _run_varfilter(dt, x, x_tr, decayFactor, a_delta=0.):
def _run_varfilter(dt, x, x_tr, decayFactor, gamma_tr, a_delta=0.):
"""
Run variable trace filter (low-pass filter) dynamics one step forward.
Expand All @@ -22,7 +22,7 @@ def _run_varfilter(dt, x, x_tr, decayFactor, a_delta=0.):
Returns:
updated trace/filter value/state
"""
_x_tr = x_tr * decayFactor
_x_tr = gamma_tr * x_tr * decayFactor
#x_tr + (-x_tr) * (dt / tau_tr) = (1 - dt/tau_tr) * x_tr
if a_delta > 0.: ## perform additive form of trace ODE
_x_tr = _x_tr + x * a_delta
Expand Down Expand Up @@ -64,13 +64,14 @@ class VarTrace(JaxComponent): ## low-pass filter
"""

# Define Functions
def __init__(self, name, n_units, tau_tr, a_delta, decay_type="exp",
def __init__(self, name, n_units, tau_tr, a_delta, gamma_tr=1, decay_type="exp",
batch_size=1, **kwargs):
super().__init__(name, **kwargs)

## Trace control coefficients
self.tau_tr = tau_tr ## trace time constant
self.a_delta = a_delta ## trace increment (if spike occurred)
self.gamma_tr = gamma_tr
self.decay_type = decay_type ## lin --> linear decay; exp --> exponential decay

## Layer Size Setup
Expand All @@ -83,17 +84,20 @@ def __init__(self, name, n_units, tau_tr, a_delta, decay_type="exp",
self.trace = Compartment(restVals)

@staticmethod
def _advance_state(dt, decay_type, tau_tr, a_delta, inputs, trace):
## compute the decay factor
decayFactor = 0. ## <-- pulse filter decay (default)
def _advance_state(dt, decay_type, tau_tr, a_delta, gamma_tr, inputs, trace):
decayFactor = 0.
if "exp" in decay_type:
decayFactor = jnp.exp(-dt/tau_tr)
elif "lin" in decay_type:
decayFactor = (1. - dt/tau_tr)
## else "step" == decay_type, yielding a step/pulse-like filter
trace = _run_varfilter(dt, inputs, trace, decayFactor, a_delta)
outputs = trace
return outputs, trace

_x_tr = gamma_tr * trace * decayFactor
if a_delta > 0.:
_x_tr = _x_tr + inputs * a_delta
else:
_x_tr = _x_tr * (1. - inputs) + inputs

return trace, trace

@resolver(_advance_state)
def advance_state(self, outputs, trace):
Expand Down
4 changes: 2 additions & 2 deletions ngclearn/components/synapses/denseSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(self, name, shape, weight_init=None, bias_init=None,
self.bias_init = bias_init

## Synapse meta-parameters
self.shape = shape ## shape of synaptic efficacy matrix
self.Rscale = resist_scale ## post-transformation scale factor
self.shape = shape
self.Rscale = resist_scale

## Set up synaptic weight values
tmp_key, *subkeys = random.split(self.key.value, 4)
Expand Down
4 changes: 2 additions & 2 deletions ngclearn/utils/viz/synapse_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ def visualize_gif(frames, path='.', name='tmp', suffix='.jpg', **kwargs):
_frames = [f.astype(jnp.uint8) for f in frames]
iio.imwrite(path + '/' + name + '.gif', _frames, **kwargs)

def make_video(f_start, f_end, path, prefix, suffix='.jpg', skip=1):
def make_video(f_start, f_end, path, prefix, suffix='.jpg', skip=1, **kwargs):
images = []
for i in range(f_start, f_end+1, skip):
print("Reading frame " + str(i))
images.append(iio.imread(path + "/" + prefix + str(i) + suffix))
print("writing gif")
iio.imwrite(path + '/training.gif', images, loop=0, duration=200)
iio.imwrite(path + '/training.gif', images, **kwargs)


# def visualize_norm(thetas, sizes, prefix, suffix='.jpg'):
Expand Down

0 comments on commit 35eae76

Please sign in to comment.