From c36f363bffa7cf32a5601fa731c91fbaccf0a465 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Thu, 13 Feb 2025 18:08:39 +0100 Subject: [PATCH] add-bn --- src/eva/core/models/networks/mlp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/eva/core/models/networks/mlp.py b/src/eva/core/models/networks/mlp.py index c8403dbe..0d2b0c90 100644 --- a/src/eva/core/models/networks/mlp.py +++ b/src/eva/core/models/networks/mlp.py @@ -17,6 +17,7 @@ def __init__( hidden_activation_fn: Type[torch.nn.Module] | None = nn.ReLU, output_activation_fn: Type[torch.nn.Module] | None = None, dropout: float = 0.0, + use_batch_norm: bool = False, ) -> None: """Initializes the MLP. @@ -36,6 +37,7 @@ def __init__( self.hidden_activation_fn = hidden_activation_fn self.output_activation_fn = output_activation_fn self.dropout = dropout + self.use_batch_norm = use_batch_norm self._network = self._build_network() @@ -45,6 +47,8 @@ def _build_network(self) -> nn.Sequential: prev_size = self.input_size for size in self.hidden_layer_sizes: layers.append(nn.Linear(prev_size, size)) + if self.use_batch_norm: + layers.append(nn.BatchNorm1d(size)) if self.hidden_activation_fn is not None: layers.append(self.hidden_activation_fn()) if self.dropout > 0: