-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Ubuntu
committed
Feb 9, 2024
1 parent
8617a24
commit ffa3000
Showing
3 changed files
with
358 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,318 @@ | ||
import re | ||
import unittest | ||
|
||
from enum import Enum | ||
from functools import partial | ||
|
||
import backend as F | ||
|
||
import dgl | ||
import dgl.graphbolt as gb | ||
import pytest | ||
import torch | ||
from torchdata.datapipes.iter import Mapper | ||
|
||
from . import gb_test_utils | ||
|
||
|
||
def test_add_reverse_edges_homo(): | ||
edges = (torch.tensor([0, 1, 2, 3]), torch.tensor([4, 5, 6, 7])) | ||
combined_edges = gb.add_reverse_edges(edges) | ||
assert torch.equal( | ||
combined_edges[0], torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) | ||
) | ||
assert torch.equal( | ||
combined_edges[1], torch.tensor([4, 5, 6, 7, 0, 1, 2, 3]) | ||
) | ||
|
||
|
||
def test_add_reverse_edges_hetero(): | ||
# reverse_etype doesn't exist in original etypes. | ||
edges = {"n1:e1:n2": (torch.tensor([0, 1, 2]), torch.tensor([4, 5, 6]))} | ||
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"} | ||
combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping) | ||
assert torch.equal(combined_edges["n1:e1:n2"][0], torch.tensor([0, 1, 2])) | ||
assert torch.equal(combined_edges["n1:e1:n2"][1], torch.tensor([4, 5, 6])) | ||
assert torch.equal(combined_edges["n2:e2:n1"][0], torch.tensor([4, 5, 6])) | ||
assert torch.equal(combined_edges["n2:e2:n1"][1], torch.tensor([0, 1, 2])) | ||
# reverse_etype exists in original etypes. | ||
edges = { | ||
"n1:e1:n2": (torch.tensor([0, 1, 2]), torch.tensor([4, 5, 6])), | ||
"n2:e2:n1": (torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])), | ||
} | ||
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"} | ||
combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping) | ||
assert torch.equal(combined_edges["n1:e1:n2"][0], torch.tensor([0, 1, 2])) | ||
assert torch.equal(combined_edges["n1:e1:n2"][1], torch.tensor([4, 5, 6])) | ||
assert torch.equal( | ||
combined_edges["n2:e2:n1"][0], torch.tensor([7, 8, 9, 4, 5, 6]) | ||
) | ||
assert torch.equal( | ||
combined_edges["n2:e2:n1"][1], torch.tensor([10, 11, 12, 0, 1, 2]) | ||
) | ||
|
||
|
||
def test_add_reverse_edges_2_homo(): | ||
edges = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).T | ||
combined_edges = gb.add_reverse_edges_2(edges) | ||
assert torch.equal( | ||
combined_edges, | ||
torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7], [4, 5, 6, 7, 0, 1, 2, 3]]).T, | ||
) | ||
# Tensor with uncorrect dimensions. | ||
edges = torch.tensor([0, 1, 2, 3]) | ||
with pytest.raises( | ||
AssertionError, | ||
match=re.escape( | ||
"Only tensor with shape N*2 is supported now, but got torch.Size([4])." | ||
), | ||
): | ||
gb.add_reverse_edges_2(edges) | ||
|
||
|
||
def test_add_reverse_edges_2_hetero(): | ||
# reverse_etype doesn't exist in original etypes. | ||
edges = {"n1:e1:n2": torch.tensor([[0, 1, 2], [4, 5, 6]]).T} | ||
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"} | ||
combined_edges = gb.add_reverse_edges_2(edges, reverse_etype_mapping) | ||
assert torch.equal( | ||
combined_edges["n1:e1:n2"], torch.tensor([[0, 1, 2], [4, 5, 6]]).T | ||
) | ||
assert torch.equal( | ||
combined_edges["n2:e2:n1"], torch.tensor([[4, 5, 6], [0, 1, 2]]).T | ||
) | ||
# reverse_etype exists in original etypes. | ||
edges = { | ||
"n1:e1:n2": torch.tensor([[0, 1, 2], [4, 5, 6]]).T, | ||
"n2:e2:n1": torch.tensor([[7, 8, 9], [10, 11, 12]]).T, | ||
} | ||
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"} | ||
combined_edges = gb.add_reverse_edges_2(edges, reverse_etype_mapping) | ||
assert torch.equal( | ||
combined_edges["n1:e1:n2"], torch.tensor([[0, 1, 2], [4, 5, 6]]).T | ||
) | ||
assert torch.equal( | ||
combined_edges["n2:e2:n1"], | ||
torch.tensor([[7, 8, 9, 4, 5, 6], [10, 11, 12, 0, 1, 2]]).T, | ||
) | ||
# Tensor with uncorrect dimensions. | ||
edges = { | ||
"n1:e1:n2": torch.tensor([0, 1, 2]), | ||
"n2:e2:n1": torch.tensor([7, 8, 9]), | ||
} | ||
with pytest.raises( | ||
AssertionError, | ||
match=re.escape( | ||
"Only tensor with shape N*2 is supported now, but got torch.Size([3])." | ||
), | ||
): | ||
gb.add_reverse_edges_2(edges, reverse_etype_mapping) | ||
|
||
|
||
@unittest.skipIf( | ||
F._default_context_str == "gpu", | ||
reason="Fails due to different result on the GPU.", | ||
) | ||
def test_exclude_seed_edges_homo_cpu(): | ||
graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4])) | ||
graph = gb.from_dglgraph(graph, True).to(F.ctx()) | ||
items = torch.LongTensor([[0, 3], [4, 4]]) | ||
names = "seeds" | ||
itemset = gb.ItemSet(items, names=names) | ||
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) | ||
num_layer = 2 | ||
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] | ||
sampler = gb.NeighborSampler | ||
datapipe = sampler(datapipe, graph, fanouts) | ||
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) | ||
original_row_node_ids = [ | ||
torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()), | ||
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()), | ||
] | ||
compacted_indices = [ | ||
torch.tensor([3, 4, 4, 5, 6]).to(F.ctx()), | ||
torch.tensor([3, 4, 4]).to(F.ctx()), | ||
] | ||
indptr = [ | ||
torch.tensor([0, 1, 2, 3, 3, 5]).to(F.ctx()), | ||
torch.tensor([0, 1, 2, 3]).to(F.ctx()), | ||
] | ||
seeds = [ | ||
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()), | ||
torch.tensor([0, 3, 4]).to(F.ctx()), | ||
] | ||
for data in datapipe: | ||
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): | ||
assert torch.equal( | ||
sampled_subgraph.original_row_node_ids, | ||
original_row_node_ids[step], | ||
) | ||
assert torch.equal( | ||
sampled_subgraph.sampled_csc.indices, compacted_indices[step] | ||
) | ||
assert torch.equal( | ||
sampled_subgraph.sampled_csc.indptr, indptr[step] | ||
) | ||
assert torch.equal( | ||
sampled_subgraph.original_column_node_ids, seeds[step] | ||
) | ||
|
||
|
||
@unittest.skipIf( | ||
F._default_context_str == "cpu", | ||
reason="Fails due to different result on the CPU.", | ||
) | ||
def test_exclude_seed_edges_gpu(): | ||
graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4])) | ||
graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx()) | ||
items = torch.LongTensor([[0, 3], [4, 4]]) | ||
names = "seeds" | ||
itemset = gb.ItemSet(items, names=names) | ||
datapipe = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx()) | ||
num_layer = 2 | ||
fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)] | ||
sampler = gb.NeighborSampler | ||
datapipe = sampler( | ||
datapipe, | ||
graph, | ||
fanouts, | ||
deduplicate=True, | ||
) | ||
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) | ||
original_row_node_ids = [ | ||
torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()), | ||
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()), | ||
] | ||
compacted_indices = [ | ||
torch.tensor([4, 3, 5, 5]).to(F.ctx()), | ||
torch.tensor([4, 3]).to(F.ctx()), | ||
] | ||
indptr = [ | ||
torch.tensor([0, 1, 2, 2, 4, 4]).to(F.ctx()), | ||
torch.tensor([0, 1, 2, 2]).to(F.ctx()), | ||
] | ||
seeds = [ | ||
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()), | ||
torch.tensor([0, 3, 4]).to(F.ctx()), | ||
] | ||
for data in datapipe: | ||
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): | ||
assert torch.equal( | ||
sampled_subgraph.original_row_node_ids, | ||
original_row_node_ids[step], | ||
) | ||
assert torch.equal( | ||
(sampled_subgraph.sampled_csc.indices), compacted_indices[step] | ||
) | ||
assert torch.equal( | ||
sampled_subgraph.sampled_csc.indptr, indptr[step] | ||
) | ||
assert torch.equal( | ||
sampled_subgraph.original_column_node_ids, seeds[step] | ||
) | ||
|
||
|
||
def get_hetero_graph(): | ||
# COO graph: | ||
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4] | ||
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1] | ||
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type. | ||
# num_nodes = 5, num_n1 = 2, num_n2 = 3 | ||
ntypes = {"n1": 0, "n2": 1} | ||
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} | ||
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10]) | ||
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) | ||
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) | ||
node_type_offset = torch.LongTensor([0, 2, 5]) | ||
return gb.fused_csc_sampling_graph( | ||
indptr, | ||
indices, | ||
node_type_offset=node_type_offset, | ||
type_per_edge=type_per_edge, | ||
node_type_to_id=ntypes, | ||
edge_type_to_id=etypes, | ||
) | ||
|
||
|
||
def test_exclude_seed_edges_hetero(): | ||
graph = get_hetero_graph().to(F.ctx()) | ||
itemset = gb.ItemSetDict( | ||
{"n1:e1:n2": gb.ItemSet(torch.tensor([[0, 1]]), names="seeds")} | ||
) | ||
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx()) | ||
num_layer = 2 | ||
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] | ||
Sampler = gb.NeighborSampler | ||
datapipe = Sampler( | ||
item_sampler, | ||
graph, | ||
fanouts, | ||
deduplicate=True, | ||
) | ||
datapipe = datapipe.transform(partial(gb.exclude_seed_edges)) | ||
csc_formats = [ | ||
{ | ||
"n1:e1:n2": gb.CSCFormatBase( | ||
indptr=torch.tensor([0, 1, 3, 5]), | ||
indices=torch.tensor([1, 0, 1, 0, 1]), | ||
), | ||
"n2:e2:n1": gb.CSCFormatBase( | ||
indptr=torch.tensor([0, 2, 4]), | ||
indices=torch.tensor([1, 2, 1, 0]), | ||
), | ||
}, | ||
{ | ||
"n1:e1:n2": gb.CSCFormatBase( | ||
indptr=torch.tensor([0, 1]), | ||
indices=torch.tensor([1]), | ||
), | ||
"n2:e2:n1": gb.CSCFormatBase( | ||
indptr=torch.tensor([0, 2]), | ||
indices=torch.tensor([1, 2], dtype=torch.int64), | ||
), | ||
}, | ||
] | ||
original_column_node_ids = [ | ||
{ | ||
"n1": torch.tensor([0, 1]), | ||
"n2": torch.tensor([0, 1, 2]), | ||
}, | ||
{ | ||
"n1": torch.tensor([0]), | ||
"n2": torch.tensor([1]), | ||
}, | ||
] | ||
original_row_node_ids = [ | ||
{ | ||
"n1": torch.tensor([0, 1]), | ||
"n2": torch.tensor([0, 1, 2]), | ||
}, | ||
{ | ||
"n1": torch.tensor([0, 1]), | ||
"n2": torch.tensor([0, 1, 2]), | ||
}, | ||
] | ||
for data in datapipe: | ||
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): | ||
for ntype in ["n1", "n2"]: | ||
assert torch.equal( | ||
torch.sort(sampled_subgraph.original_row_node_ids[ntype])[ | ||
0 | ||
], | ||
original_row_node_ids[step][ntype].to(F.ctx()), | ||
) | ||
assert torch.equal( | ||
torch.sort( | ||
sampled_subgraph.original_column_node_ids[ntype] | ||
)[0], | ||
original_column_node_ids[step][ntype].to(F.ctx()), | ||
) | ||
for etype in ["n1:e1:n2", "n2:e2:n1"]: | ||
assert torch.equal( | ||
sampled_subgraph.sampled_csc[etype].indices, | ||
csc_formats[step][etype].indices.to(F.ctx()), | ||
) | ||
assert torch.equal( | ||
sampled_subgraph.sampled_csc[etype].indptr, | ||
csc_formats[step][etype].indptr.to(F.ctx()), | ||
) |