Skip to content

Commit

Permalink
modify and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Feb 6, 2024
1 parent 845864d commit 8617a24
Show file tree
Hide file tree
Showing 4 changed files with 325 additions and 18 deletions.
8 changes: 7 additions & 1 deletion python/dgl/graphbolt/sampled_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def exclude_edges(
self,
edges: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Dict[str, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor],
torch.Tensor,
],
assume_num_node_within_int32: bool = True,
):
Expand Down Expand Up @@ -183,7 +185,7 @@ def exclude_edges(
), "Values > int32 are not supported yet."
assert (
isinstance(self.sampled_csc, (CSCFormatBase, tuple))
) == isinstance(edges, tuple), (
) == isinstance(edges, (tuple, torch.Tensor)), (
"The sampled subgraph and the edges to exclude should be both "
"homogeneous or both heterogeneous."
)
Expand All @@ -200,6 +202,8 @@ def exclude_edges(
self.original_row_node_ids,
self.original_column_node_ids,
)
if isinstance(edges, torch.Tensor):
edges = edges.T
index = _exclude_homo_edges(
reverse_edges, edges, assume_num_node_within_int32
)
Expand Down Expand Up @@ -227,6 +231,8 @@ def exclude_edges(
original_row_node_ids,
original_column_node_ids,
)
if isinstance(edges[etype], torch.Tensor):
edges[etype] = edges[etype].T
index[etype] = _exclude_homo_edges(
reverse_edges,
edges[etype],
Expand Down
85 changes: 80 additions & 5 deletions python/dgl/graphbolt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,74 @@ def add_reverse_edges(
return combined_edges


def add_reverse_edges_2(
edges: Union[Dict[str, torch.Tensor], torch.Tensor],
reverse_etypes_mapping: Dict[str, str] = None,
):
r"""
This function finds the reverse edges of the given `edges` and returns the
composition of them. In a homogeneous graph, reverse edges have inverted
source and destination node IDs. While in a heterogeneous graph, reversing
also involves swapping node IDs and their types. This function could be
used before `exclude_edges` function to help find targeting edges.
Note: The found reverse edges may not really exists in the original graph.
And repeat edges could be added becasue reverse edges may already exists in
the `edges`.
Parameters
----------
edges : Union[Dict[str, torch.Tensor], torch.Tensor]
- If sampled subgraph is homogeneous, then `edges` should be a N*2
tensors.
- If sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude.
reverse_etypes_mapping : Dict[str, str], optional
The mapping from the original edge types to their reverse edge types.
Returns
-------
Union[Dict[str, torch.Tensor], torch.Tensor]
The node pairs contain both the original edges and their reverse
counterparts.
Examples
--------
>>> edges = {"A:r:B": torch.tensor([[0, 1],[1, 2]]))}
>>> print(gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"}))
{'A:r:B': torch.tensor([[0, 1],[1, 2]]),
'B:rr:A': torch.tensor([[1, 0],[2, 1]])}
>>> edges = torch.tensor([[0, 1],[1, 2]])
>>> print(gb.add_reverse_edges(edges))
torch.tensor([[1, 0],[2, 1]])
"""
if isinstance(edges, torch.Tensor):
assert edges.ndim == 2 and edges.shape[1] == 2, (
"Only tensor with shape N*2 is supported now, but got "
+ f"{edges.shape}."
)
reverse_edges = edges.flip(dims=(1,))
return torch.cat((edges, reverse_edges))
else:
combined_edges = edges.copy()
for etype, reverse_etype in reverse_etypes_mapping.items():
if etype in edges:
assert edges[etype].ndim == 2 and edges[etype].shape[1] == 2, (
"Only tensor with shape N*2 is supported now, but got "
+ f"{edges[etype].shape}."
)
if reverse_etype in combined_edges:
combined_edges[reverse_etype] = torch.cat(
(
combined_edges[reverse_etype],
edges[etype].flip(dims=(1,)),
)
)
else:
combined_edges[reverse_etype] = edges[etype].flip(dims=(1,))
return combined_edges


def exclude_seed_edges(
minibatch: MiniBatch,
include_reverse_edges: bool = False,
Expand All @@ -89,11 +157,18 @@ def exclude_seed_edges(
reverse_etypes_mapping : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
"""
edges_to_exclude = minibatch.node_pairs
if include_reverse_edges:
edges_to_exclude = add_reverse_edges(
minibatch.node_pairs, reverse_etypes_mapping
)
if minibatch.node_pairs is not None:
edges_to_exclude = minibatch.node_pairs
if include_reverse_edges:
edges_to_exclude = add_reverse_edges(
minibatch.node_pairs, reverse_etypes_mapping
)
else:
edges_to_exclude = minibatch.seeds
if include_reverse_edges:
edges_to_exclude = add_reverse_edges_2(
edges_to_exclude, reverse_etypes_mapping
)
minibatch.sampled_subgraphs = [
subgraph.exclude_edges(edges_to_exclude)
for subgraph in minibatch.sampled_subgraphs
Expand Down
232 changes: 232 additions & 0 deletions tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,238 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):
_assert_container_equal(result.original_edge_ids, expected_edge_ids)


@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_homo_deduplicated_tensor(reverse_row, reverse_column):
csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 2, 2, 3]), indices=torch.tensor([0, 3, 2])
)
if reverse_row:
original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
src_to_exclude = torch.tensor([11])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2])

if reverse_column:
original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
dst_to_exclude = torch.tensor([9])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([4])
original_edge_ids = torch.Tensor([5, 9, 10])
subgraph = SampledSubgraphImpl(
csc_formats,
original_column_node_ids,
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(1, -1)
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 2, 2, 2]), indices=torch.tensor([0, 3])
)
if reverse_row:
expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 9])

_assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)


@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_homo_duplicated_tensor(reverse_row, reverse_column):
csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 3, 3, 5]),
indices=torch.tensor([0, 3, 3, 2, 2]),
)
if reverse_row:
original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
src_to_exclude = torch.tensor([24])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([3])

if reverse_column:
original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
dst_to_exclude = torch.tensor([11])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([2])
original_edge_ids = torch.Tensor([5, 9, 9, 10, 10])
subgraph = SampledSubgraphImpl(
csc_formats,
original_column_node_ids,
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(1, -1)
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 1, 1, 3]), indices=torch.tensor([0, 2, 2])
)
if reverse_row:
expected_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
else:
expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 10, 10])
_assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)


@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero_deduplicated_tensor(reverse_row, reverse_column):
csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]),
indices=torch.tensor([2, 1, 0]),
)
}
if reverse_row:
original_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
src_to_exclude = torch.tensor([15, 13])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2, 0])
if reverse_column:
original_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
dst_to_exclude = torch.tensor([10, 12])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl(
sampled_csc=csc_formats,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
)

edges_to_exclude = {
"A:relation:B": torch.cat((src_to_exclude, dst_to_exclude))
.view(2, -1)
.T
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 1]),
indices=torch.tensor([1]),
)
}
if reverse_row:
expected_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
else:
expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20])}

_assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)


@pytest.mark.parametrize("reverse_row", [True, False])
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero_duplicated_tensor(reverse_row, reverse_column):
csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 5]),
indices=torch.tensor([2, 2, 1, 1, 0]),
)
}
if reverse_row:
original_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
src_to_exclude = torch.tensor([15, 13])
else:
original_row_node_ids = None
src_to_exclude = torch.tensor([2, 0])
if reverse_column:
original_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
dst_to_exclude = torch.tensor([10, 12])
else:
original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 19, 20, 20, 21])}
subgraph = SampledSubgraphImpl(
sampled_csc=csc_formats,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
)

edges_to_exclude = {
"A:relation:B": torch.cat((src_to_exclude, dst_to_exclude))
.view(2, -1)
.T
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = {
"A:relation:B": gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 2, 2]),
indices=torch.tensor([1, 1]),
)
}
if reverse_row:
expected_row_node_ids = {
"A": torch.tensor([13, 14, 15]),
}
else:
expected_row_node_ids = None
if reverse_column:
expected_column_node_ids = {
"B": torch.tensor([10, 11, 12]),
}
else:
expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20, 20])}

_assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
_assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.original_edge_ids, expected_edge_ids)


@unittest.skipIf(
F._default_context_str == "cpu",
reason="`to` function needs GPU to test.",
Expand Down
Loading

0 comments on commit 8617a24

Please sign in to comment.