diff --git a/graphbolt/src/cuda/extension/unique_and_compact_map.cu b/graphbolt/src/cuda/extension/unique_and_compact_map.cu index a36c63925d7f..3db918ee7fdb 100644 --- a/graphbolt/src/cuda/extension/unique_and_compact_map.cu +++ b/graphbolt/src/cuda/extension/unique_and_compact_map.cu @@ -284,14 +284,18 @@ UniqueAndCompactBatchedHashMapBased( unique_ids_offsets_dev.data_ptr(); } at::cuda::CUDAEvent unique_ids_offsets_event; + unique_ids_offsets_event.record(); torch::optional index; if (part_ids) { + unique_ids_offsets_event.synchronize(); + const auto num_unique = + unique_ids_offsets.data_ptr()[num_batches]; + unique_ids = unique_ids.slice(0, 0, num_unique); + part_ids = part_ids->slice(0, 0, num_unique); std::tie( unique_ids, index, unique_ids_offsets, unique_ids_offsets_event) = cuda::RankSortImpl( unique_ids, *part_ids, unique_ids_offsets_dev, world_size); - } else { - unique_ids_offsets_event.record(); } auto mapped_ids = torch::empty(offsets_ptr[3 * num_batches], unique_ids.options()); diff --git a/python/dgl/graphbolt/impl/cooperative_conv.py b/python/dgl/graphbolt/impl/cooperative_conv.py index cb3d39d4d980..8040dabd286c 100644 --- a/python/dgl/graphbolt/impl/cooperative_conv.py +++ b/python/dgl/graphbolt/impl/cooperative_conv.py @@ -35,8 +35,11 @@ def forward( counts_received = convert_to_hetero(subgraph._counts_received) seed_inverse_ids = convert_to_hetero(subgraph._seed_inverse_ids) seed_sizes = convert_to_hetero(subgraph._seed_sizes) - ctx.save_for_backward( - counts_sent, counts_received, seed_inverse_ids, seed_sizes + ctx.communication_variables = ( + counts_sent, + counts_received, + seed_inverse_ids, + seed_sizes, ) outs = {} for ntype, typed_tensor in convert_to_hetero(tensor).items(): @@ -63,7 +66,8 @@ def backward( counts_received, seed_inverse_ids, seed_sizes, - ) = ctx.saved_tensors + ) = ctx.communication_variables + delattr(ctx, "communication_variables") outs = {} for ntype, typed_grad_output in convert_to_hetero(grad_output).items(): out = typed_grad_output.new_empty( @@ -79,7 +83,11 @@ def backward( ) # src i[1] = seed_inverse_ids[ntype] # dst coo = torch.sparse_coo_tensor( - i, 1, size=(seed_sizes[ntype], i.shape[1]) + i, + torch.ones( + i.shape[1], dtype=grad_output.dtype, device=i.device + ), + size=(seed_sizes[ntype], i.shape[1]), ) outs[ntype] = torch.sparse.mm(coo, out) return None, revert_to_homo(outs) diff --git a/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py b/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py index f85676578bd5..1de8669b0e08 100644 --- a/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py +++ b/tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py @@ -19,9 +19,10 @@ @pytest.mark.parametrize("dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("rank", list(range(WORLD_SIZE))) def test_gpu_cached_feature_read_async(dtype, rank): + torch.manual_seed(7) nodes_list1 = [ - torch.randint(0, 11111111, [777], dtype=dtype, device=F.ctx()) - for i in range(10) + torch.randint(0, 2111111111, [777], dtype=dtype, device=F.ctx()) + for _ in range(10) ] nodes_list2 = [nodes.sort()[0] for nodes in nodes_list1] @@ -57,3 +58,13 @@ def test_gpu_cached_feature_read_async(dtype, rank): assert_equal( idx1[off1[j] : off1[j + 1]], idx4[off4[i] : off4[i + 1]] ) + + unique, compacted, offsets = gb.unique_and_compact( + nodes_list1[:1], rank, WORLD_SIZE + ) + + nodes1, idx1, offsets1 = res1[0] + + assert_equal(unique, nodes1) + assert_equal(compacted[0], idx1) + assert_equal(offsets, offsets1)