Skip to content

Commit

Permalink
enhance documentation for membrane updates metric and generate tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ben9809 committed Apr 20, 2024
1 parent 8ff7f8d commit 4909120
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
19 changes: 14 additions & 5 deletions neurobench/benchmarks/workload_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,31 +124,32 @@ def activation_sparsity(model, preds, data):

class membrane_updates(AccumulatedMetric):
"""
Number of synaptic operations.
Number of membrane potential updates.
MACs for ANN ACs for SNN
This metric can only be used for spiking models implemented with SNNTorch.
"""

def __init__(self):
"""Init metric state."""
self.total_samples = 0
self.neuron_membrane_updates = defaultdict(int)

def reset(self):
"""Reset metric state."""
self.total_samples = 0
self.neuron_membrane_updates = defaultdict(int)

def __call__(self, model, preds, data):
"""
Multiply-accumulates (MACs) of the model forward.
Number of membrane updates of the model forward.
Args:
model: A NeuroBenchModel.
preds: A tensor of model predictions.
data: A tuple of data and labels.
inputs: A tensor of model inputs.
Returns:
float: Multiply-accumulates.
float: Number of membrane potential updates.
"""
for hook in model.activation_hooks:
Expand All @@ -164,6 +165,14 @@ def __call__(self, model, preds, data):
return self.compute()

def compute(self):
"""
Compute membrane updates using accumulated data.
Returns:
float: Compute the total updates to each neuron's membrane potential within the model,
aggregated across all neurons and normalized by the number of samples processed.
"""
if self.total_samples == 0:
return 0

Expand Down
28 changes: 28 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
detect_activations_connections,
synaptic_operations,
number_neuron_updates,
membrane_updates,
)
from torch.profiler import profile, record_function, ProfilerActivity

Expand Down Expand Up @@ -531,6 +532,33 @@ def test_neuron_update_metric():
print("Passed neuron update metric")


def test_membrane_potential_updates():

# test snn layers
net_snn = nn.Sequential(
# nn.Flatten(),
nn.Linear(20, 5, bias=False),
snn.Leaky(
beta=0.9, spike_grad=surrogate.fast_sigmoid(), init_hidden=True, output=True
),
)

# simulate spiking input with only ones
inp = torch.ones(5, 10, 20) # batch size, time steps, input size

model = SNNTorchModel(net_snn)

detect_activations_connections(model)

out = model(inp)
mem_updates = membrane_updates()
tot_mem_updates = mem_updates(model, out, (inp, 0))

assert tot_mem_updates == 50

print("Passed membrane updates")


class simple_LSTM(nn.Module):
"""Nonsense LSTM for operations testing Should be 615 MACs."""

Expand Down

0 comments on commit 4909120

Please sign in to comment.