From 58e2354f9f313c19e4eb0a24787db1aec90efc28 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Thu, 5 Sep 2024 01:01:10 -0400 Subject: [PATCH] simplify the code. --- .../src/cuda/cooperative_minibatching_utils.cu | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/graphbolt/src/cuda/cooperative_minibatching_utils.cu b/graphbolt/src/cuda/cooperative_minibatching_utils.cu index 50ac45048193..4822681544eb 100644 --- a/graphbolt/src/cuda/cooperative_minibatching_utils.cu +++ b/graphbolt/src/cuda/cooperative_minibatching_utils.cu @@ -62,12 +62,7 @@ RankSortImpl( auto nodes_sorted = torch::empty_like(nodes); auto index = torch::arange(nodes.numel(), nodes.options()); auto index_sorted = torch::empty_like(index); - auto offsets = torch::empty( - num_batches * world_size + 1, c10::TensorOptions() - .dtype(offsets_dev.scalar_type()) - .pinned_memory(true)); - at::cuda::CUDAEvent offsets_event; - AT_DISPATCH_INDEX_TYPES( + return AT_DISPATCH_INDEX_TYPES( nodes.scalar_type(), "RankSortImpl", ([&] { CUB_CALL( DeviceSegmentedRadixSort::SortPairs, @@ -75,6 +70,10 @@ RankSortImpl( part_ids_sorted.data_ptr(), nodes.data_ptr(), nodes_sorted.data_ptr(), nodes.numel(), num_batches, offsets_dev_ptr, offsets_dev_ptr + 1, 0, num_bits); + auto offsets = torch::empty( + num_batches * world_size + 1, c10::TensorOptions() + .dtype(offsets_dev.scalar_type()) + .pinned_memory(true)); CUB_CALL( DeviceFor::Bulk, num_batches * world_size + 1, [=, part_ids = part_ids_sorted.data_ptr(), @@ -89,6 +88,7 @@ RankSortImpl( offset_end - offset_begin, rank) + offset_begin; }); + at::cuda::CUDAEvent offsets_event; offsets_event.record(); CUB_CALL( DeviceSegmentedRadixSort::SortPairs, @@ -97,8 +97,8 @@ RankSortImpl( index.data_ptr(), index_sorted.data_ptr(), nodes.numel(), num_batches, offsets_dev_ptr, offsets_dev_ptr + 1, 0, num_bits); + return {nodes_sorted, index_sorted, offsets, std::move(offsets_event)}; })); - return {nodes_sorted, index_sorted, offsets, std::move(offsets_event)}; } std::vector> RankSort(