Skip to content

Commit

Permalink
add-bn
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Feb 13, 2025
1 parent ae9337a commit c36f363
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/eva/core/models/networks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
hidden_activation_fn: Type[torch.nn.Module] | None = nn.ReLU,
output_activation_fn: Type[torch.nn.Module] | None = None,
dropout: float = 0.0,
use_batch_norm: bool = False,
) -> None:
"""Initializes the MLP.
Expand All @@ -36,6 +37,7 @@ def __init__(
self.hidden_activation_fn = hidden_activation_fn
self.output_activation_fn = output_activation_fn
self.dropout = dropout
self.use_batch_norm = use_batch_norm

self._network = self._build_network()

Expand All @@ -45,6 +47,8 @@ def _build_network(self) -> nn.Sequential:
prev_size = self.input_size
for size in self.hidden_layer_sizes:
layers.append(nn.Linear(prev_size, size))
if self.use_batch_norm:
layers.append(nn.BatchNorm1d(size))
if self.hidden_activation_fn is not None:
layers.append(self.hidden_activation_fn())
if self.dropout > 0:
Expand Down

0 comments on commit c36f363

Please sign in to comment.