Skip to content

Commit

Permalink
[release] cherry-pick from master and release for 2.2.1 (#7388)
Browse files Browse the repository at this point in the history
Co-authored-by: Muhammed Fatih BALIN <m.f.balin@gmail.com>
Co-authored-by: Xinyu Yao <77922129+yxy235@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
  • Loading branch information
4 people authored May 10, 2024
1 parent 8873fb2 commit 1b5c02c
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 40 deletions.
2 changes: 1 addition & 1 deletion examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def parse_args():
"--gpu-cache-size",
type=int,
default=0,
help="The capacity of the GPU cache, the number of features to store.",
help="The capacity of the GPU cache in bytes.",
)
parser.add_argument(
"--dataset",
Expand Down
12 changes: 12 additions & 0 deletions examples/sampling/graphbolt/pyg/node_classification_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,12 @@ def parse_args():
help="Graph storage - feature storage - Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
parser.add_argument(
"--gpu-cache-size",
type=int,
default=0,
help="The capacity of the GPU cache in bytes.",
)
parser.add_argument(
"--sample-mode",
default="sample_neighbor",
Expand Down Expand Up @@ -403,6 +409,12 @@ def main():

num_classes = dataset.tasks[0].metadata["num_classes"]

if args.gpu_cache_size > 0 and args.feature_device != "cuda":
features._features[("node", None, "feat")] = gb.GPUCachedFeature(
features._features[("node", None, "feat")],
args.gpu_cache_size,
)

train_dataloader, valid_dataloader = (
create_dataloader(
graph=graph,
Expand Down
12 changes: 12 additions & 0 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@

import torch
from torch.torch_version import TorchVersion

if TorchVersion(torch.__version__) >= "2.3.0":
# [TODO][https://github.com/dmlc/dgl/issues/7387] Remove or refine below
# check.
# Due to https://github.com/dmlc/dgl/issues/7380, we need to check if dill
# is available before using it.
torch.utils.data.datapipes.utils.common.DILL_AVAILABLE = (
torch.utils._import_utils.dill_available()
)

# pylint: disable=wrong-import-position
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe

Expand Down Expand Up @@ -342,6 +353,7 @@ class CSCFormatBase:
>>> print(csc_foramt_base)
... torch.tensor([1, 4, 2])
"""

indptr: torch.Tensor = None
indices: torch.Tensor = None

Expand Down
28 changes: 14 additions & 14 deletions python/dgl/graphbolt/feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,12 @@ def record_stream(tensor):

if self.node_feature_keys and input_nodes is not None:
if is_heterogeneous:
for type_name, feature_names in self.node_feature_keys.items():
nodes = input_nodes[type_name]
if nodes is None:
for type_name, nodes in input_nodes.items():
if type_name not in self.node_feature_keys or nodes is None:
continue
if nodes.is_cuda:
nodes.record_stream(torch.cuda.current_stream())
for feature_name in feature_names:
for feature_name in self.node_feature_keys[type_name]:
node_features[
(type_name, feature_name)
] = record_stream(
Expand Down Expand Up @@ -126,21 +125,22 @@ def record_stream(tensor):
if is_heterogeneous:
# Convert edge type to string.
original_edge_ids = {
etype_tuple_to_str(key)
if isinstance(key, tuple)
else key: value
(
etype_tuple_to_str(key)
if isinstance(key, tuple)
else key
): value
for key, value in original_edge_ids.items()
}
for (
type_name,
feature_names,
) in self.edge_feature_keys.items():
edges = original_edge_ids.get(type_name, None)
if edges is None:
for type_name, edges in original_edge_ids.items():
if (
type_name not in self.edge_feature_keys
or edges is None
):
continue
if edges.is_cuda:
edges.record_stream(torch.cuda.current_stream())
for feature_name in feature_names:
for feature_name in self.edge_feature_keys[type_name]:
edge_features[i][
(type_name, feature_name)
] = record_stream(
Expand Down
37 changes: 29 additions & 8 deletions python/dgl/graphbolt/impl/gpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
__all__ = ["GPUCachedFeature"]


def nbytes(tensor):
"""Returns the number of bytes to store the given tensor.
Needs to be defined only for torch versions less than 2.1. In torch >= 2.1,
we can simply use "tensor.nbytes".
"""
return tensor.numel() * tensor.element_size()


def num_cache_items(cache_capacity_in_bytes, single_item):
"""Returns the number of rows to be cached."""
item_bytes = nbytes(single_item)
# Round up so that we never get a size of 0, unless bytes is 0.
return (cache_capacity_in_bytes + item_bytes - 1) // item_bytes


class GPUCachedFeature(Feature):
r"""GPU cached feature wrapping a fallback feature.
Expand All @@ -17,8 +33,8 @@ class GPUCachedFeature(Feature):
----------
fallback_feature : Feature
The fallback feature.
cache_size : int
The capacity of the GPU cache, the number of features to store.
max_cache_size_in_bytes : int
The capacity of the GPU cache in bytes.
Examples
--------
Expand All @@ -42,16 +58,17 @@ class GPUCachedFeature(Feature):
torch.Size([5])
"""

def __init__(self, fallback_feature: Feature, cache_size: int):
def __init__(self, fallback_feature: Feature, max_cache_size_in_bytes: int):
super(GPUCachedFeature, self).__init__()
assert isinstance(fallback_feature, Feature), (
f"The fallback_feature must be an instance of Feature, but got "
f"{type(fallback_feature)}."
)
self._fallback_feature = fallback_feature
self.cache_size = cache_size
self.max_cache_size_in_bytes = max_cache_size_in_bytes
# Fetching the feature dimension from the underlying feature.
feat0 = fallback_feature.read(torch.tensor([0]))
cache_size = num_cache_items(max_cache_size_in_bytes, feat0)
self._feature = GPUCache((cache_size,) + feat0.shape[1:], feat0.dtype)

def read(self, ids: torch.Tensor = None):
Expand Down Expand Up @@ -104,11 +121,15 @@ def update(self, value: torch.Tensor, ids: torch.Tensor = None):
updated.
"""
if ids is None:
feat0 = value[:1]
self._fallback_feature.update(value)
size = min(self.cache_size, value.shape[0])
self._feature.replace(
torch.arange(0, size, device="cuda"),
value[:size].to("cuda"),
cache_size = min(
num_cache_items(self.max_cache_size_in_bytes, feat0),
value.shape[0],
)
self._feature = None # Destroy the existing cache first.
self._feature = GPUCache(
(cache_size,) + feat0.shape[1:], feat0.dtype
)
else:
self._fallback_feature.update(value, ids)
Expand Down
1 change: 1 addition & 0 deletions script/dgl_dev.yml.template
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,6 @@ dependencies:
- lintrunner
- jupyterlab
- ipywidgets
- expecttest
variables:
DGL_HOME: __DGL_HOME__
Original file line number Diff line number Diff line change
Expand Up @@ -1613,10 +1613,14 @@ def test_csc_sampling_graph_to_pinned_memory():
is_graph_pinned(graph)


@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("is_pinned", [False, True])
@pytest.mark.parametrize("nodes", [None, True])
def test_sample_neighbors_homo(labor, is_pinned, nodes):
def test_sample_neighbors_homo(
indptr_dtype, indices_dtype, labor, is_pinned, nodes
):
if is_pinned and nodes is None:
pytest.skip("Optional nodes and is_pinned is not supported together.")
"""Original graph in COO:
Expand All @@ -1630,8 +1634,10 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes):
pytest.skip("Pinning is not meaningful without a GPU.")
# Initialize data.
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
indptr = torch.tensor([0, 3, 5, 7, 9, 12], dtype=indptr_dtype)
indices = torch.tensor(
[0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4], dtype=indices_dtype
)
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)

Expand All @@ -1642,7 +1648,7 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes):

# Generate subgraph via sample neighbors.
if nodes:
nodes = torch.LongTensor([1, 3, 4]).to(F.ctx())
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype).to(F.ctx())
elif F._default_context_str != "gpu":
pytest.skip("Optional nodes is supported only for the GPU.")
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
Expand All @@ -1662,8 +1668,10 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes):
assert subgraph.original_edge_ids is None


@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero(labor):
def test_sample_neighbors_hetero(indptr_dtype, indices_dtype, labor):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
Expand All @@ -1677,10 +1685,12 @@ def test_sample_neighbors_hetero(labor):
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
type_per_edge = torch.tensor(
[1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=indices_dtype
)
node_type_offset = torch.tensor([0, 2, 5], dtype=indices_dtype)
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)

Expand All @@ -1696,8 +1706,8 @@ def test_sample_neighbors_hetero(labor):

# Sample on both node types.
nodes = {
"n1": torch.tensor([0], device=F.ctx()),
"n2": torch.tensor([0], device=F.ctx()),
"n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
"n2": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
Expand Down Expand Up @@ -1725,7 +1735,7 @@ def test_sample_neighbors_hetero(labor):
assert subgraph.original_edge_ids is None

# Sample on single node type.
nodes = {"n1": torch.tensor([0], device=F.ctx())}
nodes = {"n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx())}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b):
[[[1, 2], [3, 4]], [[4, 5], [6, 7]]], dtype=dtype, pin_memory=True
)

cache_size_a *= a[:1].element_size() * a[:1].numel()
cache_size_b *= b[:1].element_size() * b[:1].numel()

feat_store_a = gb.GPUCachedFeature(gb.TorchBasedFeature(a), cache_size_a)
feat_store_b = gb.GPUCachedFeature(gb.TorchBasedFeature(b), cache_size_b)

Expand Down Expand Up @@ -94,3 +97,7 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b):
feat_store_a.read(),
torch.tensor([[2, 0, 1], [3, 5, 2]], dtype=dtype).to("cuda"),
)

# Test with different dimensionality
feat_store_a.update(b)
assert torch.equal(feat_store_a.read(), b.to("cuda"))
29 changes: 25 additions & 4 deletions tests/python/pytorch/graphbolt/test_feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,21 @@ def test_FeatureFetcher_hetero():
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
# "n3" is not in the sampled input nodes.
node_feature_keys = {"n1": ["a"], "n2": ["a"], "n3": ["a"]}
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
sampler_dp, feature_store, node_feature_keys=node_feature_keys
)

assert len(list(fetcher_dp)) == 3

# Do not fetch feature for "n1".
node_feature_keys = {"n2": ["a"]}
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, node_feature_keys=node_feature_keys
)
for mini_batch in fetcher_dp:
assert ("n1", "a") not in mini_batch.node_features


def test_FeatureFetcher_with_edges_hetero():
a = torch.tensor([[random.randint(0, 10)] for _ in range(20)])
Expand Down Expand Up @@ -208,7 +217,11 @@ def add_node_and_edge_ids(minibatch):
return data

features = {}
keys = [("node", "n1", "a"), ("edge", "n1:e1:n2", "a")]
keys = [
("node", "n1", "a"),
("edge", "n1:e1:n2", "a"),
("edge", "n2:e2:n1", "a"),
]
features[keys[0]] = gb.TorchBasedFeature(a)
features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
Expand All @@ -220,8 +233,15 @@ def add_node_and_edge_ids(minibatch):
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
# "n3:e3:n3" is not in the sampled edges.
# Do not fetch feature for "n2:e2:n1".
node_feature_keys = {"n1": ["a"]}
edge_feature_keys = {"n1:e1:n2": ["a"], "n3:e3:n3": ["a"]}
fetcher_dp = gb.FeatureFetcher(
converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
converter_dp,
feature_store,
node_feature_keys=node_feature_keys,
edge_feature_keys=edge_feature_keys,
)

assert len(list(fetcher_dp)) == 5
Expand All @@ -230,3 +250,4 @@ def add_node_and_edge_ids(minibatch):
assert len(data.edge_features) == 3
for edge_feature in data.edge_features:
assert edge_feature[("n1:e1:n2", "a")].size(0) == 10
assert ("n2:e2:n1", "a") not in edge_feature
2 changes: 1 addition & 1 deletion third_party/cccl
Submodule cccl updated 9797 files

0 comments on commit 1b5c02c

Please sign in to comment.