Skip to content

Commit

Permalink
simplify the code.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Sep 5, 2024
1 parent a789b6b commit 58e2354
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions graphbolt/src/cuda/cooperative_minibatching_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,18 @@ RankSortImpl(
auto nodes_sorted = torch::empty_like(nodes);
auto index = torch::arange(nodes.numel(), nodes.options());
auto index_sorted = torch::empty_like(index);
auto offsets = torch::empty(
num_batches * world_size + 1, c10::TensorOptions()
.dtype(offsets_dev.scalar_type())
.pinned_memory(true));
at::cuda::CUDAEvent offsets_event;
AT_DISPATCH_INDEX_TYPES(
return AT_DISPATCH_INDEX_TYPES(
nodes.scalar_type(), "RankSortImpl", ([&] {
CUB_CALL(
DeviceSegmentedRadixSort::SortPairs,
part_ids.data_ptr<cuda::part_t>(),
part_ids_sorted.data_ptr<cuda::part_t>(), nodes.data_ptr<index_t>(),
nodes_sorted.data_ptr<index_t>(), nodes.numel(), num_batches,
offsets_dev_ptr, offsets_dev_ptr + 1, 0, num_bits);
auto offsets = torch::empty(
num_batches * world_size + 1, c10::TensorOptions()
.dtype(offsets_dev.scalar_type())
.pinned_memory(true));
CUB_CALL(
DeviceFor::Bulk, num_batches * world_size + 1,
[=, part_ids = part_ids_sorted.data_ptr<cuda::part_t>(),
Expand All @@ -89,6 +88,7 @@ RankSortImpl(
offset_end - offset_begin, rank) +
offset_begin;
});
at::cuda::CUDAEvent offsets_event;
offsets_event.record();
CUB_CALL(
DeviceSegmentedRadixSort::SortPairs,
Expand All @@ -97,8 +97,8 @@ RankSortImpl(
index.data_ptr<index_t>(), index_sorted.data_ptr<index_t>(),
nodes.numel(), num_batches, offsets_dev_ptr, offsets_dev_ptr + 1, 0,
num_bits);
return {nodes_sorted, index_sorted, offsets, std::move(offsets_event)};
}));
return {nodes_sorted, index_sorted, offsets, std::move(offsets_event)};
}

std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort(
Expand Down

0 comments on commit 58e2354

Please sign in to comment.