diff --git a/tensorflow_gnn/keras/layers/convolutions.py b/tensorflow_gnn/keras/layers/convolutions.py index f19acc89..a4a740a2 100644 --- a/tensorflow_gnn/keras/layers/convolutions.py +++ b/tensorflow_gnn/keras/layers/convolutions.py @@ -62,7 +62,7 @@ class SimpleConv(convolution_base.AnyToAnyConvolutionBase): combined input features (see combine_type). reduce_type: Specifies how to pool the messages to receivers. Defaults to `"sum"`, can be any reduce_type understood by `tfgnn.pool()`, including - concatenations like `"sum|max"` (but mind the increased dimension of the + concatenations like `"sum|mean"` (but mind the increased dimension of the result and the growing number of model weights in the next-state layer). combine_type: a string understood by tfgnn.combine_values(), to specify how the inputs are combined before passing them to the message_fn. Defaults diff --git a/tensorflow_gnn/models/mt_albis/README.md b/tensorflow_gnn/models/mt_albis/README.md index 98cbe9eb..917c755b 100644 --- a/tensorflow_gnn/models/mt_albis/README.md +++ b/tensorflow_gnn/models/mt_albis/README.md @@ -16,7 +16,7 @@ states from incoming messages. Its main architectural choices are: * how to aggregate the incoming messages from each node set: * by element-wise averaging (reduce type `"mean"`), * by a concatenation of the average with other fixed expressions - (e.g., `"mean|max"`, `"mean|sum"`), or + (e.g., `"mean|max_no_inf"`, `"mean|sum"`), or * with attention, that is, a trained, data-dependent weighting; * whether to use residual connections for updating node states; * if and how to normalize node states. diff --git a/tensorflow_gnn/models/mt_albis/layers.py b/tensorflow_gnn/models/mt_albis/layers.py index 3eb9abe2..4c1df1b8 100644 --- a/tensorflow_gnn/models/mt_albis/layers.py +++ b/tensorflow_gnn/models/mt_albis/layers.py @@ -47,7 +47,7 @@ def MtAlbisSimpleConv( # To be called like a class initializer. pylint: disabl If left unset for init, the tag must be passed at call time. reduce_type: Controls how messages are aggregated on an EdgeSet for each receiver node; defaults to `"mean"`. Can be any reduce_type understood by - `tfgnn.pool()`, including concatenations like `"mean|max"` (but mind the + `tfgnn.pool()`, including concatenations like `"mean|sum"` (but mind the increased dimension of the result and the growing number of model weights in the next-state layer). activation: The nonlinearity used on each message before pooling. @@ -291,10 +291,10 @@ def MtAlbisGraphUpdate( # To be called like a class initializer. pylint: disab simple_conv_reduce_type: For attention_type `"none"`, controls how messages are aggregated on an EdgeSet for each receiver node. Defaults to `"mean"`; other recommended values are the concatenations `"mean|sum"`, - `"mean|max"`, and `"mean|sum|max"` (but mind the increased output - dimension and the corresponding increase in the number of weights in the - next-state layer). Technically, can be set to any reduce_type understood - by `tfgnn.pool()`. + `"mean|max_no_inf"`, and `"mean|sum|max_no_inf"` (but mind the increased + output dimension and the corresponding increase in the number of weights + in the next-state layer). Technically, can be set to any reduce_type + understood by `tfgnn.pool()`. simple_conv_use_receiver_state: For attention_type `"none"`, controls whether the receiver node state is used in computing each edge's message (in addition to the sender node state and possibly an `edge feature`). diff --git a/tensorflow_gnn/models/vanilla_mpnn/layers.py b/tensorflow_gnn/models/vanilla_mpnn/layers.py index 5178b2a7..51d525dd 100644 --- a/tensorflow_gnn/models/vanilla_mpnn/layers.py +++ b/tensorflow_gnn/models/vanilla_mpnn/layers.py @@ -74,7 +74,7 @@ def VanillaMPNNGraphUpdate( # To be called like a class initializer. pylint: d this input. reduce_type: How to pool the messages from edges to receiver nodes; defaults to `"sum"`. Can be any reduce_type understood by `tfgnn.pool()`, including - concatenations like `"sum|max"` (but mind the increased dimension of the + concatenations like `"sum|mean"` (but mind the increased dimension of the result and the growing number of model weights in the next-state layer). l2_regularization: The coefficient of L2 regularization for weights and biases.