From adf228626cc99dcf690557240a5351f02bf72956 Mon Sep 17 00:00:00 2001 From: Arno Eigenwillig Date: Fri, 20 Oct 2023 03:29:32 -0700 Subject: [PATCH] Two minor fixes for MtAlbis: * edge_feature_name is made configurable through ConfigDict. Edges without real features can use MakeEmptyFeature. * MtAlbisNextNodeState can handle the case of zero incoming edge sets (for experimental use only). PiperOrigin-RevId: 575169338 --- tensorflow_gnn/models/mt_albis/config_dict.py | 4 ++-- tensorflow_gnn/models/mt_albis/layers.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow_gnn/models/mt_albis/config_dict.py b/tensorflow_gnn/models/mt_albis/config_dict.py index 18e8ac78..186849a3 100644 --- a/tensorflow_gnn/models/mt_albis/config_dict.py +++ b/tensorflow_gnn/models/mt_albis/config_dict.py @@ -22,12 +22,12 @@ def graph_update_get_config_dict() -> config_dict.ConfigDict: """Returns ConfigDict for graph_update_from_config_dict() with defaults.""" # LINT.IfChange(graph_update_get_config_dict) - # TODO(b/261835577): What about node_set_names, edge_feature_name, - # attention_edge_set_names? + # TODO(b/261835577): What about node_set_names, attention_edge_set_names? cfg = config_dict.ConfigDict() cfg.units = config_dict.placeholder(int) # Sets type to Optional[int]. cfg.message_dim = config_dict.placeholder(int) cfg.receiver_tag = config_dict.placeholder(int) + cfg.edge_feature_name = config_dict.placeholder(str) cfg.attention_type = "none" cfg.attention_num_heads = 4 cfg.simple_conv_reduce_type = "mean" diff --git a/tensorflow_gnn/models/mt_albis/layers.py b/tensorflow_gnn/models/mt_albis/layers.py index 903c421b..8bfdb7f6 100644 --- a/tensorflow_gnn/models/mt_albis/layers.py +++ b/tensorflow_gnn/models/mt_albis/layers.py @@ -193,9 +193,10 @@ def call( input_state = _require_single_tensor(input_state, "input state") flat_inputs.append(input_state) # Collect and combine pooled messages (conv results) from edge sets. - edge_input = self._combine_edge_inputs(edge_set_inputs) - edge_input = self._dropout(edge_input) - flat_inputs.append(edge_input) + if edge_set_inputs: + edge_input = self._combine_edge_inputs(edge_set_inputs) + edge_input = self._dropout(edge_input) + flat_inputs.append(edge_input) # Collect a context input, if any. (Empty Mapping means none.) if isinstance(context_input, Mapping) and not context_input: pass @@ -243,7 +244,6 @@ def MtAlbisGraphUpdate( # To be called like a class initializer. pylint: disab message_dim: int, receiver_tag: tfgnn.IncidentNodeTag, node_set_names: Optional[Collection[tfgnn.NodeSetName]] = None, - # TODO(b/261835577): Can edge_feature be set for some EdgeSets only? edge_feature_name: Optional[tfgnn.FieldName] = None, attention_type: Literal["none", "multi_head", "gat_v2"] = "none", attention_edge_set_names: Optional[Collection[tfgnn.EdgeSetName]] = None,