From 01b77dcb6ad30a818f4c4645200cd600e1e591c2 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 11 Oct 2024 19:00:19 +0000 Subject: [PATCH] more fixes. --- python/dgl/graphbolt/feature_fetcher.py | 4 ++-- python/dgl/graphbolt/impl/cooperative_conv.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/python/dgl/graphbolt/feature_fetcher.py b/python/dgl/graphbolt/feature_fetcher.py index e3db4df38bd2..832d00181588 100644 --- a/python/dgl/graphbolt/feature_fetcher.py +++ b/python/dgl/graphbolt/feature_fetcher.py @@ -169,12 +169,12 @@ def _cooperative_exchange(self, data): node_features = {key: {} for _, key in data.node_features.keys()} for (ntype, key), feature in data.node_features.items(): node_features[key][ntype] = feature - for key, feature in node_features.items(): + for key, feature in sorted(node_features.items()): new_feature = CooperativeConvFunction.apply(subgraph, feature) for ntype, tensor in new_feature.items(): data.node_features[(ntype, key)] = tensor else: - for key in data.node_features: + for key in sorted(data.node_features): feature = data.node_features[key] new_feature = CooperativeConvFunction.apply(subgraph, feature) data.node_features[key] = new_feature diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index 0daa106c17a6..c2d40d94d186 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -42,7 +42,7 @@ def forward( seed_sizes, ) outs = {} - for ntype, typed_tensor in convert_to_hetero(tensor).items(): + for ntype, typed_tensor in sorted(convert_to_hetero(tensor).items()): out = typed_tensor.new_empty( (sum(counts_sent.get(ntype, [0])),) + typed_tensor.shape[1:], requires_grad=typed_tensor.requires_grad, @@ -70,7 +70,9 @@ def backward( ) = ctx.communication_variables delattr(ctx, "communication_variables") outs = {} - for ntype, typed_grad_output in convert_to_hetero(grad_output).items(): + for ntype, typed_grad_output in sorted( + convert_to_hetero(grad_output).items() + ): out = typed_grad_output.new_empty( (sum(counts_received[ntype]),) + typed_grad_output.shape[1:] )