From 1b5c02c16cfd10e84e6b5d2c9f1f11094d82d620 Mon Sep 17 00:00:00 2001 From: Rhett Ying <85214957+Rhett-Ying@users.noreply.github.com> Date: Fri, 10 May 2024 08:16:17 +0800 Subject: [PATCH] [release] cherry-pick from master and release for 2.2.1 (#7388) Co-authored-by: Muhammed Fatih BALIN Co-authored-by: Xinyu Yao <77922129+yxy235@users.noreply.github.com> Co-authored-by: Ubuntu --- .../multigpu/graphbolt/node_classification.py | 2 +- .../pyg/node_classification_advanced.py | 12 ++++++ python/dgl/graphbolt/base.py | 12 ++++++ python/dgl/graphbolt/feature_fetcher.py | 28 +++++++------- .../dgl/graphbolt/impl/gpu_cached_feature.py | 37 +++++++++++++++---- script/dgl_dev.yml.template | 1 + .../impl/test_fused_csc_sampling_graph.py | 34 +++++++++++------ .../graphbolt/impl/test_gpu_cached_feature.py | 7 ++++ .../pytorch/graphbolt/test_feature_fetcher.py | 29 +++++++++++++-- third_party/cccl | 2 +- 10 files changed, 124 insertions(+), 40 deletions(-) diff --git a/examples/multigpu/graphbolt/node_classification.py b/examples/multigpu/graphbolt/node_classification.py index 3df09bf852ea..3ab8ed41a839 100644 --- a/examples/multigpu/graphbolt/node_classification.py +++ b/examples/multigpu/graphbolt/node_classification.py @@ -399,7 +399,7 @@ def parse_args(): "--gpu-cache-size", type=int, default=0, - help="The capacity of the GPU cache, the number of features to store.", + help="The capacity of the GPU cache in bytes.", ) parser.add_argument( "--dataset", diff --git a/examples/sampling/graphbolt/pyg/node_classification_advanced.py b/examples/sampling/graphbolt/pyg/node_classification_advanced.py index 2b5fb19d7518..2f25db523b56 100644 --- a/examples/sampling/graphbolt/pyg/node_classification_advanced.py +++ b/examples/sampling/graphbolt/pyg/node_classification_advanced.py @@ -350,6 +350,12 @@ def parse_args(): help="Graph storage - feature storage - Train device: 'cpu' for CPU and RAM," " 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.", ) + parser.add_argument( + "--gpu-cache-size", + type=int, + default=0, + help="The capacity of the GPU cache in bytes.", + ) parser.add_argument( "--sample-mode", default="sample_neighbor", @@ -403,6 +409,12 @@ def main(): num_classes = dataset.tasks[0].metadata["num_classes"] + if args.gpu_cache_size > 0 and args.feature_device != "cuda": + features._features[("node", None, "feat")] = gb.GPUCachedFeature( + features._features[("node", None, "feat")], + args.gpu_cache_size, + ) + train_dataloader, valid_dataloader = ( create_dataloader( graph=graph, diff --git a/python/dgl/graphbolt/base.py b/python/dgl/graphbolt/base.py index 398ac31e5290..13f493756d48 100644 --- a/python/dgl/graphbolt/base.py +++ b/python/dgl/graphbolt/base.py @@ -5,6 +5,17 @@ import torch from torch.torch_version import TorchVersion + +if TorchVersion(torch.__version__) >= "2.3.0": + # [TODO][https://github.com/dmlc/dgl/issues/7387] Remove or refine below + # check. + # Due to https://github.com/dmlc/dgl/issues/7380, we need to check if dill + # is available before using it. + torch.utils.data.datapipes.utils.common.DILL_AVAILABLE = ( + torch.utils._import_utils.dill_available() + ) + +# pylint: disable=wrong-import-position from torch.utils.data import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -342,6 +353,7 @@ class CSCFormatBase: >>> print(csc_foramt_base) ... torch.tensor([1, 4, 2]) """ + indptr: torch.Tensor = None indices: torch.Tensor = None diff --git a/python/dgl/graphbolt/feature_fetcher.py b/python/dgl/graphbolt/feature_fetcher.py index 01ff25af8c15..dc41e3883890 100644 --- a/python/dgl/graphbolt/feature_fetcher.py +++ b/python/dgl/graphbolt/feature_fetcher.py @@ -88,13 +88,12 @@ def record_stream(tensor): if self.node_feature_keys and input_nodes is not None: if is_heterogeneous: - for type_name, feature_names in self.node_feature_keys.items(): - nodes = input_nodes[type_name] - if nodes is None: + for type_name, nodes in input_nodes.items(): + if type_name not in self.node_feature_keys or nodes is None: continue if nodes.is_cuda: nodes.record_stream(torch.cuda.current_stream()) - for feature_name in feature_names: + for feature_name in self.node_feature_keys[type_name]: node_features[ (type_name, feature_name) ] = record_stream( @@ -126,21 +125,22 @@ def record_stream(tensor): if is_heterogeneous: # Convert edge type to string. original_edge_ids = { - etype_tuple_to_str(key) - if isinstance(key, tuple) - else key: value + ( + etype_tuple_to_str(key) + if isinstance(key, tuple) + else key + ): value for key, value in original_edge_ids.items() } - for ( - type_name, - feature_names, - ) in self.edge_feature_keys.items(): - edges = original_edge_ids.get(type_name, None) - if edges is None: + for type_name, edges in original_edge_ids.items(): + if ( + type_name not in self.edge_feature_keys + or edges is None + ): continue if edges.is_cuda: edges.record_stream(torch.cuda.current_stream()) - for feature_name in feature_names: + for feature_name in self.edge_feature_keys[type_name]: edge_features[i][ (type_name, feature_name) ] = record_stream( diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index 0be929ba4abf..e03402ad4162 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -8,6 +8,22 @@ __all__ = ["GPUCachedFeature"] +def nbytes(tensor): + """Returns the number of bytes to store the given tensor. + + Needs to be defined only for torch versions less than 2.1. In torch >= 2.1, + we can simply use "tensor.nbytes". + """ + return tensor.numel() * tensor.element_size() + + +def num_cache_items(cache_capacity_in_bytes, single_item): + """Returns the number of rows to be cached.""" + item_bytes = nbytes(single_item) + # Round up so that we never get a size of 0, unless bytes is 0. + return (cache_capacity_in_bytes + item_bytes - 1) // item_bytes + + class GPUCachedFeature(Feature): r"""GPU cached feature wrapping a fallback feature. @@ -17,8 +33,8 @@ class GPUCachedFeature(Feature): ---------- fallback_feature : Feature The fallback feature. - cache_size : int - The capacity of the GPU cache, the number of features to store. + max_cache_size_in_bytes : int + The capacity of the GPU cache in bytes. Examples -------- @@ -42,16 +58,17 @@ class GPUCachedFeature(Feature): torch.Size([5]) """ - def __init__(self, fallback_feature: Feature, cache_size: int): + def __init__(self, fallback_feature: Feature, max_cache_size_in_bytes: int): super(GPUCachedFeature, self).__init__() assert isinstance(fallback_feature, Feature), ( f"The fallback_feature must be an instance of Feature, but got " f"{type(fallback_feature)}." ) self._fallback_feature = fallback_feature - self.cache_size = cache_size + self.max_cache_size_in_bytes = max_cache_size_in_bytes # Fetching the feature dimension from the underlying feature. feat0 = fallback_feature.read(torch.tensor([0])) + cache_size = num_cache_items(max_cache_size_in_bytes, feat0) self._feature = GPUCache((cache_size,) + feat0.shape[1:], feat0.dtype) def read(self, ids: torch.Tensor = None): @@ -104,11 +121,15 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None): updated. """ if ids is None: + feat0 = value[:1] self._fallback_feature.update(value) - size = min(self.cache_size, value.shape[0]) - self._feature.replace( - torch.arange(0, size, device="cuda"), - value[:size].to("cuda"), + cache_size = min( + num_cache_items(self.max_cache_size_in_bytes, feat0), + value.shape[0], + ) + self._feature = None # Destroy the existing cache first. + self._feature = GPUCache( + (cache_size,) + feat0.shape[1:], feat0.dtype ) else: self._fallback_feature.update(value, ids) diff --git a/script/dgl_dev.yml.template b/script/dgl_dev.yml.template index 28c5f50eef10..3d71b9de4702 100644 --- a/script/dgl_dev.yml.template +++ b/script/dgl_dev.yml.template @@ -49,5 +49,6 @@ dependencies: - lintrunner - jupyterlab - ipywidgets + - expecttest variables: DGL_HOME: __DGL_HOME__ diff --git a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py index e4622deef010..13843f2886b1 100644 --- a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py +++ b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py @@ -1613,10 +1613,14 @@ def test_csc_sampling_graph_to_pinned_memory(): is_graph_pinned(graph) +@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64]) +@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("is_pinned", [False, True]) @pytest.mark.parametrize("nodes", [None, True]) -def test_sample_neighbors_homo(labor, is_pinned, nodes): +def test_sample_neighbors_homo( + indptr_dtype, indices_dtype, labor, is_pinned, nodes +): if is_pinned and nodes is None: pytest.skip("Optional nodes and is_pinned is not supported together.") """Original graph in COO: @@ -1630,8 +1634,10 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes): pytest.skip("Pinning is not meaningful without a GPU.") # Initialize data. total_num_edges = 12 - indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) - indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) + indptr = torch.tensor([0, 3, 5, 7, 9, 12], dtype=indptr_dtype) + indices = torch.tensor( + [0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4], dtype=indices_dtype + ) assert indptr[-1] == total_num_edges assert indptr[-1] == len(indices) @@ -1642,7 +1648,7 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes): # Generate subgraph via sample neighbors. if nodes: - nodes = torch.LongTensor([1, 3, 4]).to(F.ctx()) + nodes = torch.tensor([1, 3, 4], dtype=indices_dtype).to(F.ctx()) elif F._default_context_str != "gpu": pytest.skip("Optional nodes is supported only for the GPU.") sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors @@ -1662,8 +1668,10 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes): assert subgraph.original_edge_ids is None +@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64]) +@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("labor", [False, True]) -def test_sample_neighbors_hetero(labor): +def test_sample_neighbors_hetero(indptr_dtype, indices_dtype, labor): """Original graph in COO: "n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2] "n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0] @@ -1677,10 +1685,12 @@ def test_sample_neighbors_hetero(labor): ntypes = {"n1": 0, "n2": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} total_num_edges = 9 - indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) - indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) - type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0]) - node_type_offset = torch.LongTensor([0, 2, 5]) + indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype) + indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype) + type_per_edge = torch.tensor( + [1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=indices_dtype + ) + node_type_offset = torch.tensor([0, 2, 5], dtype=indices_dtype) assert indptr[-1] == total_num_edges assert indptr[-1] == len(indices) @@ -1696,8 +1706,8 @@ def test_sample_neighbors_hetero(labor): # Sample on both node types. nodes = { - "n1": torch.tensor([0], device=F.ctx()), - "n2": torch.tensor([0], device=F.ctx()), + "n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx()), + "n2": torch.tensor([0], dtype=indices_dtype, device=F.ctx()), } fanouts = torch.tensor([-1, -1]) sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors @@ -1725,7 +1735,7 @@ def test_sample_neighbors_hetero(labor): assert subgraph.original_edge_ids is None # Sample on single node type. - nodes = {"n1": torch.tensor([0], device=F.ctx())} + nodes = {"n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx())} fanouts = torch.tensor([-1, -1]) sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors subgraph = sampler(nodes, fanouts) diff --git a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py index eb9a62babff1..2a2c82fc7101 100644 --- a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py +++ b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py @@ -36,6 +36,9 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b): [[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype, pin_memory=True ) + cache_size_a *= a[:1].element_size() * a[:1].numel() + cache_size_b *= b[:1].element_size() * b[:1].numel() + feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), cache_size_a) feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), cache_size_b) @@ -94,3 +97,7 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b): feat_store_a.read(), torch.tensor([[2, 0, 1], [3, 5, 2]], dtype=dtype).to("cuda"), ) + + # Test with different dimensionality + feat_store_a.update(b) + assert torch.equal(feat_store_a.read(), b.to("cuda")) diff --git a/tests/python/pytorch/graphbolt/test_feature_fetcher.py b/tests/python/pytorch/graphbolt/test_feature_fetcher.py index b1944f06bc44..552a0bf5b055 100644 --- a/tests/python/pytorch/graphbolt/test_feature_fetcher.py +++ b/tests/python/pytorch/graphbolt/test_feature_fetcher.py @@ -160,12 +160,21 @@ def test_FeatureFetcher_hetero(): num_layer = 2 fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts) + # "n3" is not in the sampled input nodes. + node_feature_keys = {"n1": ["a"], "n2": ["a"], "n3": ["a"]} fetcher_dp = gb.FeatureFetcher( - sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]} + sampler_dp, feature_store, node_feature_keys=node_feature_keys ) - assert len(list(fetcher_dp)) == 3 + # Do not fetch feature for "n1". + node_feature_keys = {"n2": ["a"]} + fetcher_dp = gb.FeatureFetcher( + sampler_dp, feature_store, node_feature_keys=node_feature_keys + ) + for mini_batch in fetcher_dp: + assert ("n1", "a") not in mini_batch.node_features + def test_FeatureFetcher_with_edges_hetero(): a = torch.tensor([[random.randint(0, 10)] for _ in range(20)]) @@ -208,7 +217,11 @@ def add_node_and_edge_ids(minibatch): return data features = {} - keys = [("node", "n1", "a"), ("edge", "n1:e1:n2", "a")] + keys = [ + ("node", "n1", "a"), + ("edge", "n1:e1:n2", "a"), + ("edge", "n2:e2:n1", "a"), + ] features[keys[0]] = gb.TorchBasedFeature(a) features[keys[1]] = gb.TorchBasedFeature(b) feature_store = gb.BasicFeatureStore(features) @@ -220,8 +233,15 @@ def add_node_and_edge_ids(minibatch): ) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids) + # "n3:e3:n3" is not in the sampled edges. + # Do not fetch feature for "n2:e2:n1". + node_feature_keys = {"n1": ["a"]} + edge_feature_keys = {"n1:e1:n2": ["a"], "n3:e3:n3": ["a"]} fetcher_dp = gb.FeatureFetcher( - converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]} + converter_dp, + feature_store, + node_feature_keys=node_feature_keys, + edge_feature_keys=edge_feature_keys, ) assert len(list(fetcher_dp)) == 5 @@ -230,3 +250,4 @@ def add_node_and_edge_ids(minibatch): assert len(data.edge_features) == 3 for edge_feature in data.edge_features: assert edge_feature[("n1:e1:n2", "a")].size(0) == 10 + assert ("n2:e2:n1", "a") not in edge_feature diff --git a/third_party/cccl b/third_party/cccl index 64d3a5f0c1c8..1c009d23abf3 160000 --- a/third_party/cccl +++ b/third_party/cccl @@ -1 +1 @@ -Subproject commit 64d3a5f0c1c83ed83be8c0a9a1f0cdb31f913e81 +Subproject commit 1c009d23abf3e6c13d5e1f0ee54222c43b2c1785