diff --git a/neurobench/benchmarks/workload_metrics.py b/neurobench/benchmarks/workload_metrics.py index 45b60d8..114f083 100644 --- a/neurobench/benchmarks/workload_metrics.py +++ b/neurobench/benchmarks/workload_metrics.py @@ -124,31 +124,32 @@ def activation_sparsity(model, preds, data): class membrane_updates(AccumulatedMetric): """ - Number of synaptic operations. + Number of membrane potential updates. - MACs for ANN ACs for SNN + This metric can only be used for spiking models implemented with SNNTorch. """ def __init__(self): + """Init metric state.""" self.total_samples = 0 self.neuron_membrane_updates = defaultdict(int) def reset(self): + """Reset metric state.""" self.total_samples = 0 self.neuron_membrane_updates = defaultdict(int) def __call__(self, model, preds, data): """ - Multiply-accumulates (MACs) of the model forward. + Number of membrane updates of the model forward. Args: model: A NeuroBenchModel. preds: A tensor of model predictions. data: A tuple of data and labels. - inputs: A tensor of model inputs. Returns: - float: Multiply-accumulates. + float: Number of membrane potential updates. """ for hook in model.activation_hooks: @@ -164,6 +165,14 @@ def __call__(self, model, preds, data): return self.compute() def compute(self): + """ + Compute membrane updates using accumulated data. + + Returns: + float: Compute the total updates to each neuron's membrane potential within the model, + aggregated across all neurons and normalized by the number of samples processed. + + """ if self.total_samples == 0: return 0 diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 6d0b425..7edcdef 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -21,6 +21,7 @@ detect_activations_connections, synaptic_operations, number_neuron_updates, + membrane_updates, ) from torch.profiler import profile, record_function, ProfilerActivity @@ -531,6 +532,33 @@ def test_neuron_update_metric(): print("Passed neuron update metric") +def test_membrane_potential_updates(): + + # test snn layers + net_snn = nn.Sequential( + # nn.Flatten(), + nn.Linear(20, 5, bias=False), + snn.Leaky( + beta=0.9, spike_grad=surrogate.fast_sigmoid(), init_hidden=True, output=True + ), + ) + + # simulate spiking input with only ones + inp = torch.ones(5, 10, 20) # batch size, time steps, input size + + model = SNNTorchModel(net_snn) + + detect_activations_connections(model) + + out = model(inp) + mem_updates = membrane_updates() + tot_mem_updates = mem_updates(model, out, (inp, 0)) + + assert tot_mem_updates == 50 + + print("Passed membrane updates") + + class simple_LSTM(nn.Module): """Nonsense LSTM for operations testing Should be 615 MACs."""