Skip to content

Commit

Permalink
introducing a membrane potential updates as a new workload metric
Browse files Browse the repository at this point in the history
  • Loading branch information
ben9809 committed Apr 17, 2024
1 parent 3b33435 commit 8ff7f8d
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 1 deletion.
7 changes: 6 additions & 1 deletion neurobench/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from . import static_metrics, workload_metrics

# workload metrics which require hooks
requires_hooks = ["activation_sparsity", "number_neuron_updates", "synaptic_operations"]
requires_hooks = [
"activation_sparsity",
"number_neuron_updates",
"synaptic_operations",
"membrane_updates",
]


class Benchmark:
Expand Down
7 changes: 7 additions & 0 deletions neurobench/benchmarks/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def __init__(self, layer, connection_layer=None, prev_act_layer_hook=None):
"""
self.activation_outputs = []
self.activation_inputs = []
self.pre_fire_mem_potential = []
self.post_fire_mem_potential = []
if layer is not None:
self.hook = layer.register_forward_hook(self.hook_fn)
self.hook_pre = layer.register_forward_pre_hook(self.pre_hook_fn)
Expand All @@ -46,6 +48,8 @@ def pre_hook_fn(self, layer, input):
"""
self.activation_inputs.append(input)
if self.spiking:
self.pre_fire_mem_potential.append(layer.mem)

def hook_fn(self, layer, input, output):
"""
Expand All @@ -62,6 +66,7 @@ def hook_fn(self, layer, input, output):
"""
if self.spiking:
self.activation_outputs.append(output[0])
self.post_fire_mem_potential.append(layer.mem)

else:
self.activation_outputs.append(output)
Expand All @@ -75,6 +80,8 @@ def reset(self):
"""Resets the stored activation outputs and inputs."""
self.activation_outputs = []
self.activation_inputs = []
self.pre_fire_mem_potential = []
self.post_fire_mem_potential = []

def close(self):
"""Remove the registered hook."""
Expand Down
54 changes: 54 additions & 0 deletions neurobench/benchmarks/workload_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from ..utils import check_shape, make_binary_copy, single_layer_MACs
from .hooks import ActivationHook, LayerHook
from collections import defaultdict


class AccumulatedMetric:
Expand Down Expand Up @@ -121,6 +122,59 @@ def activation_sparsity(model, preds, data):
return sparsity


class membrane_updates(AccumulatedMetric):
"""
Number of synaptic operations.
MACs for ANN ACs for SNN
"""

def __init__(self):
self.total_samples = 0
self.neuron_membrane_updates = defaultdict(int)

def reset(self):
self.total_samples = 0
self.neuron_membrane_updates = defaultdict(int)

def __call__(self, model, preds, data):
"""
Multiply-accumulates (MACs) of the model forward.
Args:
model: A NeuroBenchModel.
preds: A tensor of model predictions.
data: A tuple of data and labels.
inputs: A tensor of model inputs.
Returns:
float: Multiply-accumulates.
"""
for hook in model.activation_hooks:
for index_mem in range(len(hook.pre_fire_mem_potential) - 1):
pre_fire_mem = hook.pre_fire_mem_potential[index_mem + 1]
post_fire_mem = hook.post_fire_mem_potential[index_mem + 1]
nr_updates = torch.count_nonzero(pre_fire_mem - post_fire_mem)
self.neuron_membrane_updates[str(type(hook.layer))] += int(nr_updates)
self.neuron_membrane_updates[str(type(hook.layer))] += int(
torch.numel(hook.post_fire_mem_potential[0])
)
self.total_samples += data[0].size(0)
return self.compute()

def compute(self):
if self.total_samples == 0:
return 0

total_mem_updates = 0
for key in self.neuron_membrane_updates:
total_mem_updates += self.neuron_membrane_updates[key]

total_updates_per_sample = total_mem_updates / self.total_samples
return total_updates_per_sample


def number_neuron_updates(model, preds, data):
"""
Number of times each neuron type is updated.
Expand Down

0 comments on commit 8ff7f8d

Please sign in to comment.