diff --git a/src/04kernel/cuda/src/gather.cu b/src/04kernel/cuda/src/gather.cu index f906d1778..dd0399352 100644 --- a/src/04kernel/cuda/src/gather.cu +++ b/src/04kernel/cuda/src/gather.cu @@ -20,9 +20,10 @@ namespace refactor::kernel::cuda { tid += step) { auto i = tid / batch, j = tid % batch; - auto index = __ldg(indices + i % midSizeO); + auto k = __ldg(indices + i % midSizeO); + auto quot = k >= 0 ? i / midSizeO : i / midSizeO + 1; optimizedMemcpy(unit * tid + output, - unit * (batch * (i / midSizeO * midSizeI + index) + j) + data, + unit * (batch * (quot * midSizeI + k) + j) + data, unit); } } diff --git a/src/04kernel/src/kernels/gather/cpu_kernel.cc b/src/04kernel/src/kernels/gather/cpu_kernel.cc index c848b2282..160711155 100644 --- a/src/04kernel/src/kernels/gather/cpu_kernel.cc +++ b/src/04kernel/src/kernels/gather/cpu_kernel.cc @@ -33,8 +33,9 @@ namespace refactor::kernel { int64_t k = info.idxType == DataType::I64 ? reinterpret_cast(inputs[1])[d.rem] : reinterpret_cast(inputs[1])[d.rem]; + auto quot = k >= 0 ? d.quot : d.quot + 1; std::memcpy(info.postfix * i + output, - info.postfix * (d.quot * info.midSizeI + k) + data, + info.postfix * (quot * info.midSizeI + k) + data, info.postfix); }); };