From 68bd455cf0a27615a55cf2b5a82c874edb14b106 Mon Sep 17 00:00:00 2001 From: Ayush Noori Date: Mon, 20 Nov 2023 17:40:00 -0500 Subject: [PATCH 01/10] Fix typo in ShaDowKHopSampler Fix typo in ShaDowKHopSampler sample() function --- python/dgl/dataloading/shadow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dgl/dataloading/shadow.py b/python/dgl/dataloading/shadow.py index 5618ff9e94eb..05ccf99aef87 100644 --- a/python/dgl/dataloading/shadow.py +++ b/python/dgl/dataloading/shadow.py @@ -95,7 +95,7 @@ def sample( Parameters ---------- g : DGLGraph - The graph to sampler from. + The graph to sample nodes from. seed_nodes : Tensor or dict[str, Tensor] The nodes sampled in the current minibatch. exclude_eids : Tensor or dict[etype, Tensor], optional From f16d04ae30bf24d332042d713ae7a81c3b5bee4a Mon Sep 17 00:00:00 2001 From: Ayush Noori Date: Sat, 2 Dec 2023 19:46:53 -0500 Subject: [PATCH 02/10] Draft of fixed sampler --- python/dgl/dataloading/fixed.py | 146 ++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 python/dgl/dataloading/fixed.py diff --git a/python/dgl/dataloading/fixed.py b/python/dgl/dataloading/fixed.py new file mode 100644 index 000000000000..b7c15b90b7df --- /dev/null +++ b/python/dgl/dataloading/fixed.py @@ -0,0 +1,146 @@ +"""Fixed subgraph sampler.""" +from ..sampling.utils import EidExcluder +from .base import set_node_lazy_features, set_edge_lazy_features, Sampler + +# import non-DGL libraries +import numpy as np +import torch +from collections import defaultdict + +class FixedSampler(Sampler): + """Subgraph sampler that heterogeneous sampler that sets an upper + bound on the number of nodes included in each layer of the sampled subgraph. + + At each layer, the frontier is randomly subsampled. Rare node types can also be + upsampled by taking the scaled square root of the sampling probabilities. + + It performs node-wise neighbor sampling and returns the subgraph induced by + all the sampled nodes. + + Parameters + ---------- + fanouts : list[int] or list[dict[etype, int]] + List of neighbors to sample per edge type for each GNN layer, with the i-th + element being the fanout for the i-th GNN layer. + + If only a single integer is provided, DGL assumes that every edge type + will have the same fanout. + + If -1 is provided for one edge type on one layer, then all inbound edges + of that edge type will be included. + fixed_k : int + The number of nodes to sample for each GNN layer. + upsample_rare_types : bool + Whether or not to upsample rare node types. + replace : bool, default True + Whether to sample with replacement + prob : str, optional + If given, the probability of each neighbor being sampled is proportional + to the edge feature value with the given name in ``g.edata``. The feature must be + a scalar on each edge. + """ + def __init__(self, fanouts, fixed_k, upsample_rare_types, replace=False, prob=None, + prefetch_node_feats=None, prefetch_edge_feats=None, output_device=None): + super().__init__() + self.fanouts = fanouts + self.replace = replace + self.fixed_k = fixed_k + self.upsample_rare_types = upsample_rare_types + self.prob = prob + self.prefetch_node_feats = prefetch_node_feats + self.prefetch_edge_feats = prefetch_edge_feats + self.output_device = output_device + + def sample(self, g, seed_nodes, exclude_eids=None): + """Sampling function. + + Parameters + ---------- + g : DGLGraph + The graph to sampler from. + seed_nodes : Tensor or dict[str, Tensor] + The nodes sampled in the current minibatch. + exclude_eids : Tensor or dict[etype, Tensor], optional + The edges to exclude from neighborhood expansion. + + Returns + ------- + input_nodes, output_nodes, subg + A triplet containing (1) the node IDs inducing the subgraph, (2) the node + IDs that are sampled in this minibatch, and (3) the subgraph itself. + """ + + # define empty dictionary to store reached nodes + output_nodes = seed_nodes + all_reached_nodes = [seed_nodes] + + # iterate over fanout + for fanout in reversed(self.fanouts): + + # sample frontier + frontier = g.sample_neighbors( + seed_nodes, fanout, output_device=self.output_device, + replace=self.replace, prob=self.prob, exclude_edges=exclude_eids) + + # get reached nodes + curr_reached = defaultdict(list) + for c_etype in frontier.canonical_etypes: + (src_type, rel_type, dst_type) = c_etype + src, _ = frontier.edges(etype = c_etype) + curr_reached[src_type].append(src) + + # de-duplication + curr_reached = {ntype : torch.unique(torch.cat(srcs)) for ntype, srcs in curr_reached.items()} + + # generate type sampling probabilties + type_count = {node_type: indices.shape[0] for node_type, indices in curr_reached.items()} + total_count = sum(type_count.values()) + probs = {node_type: count / total_count for node_type, count in type_count.items()} + + # upsample rare node types + if self.upsample_rare_types: + + # take scaled square root of probabilities + prob_dist = list(probs.values()) + prob_dist = np.sqrt(prob_dist) + prob_dist = prob_dist / prob_dist.sum() + + # update probabilities + probs = {node_type: prob_dist[i] for i, node_type in enumerate(probs.keys())} + + # generate node counts per type + n_per_type = {node_type: int(self.fixed_k * prob) for node_type, prob in probs.items()} + remainder = self.fixed_k - sum(n_per_type.values()) + for _ in range(remainder): + node_type = np.random.choice(list(probs.keys()), p=list(probs.values())) + n_per_type[node_type] += 1 + + # downsample nodes + curr_reached_k = {} + for node_type, node_IDs in curr_reached.items(): + + # get number of total nodes and number to sample + num_nodes = node_IDs.shape[0] + n_to_sample = min(num_nodes, n_per_type[node_type]) + + # downsample nodes of current type + random_indices = torch.randperm(num_nodes)[:n_to_sample] + curr_reached_k[node_type] = node_IDs[random_indices] + + # update seed nodes + seed_nodes = curr_reached_k + all_reached_nodes.append(curr_reached_k) + + # merge all reached nodes before sending to DGLGraph.subgraph + merged_nodes = {} + for ntype in g.ntypes: + merged_nodes[ntype] = torch.unique(torch.cat([reached.get(ntype, []) for reached in all_reached_nodes])) + subg = g.subgraph(merged_nodes, relabel_nodes=True, output_device=self.output_device) + + if exclude_eids is not None: + subg = EidExcluder(exclude_eids)(subg) + + set_node_lazy_features(subg, self.prefetch_node_feats) + set_edge_lazy_features(subg, self.prefetch_edge_feats) + + return seed_nodes, output_nodes, subg \ No newline at end of file From 928fe213334241781dfd646932af5ca24e85a93f Mon Sep 17 00:00:00 2001 From: Ayush Noori Date: Sun, 3 Dec 2023 21:36:24 -0500 Subject: [PATCH 03/10] Fix documentation --- python/dgl/dataloading/fixed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/dgl/dataloading/fixed.py b/python/dgl/dataloading/fixed.py index b7c15b90b7df..7b0e6238a08d 100644 --- a/python/dgl/dataloading/fixed.py +++ b/python/dgl/dataloading/fixed.py @@ -8,8 +8,8 @@ from collections import defaultdict class FixedSampler(Sampler): - """Subgraph sampler that heterogeneous sampler that sets an upper - bound on the number of nodes included in each layer of the sampled subgraph. + """Subgraph sampler that sets an upper bound on the number of nodes included in + each layer of the sampled subgraph. At each layer, the frontier is randomly subsampled. Rare node types can also be upsampled by taking the scaled square root of the sampling probabilities. From 6f8acdb2fac5a7eb477c51a9162b92ab1e343eaa Mon Sep 17 00:00:00 2001 From: Ayush Noori Date: Sat, 27 Jan 2024 20:16:53 -0500 Subject: [PATCH 04/10] Address feedback from @frozenbugs --- .../{fixed.py => capped_neighbor_sampler.py} | 99 ++++++++++--------- 1 file changed, 52 insertions(+), 47 deletions(-) rename python/dgl/dataloading/{fixed.py => capped_neighbor_sampler.py} (61%) diff --git a/python/dgl/dataloading/fixed.py b/python/dgl/dataloading/capped_neighbor_sampler.py similarity index 61% rename from python/dgl/dataloading/fixed.py rename to python/dgl/dataloading/capped_neighbor_sampler.py index 7b0e6238a08d..b218f12fc950 100644 --- a/python/dgl/dataloading/fixed.py +++ b/python/dgl/dataloading/capped_neighbor_sampler.py @@ -1,46 +1,43 @@ """Fixed subgraph sampler.""" -from ..sampling.utils import EidExcluder -from .base import set_node_lazy_features, set_edge_lazy_features, Sampler - -# import non-DGL libraries +from collections import defaultdict import numpy as np import torch -from collections import defaultdict -class FixedSampler(Sampler): +from ..sampling.utils import EidExcluder +from .base import set_node_lazy_features, set_edge_lazy_features, Sampler + +class CappedNeighborSampler(Sampler): """Subgraph sampler that sets an upper bound on the number of nodes included in - each layer of the sampled subgraph. - - At each layer, the frontier is randomly subsampled. Rare node types can also be - upsampled by taking the scaled square root of the sampling probabilities. + each layer of the sampled subgraph. At each layer, the frontier is randomly + subsampled. Rare node types can also be upsampled by taking the scaled square + root of the sampling probabilities. It performs node-wise neighbor sampling and returns the subgraph induced by all the sampled nodes. Parameters ---------- - fanouts : list[int] or list[dict[etype, int]] + fanouts : list[int] or dict[etype, int] List of neighbors to sample per edge type for each GNN layer, with the i-th element being the fanout for the i-th GNN layer. - - If only a single integer is provided, DGL assumes that every edge type - will have the same fanout. - - If -1 is provided for one edge type on one layer, then all inbound edges - of that edge type will be included. + - If only a single integer is provided, DGL assumes that every edge type + will have the same fanout. + - If -1 is provided for one edge type on one layer, then all inbound edges + of that edge type will be included. fixed_k : int The number of nodes to sample for each GNN layer. upsample_rare_types : bool Whether or not to upsample rare node types. replace : bool, default True - Whether to sample with replacement + Whether to sample with replacement. prob : str, optional If given, the probability of each neighbor being sampled is proportional to the edge feature value with the given name in ``g.edata``. The feature must be a scalar on each edge. """ - def __init__(self, fanouts, fixed_k, upsample_rare_types, replace=False, prob=None, - prefetch_node_feats=None, prefetch_edge_feats=None, output_device=None): + def __init__(self, fanouts, fixed_k, upsample_rare_types, replace=False, + prob=None, prefetch_node_feats=None, prefetch_edge_feats=None, + output_device=None): super().__init__() self.fanouts = fanouts self.replace = replace @@ -57,81 +54,89 @@ def sample(self, g, seed_nodes, exclude_eids=None): Parameters ---------- g : DGLGraph - The graph to sampler from. + The graph to sample from. seed_nodes : Tensor or dict[str, Tensor] The nodes sampled in the current minibatch. exclude_eids : Tensor or dict[etype, Tensor], optional - The edges to exclude from neighborhood expansion. + The edges to exclude from the sampled subgraph. Returns ------- - input_nodes, output_nodes, subg - A triplet containing (1) the node IDs inducing the subgraph, (2) the node - IDs that are sampled in this minibatch, and (3) the subgraph itself. + input_nodes : Tensor or dict[str, Tensor] + The node IDs inducing the subgraph. + output_nodes : Tensor or dict[str, Tensor] + The node IDs that are sampled in this minibatch. + subg : DGLGraph + The subgraph itself. """ - # define empty dictionary to store reached nodes + # Define empty dictionary to store reached nodes. output_nodes = seed_nodes all_reached_nodes = [seed_nodes] - # iterate over fanout + # Iterate over fanout. for fanout in reversed(self.fanouts): - # sample frontier + # Sample frontier. frontier = g.sample_neighbors( - seed_nodes, fanout, output_device=self.output_device, - replace=self.replace, prob=self.prob, exclude_edges=exclude_eids) + seed_nodes, fanout, + output_device=self.output_device, replace=self.replace, + prob=self.prob, exclude_edges=exclude_eids) - # get reached nodes + # Get reached nodes. curr_reached = defaultdict(list) for c_etype in frontier.canonical_etypes: (src_type, rel_type, dst_type) = c_etype src, _ = frontier.edges(etype = c_etype) curr_reached[src_type].append(src) - # de-duplication - curr_reached = {ntype : torch.unique(torch.cat(srcs)) for ntype, srcs in curr_reached.items()} + # De-duplication. + curr_reached = {ntype : torch.unique(torch.cat(srcs)) + for ntype, srcs in curr_reached.items()} - # generate type sampling probabilties - type_count = {node_type: indices.shape[0] for node_type, indices in curr_reached.items()} + # Generate type sampling probabilties. + type_count = {node_type: indices.shape[0] + for node_type, indices in curr_reached.items()} total_count = sum(type_count.values()) - probs = {node_type: count / total_count for node_type, count in type_count.items()} + probs = {node_type: count / total_count + for node_type, count in type_count.items()} - # upsample rare node types + # Upsample rare node types. if self.upsample_rare_types: - # take scaled square root of probabilities + # Take scaled square root of probabilities. prob_dist = list(probs.values()) prob_dist = np.sqrt(prob_dist) prob_dist = prob_dist / prob_dist.sum() - # update probabilities + # Update probabilities. probs = {node_type: prob_dist[i] for i, node_type in enumerate(probs.keys())} - # generate node counts per type - n_per_type = {node_type: int(self.fixed_k * prob) for node_type, prob in probs.items()} + # Generate node counts per type. + n_per_type = {node_type: int(self.fixed_k * prob) + for node_type, prob in probs.items()} remainder = self.fixed_k - sum(n_per_type.values()) for _ in range(remainder): node_type = np.random.choice(list(probs.keys()), p=list(probs.values())) n_per_type[node_type] += 1 - # downsample nodes + # Downsample nodes. curr_reached_k = {} for node_type, node_IDs in curr_reached.items(): - # get number of total nodes and number to sample + # Get number of total nodes and number to sample. num_nodes = node_IDs.shape[0] n_to_sample = min(num_nodes, n_per_type[node_type]) - # downsample nodes of current type + # Downsample nodes of current type. random_indices = torch.randperm(num_nodes)[:n_to_sample] curr_reached_k[node_type] = node_IDs[random_indices] - # update seed nodes + # Update seed nodes. seed_nodes = curr_reached_k all_reached_nodes.append(curr_reached_k) - # merge all reached nodes before sending to DGLGraph.subgraph + # Merge all reached nodes before sending to `DGLGraph.subgraph`. merged_nodes = {} for ntype in g.ntypes: merged_nodes[ntype] = torch.unique(torch.cat([reached.get(ntype, []) for reached in all_reached_nodes])) @@ -143,4 +148,4 @@ def sample(self, g, seed_nodes, exclude_eids=None): set_node_lazy_features(subg, self.prefetch_node_feats) set_edge_lazy_features(subg, self.prefetch_edge_feats) - return seed_nodes, output_nodes, subg \ No newline at end of file + return seed_nodes, output_nodes, subg From deb98552ba872ccd7a1ccba06414aa8c1499990e Mon Sep 17 00:00:00 2001 From: Ayush Noori Date: Fri, 9 Feb 2024 15:28:38 -0500 Subject: [PATCH 05/10] Fixed linting errors --- .../dataloading/capped_neighbor_sampler.py | 86 +++++++++++++------ 1 file changed, 59 insertions(+), 27 deletions(-) diff --git a/python/dgl/dataloading/capped_neighbor_sampler.py b/python/dgl/dataloading/capped_neighbor_sampler.py index b218f12fc950..0e2597f376fd 100644 --- a/python/dgl/dataloading/capped_neighbor_sampler.py +++ b/python/dgl/dataloading/capped_neighbor_sampler.py @@ -1,18 +1,18 @@ """Fixed subgraph sampler.""" from collections import defaultdict + import numpy as np import torch from ..sampling.utils import EidExcluder -from .base import set_node_lazy_features, set_edge_lazy_features, Sampler +from .base import Sampler, set_edge_lazy_features, set_node_lazy_features -class CappedNeighborSampler(Sampler): - """Subgraph sampler that sets an upper bound on the number of nodes included in - each layer of the sampled subgraph. At each layer, the frontier is randomly - subsampled. Rare node types can also be upsampled by taking the scaled square - root of the sampling probabilities. - It performs node-wise neighbor sampling and returns the subgraph induced by +class CappedNeighborSampler(Sampler): + """Subgraph sampler that sets an upper bound on the number of nodes included in + each layer of the sampled subgraph. At each layer, the frontier is randomly + subsampled. Rare node types can also be upsampled by taking the scaled square + root of the sampling probabilities. The sampler returns the subgraph induced by all the sampled nodes. Parameters @@ -35,9 +35,18 @@ class CappedNeighborSampler(Sampler): to the edge feature value with the given name in ``g.edata``. The feature must be a scalar on each edge. """ - def __init__(self, fanouts, fixed_k, upsample_rare_types, replace=False, - prob=None, prefetch_node_feats=None, prefetch_edge_feats=None, - output_device=None): + + def __init__( + self, + fanouts, + fixed_k, + upsample_rare_types, + replace=False, + prob=None, + prefetch_node_feats=None, + prefetch_edge_feats=None, + output_device=None, + ): super().__init__() self.fanouts = fanouts self.replace = replace @@ -56,7 +65,7 @@ def sample(self, g, seed_nodes, exclude_eids=None): g : DGLGraph The graph to sample from. seed_nodes : Tensor or dict[str, Tensor] - The nodes sampled in the current minibatch. + Nodes which induce the subgraph. exclude_eids : Tensor or dict[etype, Tensor], optional The edges to exclude from the sampled subgraph. @@ -79,27 +88,37 @@ def sample(self, g, seed_nodes, exclude_eids=None): # Sample frontier. frontier = g.sample_neighbors( - seed_nodes, fanout, - output_device=self.output_device, replace=self.replace, - prob=self.prob, exclude_edges=exclude_eids) + seed_nodes, + fanout, + output_device=self.output_device, + replace=self.replace, + prob=self.prob, + exclude_edges=exclude_eids, + ) # Get reached nodes. curr_reached = defaultdict(list) for c_etype in frontier.canonical_etypes: (src_type, rel_type, dst_type) = c_etype - src, _ = frontier.edges(etype = c_etype) + src, _ = frontier.edges(etype=c_etype) curr_reached[src_type].append(src) # De-duplication. - curr_reached = {ntype : torch.unique(torch.cat(srcs)) - for ntype, srcs in curr_reached.items()} + curr_reached = { + ntype: torch.unique(torch.cat(srcs)) + for ntype, srcs in curr_reached.items() + } # Generate type sampling probabilties. - type_count = {node_type: indices.shape[0] - for node_type, indices in curr_reached.items()} + type_count = { + node_type: indices.shape[0] + for node_type, indices in curr_reached.items() + } total_count = sum(type_count.values()) - probs = {node_type: count / total_count - for node_type, count in type_count.items()} + probs = { + node_type: count / total_count + for node_type, count in type_count.items() + } # Upsample rare node types. if self.upsample_rare_types: @@ -110,14 +129,21 @@ def sample(self, g, seed_nodes, exclude_eids=None): prob_dist = prob_dist / prob_dist.sum() # Update probabilities. - probs = {node_type: prob_dist[i] for i, node_type in enumerate(probs.keys())} + probs = { + node_type: prob_dist[i] + for i, node_type in enumerate(probs.keys()) + } # Generate node counts per type. - n_per_type = {node_type: int(self.fixed_k * prob) - for node_type, prob in probs.items()} + n_per_type = { + node_type: int(self.fixed_k * prob) + for node_type, prob in probs.items() + } remainder = self.fixed_k - sum(n_per_type.values()) for _ in range(remainder): - node_type = np.random.choice(list(probs.keys()), p=list(probs.values())) + node_type = np.random.choice( + list(probs.keys()), p=list(probs.values()) + ) n_per_type[node_type] += 1 # Downsample nodes. @@ -139,8 +165,14 @@ def sample(self, g, seed_nodes, exclude_eids=None): # Merge all reached nodes before sending to `DGLGraph.subgraph`. merged_nodes = {} for ntype in g.ntypes: - merged_nodes[ntype] = torch.unique(torch.cat([reached.get(ntype, []) for reached in all_reached_nodes])) - subg = g.subgraph(merged_nodes, relabel_nodes=True, output_device=self.output_device) + merged_nodes[ntype] = torch.unique( + torch.cat( + [reached.get(ntype, []) for reached in all_reached_nodes] + ) + ) + subg = g.subgraph( + merged_nodes, relabel_nodes=True, output_device=self.output_device + ) if exclude_eids is not None: subg = EidExcluder(exclude_eids)(subg) From ba9e1f2619d48922251019dfe65f320936f9a541 Mon Sep 17 00:00:00 2001 From: Ayush Noori Date: Fri, 5 Jul 2024 22:07:44 -0700 Subject: [PATCH 06/10] Add unit test comment requested by @frozenbugs --- python/dgl/dataloading/capped_neighbor_sampler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/dgl/dataloading/capped_neighbor_sampler.py b/python/dgl/dataloading/capped_neighbor_sampler.py index 0e2597f376fd..f1ee24ea3db6 100644 --- a/python/dgl/dataloading/capped_neighbor_sampler.py +++ b/python/dgl/dataloading/capped_neighbor_sampler.py @@ -15,6 +15,11 @@ class CappedNeighborSampler(Sampler): root of the sampling probabilities. The sampler returns the subgraph induced by all the sampled nodes. + This code was contributed by a community member + ([@ayushnoori](https://github.com/ayushnoori)). There aren't currently any unit + tests in place to verify its functionality, so please be cautious if you need + to make any changes to the code's logic. + Parameters ---------- fanouts : list[int] or dict[etype, int] From 87df98e8a439894e29abc99f131abb4771785df3 Mon Sep 17 00:00:00 2001 From: Ayush Noori Date: Fri, 5 Jul 2024 22:09:07 -0700 Subject: [PATCH 07/10] Fix linting issues --- python/dgl/dataloading/capped_neighbor_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dgl/dataloading/capped_neighbor_sampler.py b/python/dgl/dataloading/capped_neighbor_sampler.py index f1ee24ea3db6..2c13d7778e82 100644 --- a/python/dgl/dataloading/capped_neighbor_sampler.py +++ b/python/dgl/dataloading/capped_neighbor_sampler.py @@ -17,7 +17,7 @@ class CappedNeighborSampler(Sampler): This code was contributed by a community member ([@ayushnoori](https://github.com/ayushnoori)). There aren't currently any unit - tests in place to verify its functionality, so please be cautious if you need + tests in place to verify its functionality, so please be cautious if you need to make any changes to the code's logic. Parameters From 161edd6edb02542b286ac2eb1cb651560a4ecf20 Mon Sep 17 00:00:00 2001 From: Ayush Noori Date: Fri, 5 Jul 2024 23:07:00 -0700 Subject: [PATCH 08/10] Fix second linter errors --- .../dataloading/capped_neighbor_sampler.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/dgl/dataloading/capped_neighbor_sampler.py b/python/dgl/dataloading/capped_neighbor_sampler.py index 2c13d7778e82..7abb299a0d66 100644 --- a/python/dgl/dataloading/capped_neighbor_sampler.py +++ b/python/dgl/dataloading/capped_neighbor_sampler.py @@ -62,14 +62,14 @@ def __init__( self.prefetch_edge_feats = prefetch_edge_feats self.output_device = output_device - def sample(self, g, seed_nodes, exclude_eids=None): + def sample(self, g, indices, exclude_eids=None): """Sampling function. Parameters ---------- g : DGLGraph The graph to sample from. - seed_nodes : Tensor or dict[str, Tensor] + indices : Tensor or dict[str, Tensor] Nodes which induce the subgraph. exclude_eids : Tensor or dict[etype, Tensor], optional The edges to exclude from the sampled subgraph. @@ -85,15 +85,15 @@ def sample(self, g, seed_nodes, exclude_eids=None): """ # Define empty dictionary to store reached nodes. - output_nodes = seed_nodes - all_reached_nodes = [seed_nodes] + output_nodes = indices + all_reached_nodes = [indices] # Iterate over fanout. for fanout in reversed(self.fanouts): # Sample frontier. frontier = g.sample_neighbors( - seed_nodes, + indices, fanout, output_device=self.output_device, replace=self.replace, @@ -104,7 +104,7 @@ def sample(self, g, seed_nodes, exclude_eids=None): # Get reached nodes. curr_reached = defaultdict(list) for c_etype in frontier.canonical_etypes: - (src_type, rel_type, dst_type) = c_etype + (src_type, _, _) = c_etype src, _ = frontier.edges(etype=c_etype) curr_reached[src_type].append(src) @@ -153,18 +153,18 @@ def sample(self, g, seed_nodes, exclude_eids=None): # Downsample nodes. curr_reached_k = {} - for node_type, node_IDs in curr_reached.items(): + for node_type, node_ids in curr_reached.items(): # Get number of total nodes and number to sample. - num_nodes = node_IDs.shape[0] + num_nodes = node_ids.shape[0] n_to_sample = min(num_nodes, n_per_type[node_type]) # Downsample nodes of current type. random_indices = torch.randperm(num_nodes)[:n_to_sample] - curr_reached_k[node_type] = node_IDs[random_indices] + curr_reached_k[node_type] = node_ids[random_indices] # Update seed nodes. - seed_nodes = curr_reached_k + indices = curr_reached_k all_reached_nodes.append(curr_reached_k) # Merge all reached nodes before sending to `DGLGraph.subgraph`. @@ -185,4 +185,4 @@ def sample(self, g, seed_nodes, exclude_eids=None): set_node_lazy_features(subg, self.prefetch_node_feats) set_edge_lazy_features(subg, self.prefetch_edge_feats) - return seed_nodes, output_nodes, subg + return indices, output_nodes, subg From 722cce1012328386b41f760a984c0a9b20639364 Mon Sep 17 00:00:00 2001 From: Ayush Noori Date: Sun, 7 Jul 2024 17:42:06 -0700 Subject: [PATCH 09/10] Disable differing arguments warning --- python/dgl/dataloading/capped_neighbor_sampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/dgl/dataloading/capped_neighbor_sampler.py b/python/dgl/dataloading/capped_neighbor_sampler.py index 7abb299a0d66..e5749f536324 100644 --- a/python/dgl/dataloading/capped_neighbor_sampler.py +++ b/python/dgl/dataloading/capped_neighbor_sampler.py @@ -62,7 +62,9 @@ def __init__( self.prefetch_edge_feats = prefetch_edge_feats self.output_device = output_device - def sample(self, g, indices, exclude_eids=None): + def sample( + self, g, indices, exclude_eids=None + ): # pylint: disable=arguments-differ """Sampling function. Parameters From 7b8e30f67eab35425c1a3bed43737ed0c653462e Mon Sep 17 00:00:00 2001 From: "Hongzhi (Steve), Chen" Date: Tue, 9 Jul 2024 13:59:57 +0800 Subject: [PATCH 10/10] Update python/dgl/dataloading/capped_neighbor_sampler.py --- python/dgl/dataloading/capped_neighbor_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dgl/dataloading/capped_neighbor_sampler.py b/python/dgl/dataloading/capped_neighbor_sampler.py index e5749f536324..70e1d9cc6636 100644 --- a/python/dgl/dataloading/capped_neighbor_sampler.py +++ b/python/dgl/dataloading/capped_neighbor_sampler.py @@ -1,4 +1,4 @@ -"""Fixed subgraph sampler.""" +"""Capped neighbor sampler.""" from collections import defaultdict import numpy as np