Skip to content

Commit

Permalink
Merge pull request #205 from NeuroBench/feature/membrane_updates
Browse files Browse the repository at this point in the history
Membrane Updates Metric
  • Loading branch information
jasonlyik authored May 8, 2024
2 parents 3d42b39 + 4909120 commit 5b4f1db
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 1 deletion.
7 changes: 6 additions & 1 deletion neurobench/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from . import static_metrics, workload_metrics

# workload metrics which require hooks
requires_hooks = ["activation_sparsity", "number_neuron_updates", "synaptic_operations"]
requires_hooks = [
"activation_sparsity",
"number_neuron_updates",
"synaptic_operations",
"membrane_updates",
]


class Benchmark:
Expand Down
7 changes: 7 additions & 0 deletions neurobench/benchmarks/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def __init__(self, layer, connection_layer=None, prev_act_layer_hook=None):
"""
self.activation_outputs = []
self.activation_inputs = []
self.pre_fire_mem_potential = []
self.post_fire_mem_potential = []
if layer is not None:
self.hook = layer.register_forward_hook(self.hook_fn)
self.hook_pre = layer.register_forward_pre_hook(self.pre_hook_fn)
Expand All @@ -46,6 +48,8 @@ def pre_hook_fn(self, layer, input):
"""
self.activation_inputs.append(input)
if self.spiking:
self.pre_fire_mem_potential.append(layer.mem)

def hook_fn(self, layer, input, output):
"""
Expand All @@ -62,6 +66,7 @@ def hook_fn(self, layer, input, output):
"""
if self.spiking:
self.activation_outputs.append(output[0])
self.post_fire_mem_potential.append(layer.mem)

else:
self.activation_outputs.append(output)
Expand All @@ -75,6 +80,8 @@ def reset(self):
"""Resets the stored activation outputs and inputs."""
self.activation_outputs = []
self.activation_inputs = []
self.pre_fire_mem_potential = []
self.post_fire_mem_potential = []

def close(self):
"""Remove the registered hook."""
Expand Down
63 changes: 63 additions & 0 deletions neurobench/benchmarks/workload_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from ..utils import check_shape, make_binary_copy, single_layer_MACs
from .hooks import ActivationHook, LayerHook
from collections import defaultdict


class AccumulatedMetric:
Expand Down Expand Up @@ -121,6 +122,68 @@ def activation_sparsity(model, preds, data):
return sparsity


class membrane_updates(AccumulatedMetric):
"""
Number of membrane potential updates.
This metric can only be used for spiking models implemented with SNNTorch.
"""

def __init__(self):
"""Init metric state."""
self.total_samples = 0
self.neuron_membrane_updates = defaultdict(int)

def reset(self):
"""Reset metric state."""
self.total_samples = 0
self.neuron_membrane_updates = defaultdict(int)

def __call__(self, model, preds, data):
"""
Number of membrane updates of the model forward.
Args:
model: A NeuroBenchModel.
preds: A tensor of model predictions.
data: A tuple of data and labels.
Returns:
float: Number of membrane potential updates.
"""
for hook in model.activation_hooks:
for index_mem in range(len(hook.pre_fire_mem_potential) - 1):
pre_fire_mem = hook.pre_fire_mem_potential[index_mem + 1]
post_fire_mem = hook.post_fire_mem_potential[index_mem + 1]
nr_updates = torch.count_nonzero(pre_fire_mem - post_fire_mem)
self.neuron_membrane_updates[str(type(hook.layer))] += int(nr_updates)
self.neuron_membrane_updates[str(type(hook.layer))] += int(
torch.numel(hook.post_fire_mem_potential[0])
)
self.total_samples += data[0].size(0)
return self.compute()

def compute(self):
"""
Compute membrane updates using accumulated data.
Returns:
float: Compute the total updates to each neuron's membrane potential within the model,
aggregated across all neurons and normalized by the number of samples processed.
"""
if self.total_samples == 0:
return 0

total_mem_updates = 0
for key in self.neuron_membrane_updates:
total_mem_updates += self.neuron_membrane_updates[key]

total_updates_per_sample = total_mem_updates / self.total_samples
return total_updates_per_sample


def number_neuron_updates(model, preds, data):
"""
Number of times each neuron type is updated.
Expand Down
28 changes: 28 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
detect_activations_connections,
synaptic_operations,
number_neuron_updates,
membrane_updates,
)
from torch.profiler import profile, record_function, ProfilerActivity

Expand Down Expand Up @@ -531,6 +532,33 @@ def test_neuron_update_metric():
print("Passed neuron update metric")


def test_membrane_potential_updates():

# test snn layers
net_snn = nn.Sequential(
# nn.Flatten(),
nn.Linear(20, 5, bias=False),
snn.Leaky(
beta=0.9, spike_grad=surrogate.fast_sigmoid(), init_hidden=True, output=True
),
)

# simulate spiking input with only ones
inp = torch.ones(5, 10, 20) # batch size, time steps, input size

model = SNNTorchModel(net_snn)

detect_activations_connections(model)

out = model(inp)
mem_updates = membrane_updates()
tot_mem_updates = mem_updates(model, out, (inp, 0))

assert tot_mem_updates == 50

print("Passed membrane updates")


class simple_LSTM(nn.Module):
"""Nonsense LSTM for operations testing Should be 615 MACs."""

Expand Down

0 comments on commit 5b4f1db

Please sign in to comment.