diff --git a/rl4co/models/zoo/ham/attention.py b/rl4co/models/zoo/ham/attention.py index 4636b88d..0c4d593e 100644 --- a/rl4co/models/zoo/ham/attention.py +++ b/rl4co/models/zoo/ham/attention.py @@ -56,7 +56,8 @@ def forward(self, q, h=None, mask=None): q: queries (batch_size, n_query, input_dim) h: data (batch_size, graph_size, input_dim) mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1) - Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency) + + Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency) """ if h is None: h = q # compute self-attention