diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index cd6838bf..a50151ca 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -331,7 +331,7 @@ def __init__( input_dim=self.post_embed_dim, output_dim=self.post_embed_dim, n_d=n_d, - n_a=n_d, + n_a=n_a, n_steps=n_steps, gamma=gamma, n_independent=n_independent, @@ -442,7 +442,7 @@ def __init__( input_dim=input_dim, output_dim=output_dim, n_d=n_d, - n_a=n_d, + n_a=n_a, n_steps=n_steps, gamma=gamma, n_independent=n_independent,