Skip to content

Commit

Permalink
[Fix] CAGRA 'merge' API compilation error under CUDA 12.4
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Feb 14, 2025
1 parent ab91f93 commit ac5cc15
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions cpp/src/neighbors/detail/cagra/cagra_merge.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ index<T, IdxT> merge(raft::resources const& handle,
for (auto index : indices) {
RAFT_EXPECTS(index != nullptr,
"Null pointer detected in 'indices'. Ensure all elements are valid before usage.");
using ds_idx_type = decltype(index->data().n_rows());
if (auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index->data());
if (auto* strided_dset = dynamic_cast<const strided_dataset<T, int64_t>*>(&index->data());
strided_dset != nullptr) {
if (dim == 0) {
dim = index->dim();
Expand All @@ -75,8 +74,7 @@ index<T, IdxT> merge(raft::resources const& handle,

auto merge_dataset = [&](T* dst) {
for (auto index : indices) {
using ds_idx_type = decltype(index->data().n_rows());
auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index->data());
auto* strided_dset = dynamic_cast<const strided_dataset<T, int64_t>*>(&index->data());

RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst + offset * dim,
sizeof(T) * dim,
Expand Down

0 comments on commit ac5cc15

Please sign in to comment.