Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Lifann authored and oppenheimli committed Aug 14, 2024
1 parent 28f6ffc commit 12ccbd3
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 25 deletions.
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,8 @@ TARGET_LINK_LIBRARIES(find_with_missed_keys_test gtest_main)
add_executable(reserved_keys_test tests/reserved_keys_test.cc.cu)
target_compile_features(reserved_keys_test PUBLIC cxx_std_14)
set_target_properties(reserved_keys_test PROPERTIES CUDA_ARCHITECTURES OFF)
TARGET_LINK_LIBRARIES(reserved_keys_test gtest_main)
TARGET_LINK_LIBRARIES(reserved_keys_test gtest_main)

add_executable(export_batch_if_test tests/export_batch_if_test.cc.cu)
target_compile_features(export_batch_if_test PUBLIC cxx_std_14)
set_target_properties(export_batch_if_test PROPERTIES CUDA_ARCHITECTURES OFF)
30 changes: 11 additions & 19 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,6 @@ __global__ void dump_kernel(const Table<K, V, S>* __restrict table,
}
}

/* Dump with score. */
template <class K, class V, class S,
template <typename, typename> class PredFunctor,
int TILE_SIZE>
Expand All @@ -924,31 +923,24 @@ __global__ void dump_kernel_v2(const Table<K, V, S>* __restrict table,
int dim = table->dim;
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());

__shared__ block_acc;
if (threadIdx.x == 0) {
block_acc = 0;
}
__syncthreads();

size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
size_t N = n * TILE_SIZE;
PredFunctor<K, S> pred;
size_t tid = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);

for (size_t ii = tid; ii < N; ii += gridDim.x * blockDim.x) {
size_t i = ii / TILE_SIZE;
for (size_t ii = tid; ii < search_length; ii += gridDim.x * blockDim.x) {
size_t bkt_idx = (ii + offset) / bucket_max_size;
int key_idx = (ii + offset) % bucket_max_size;
int leading_key_idx = key_idx % TILE_SIZE;
int leading_key_idx = key_idx / TILE_SIZE * TILE_SIZE;
Bucket<K, V, S>* bucket = &(buckets[bkt_idx]);

const K key =
(bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed);
S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed);
bool match = !IS_RESERVED_KEY<K>(key) && pred(key, score, pattern, threshold);
int vote = g.ballot(match);
bool match = (!IS_RESERVED_KEY<K>(key)) && pred(key, score, pattern, threshold);
unsigned int vote = g.ballot(match);
int tile_cnt = __popc(vote);
int tile_offset = 0;
if (g.rank() == 0) {
tile_offset = atomicAdd(d_dump_counter, static_cast<size_t>(tile_cnt));
if (g.thread_rank() == 0) {
tile_offset = static_cast<int>(atomicAdd(d_dump_counter, static_cast<size_t>(tile_cnt)));
}
tile_offset = g.shfl(tile_offset, 0);

Expand All @@ -962,10 +954,10 @@ __global__ void dump_kernel_v2(const Table<K, V, S>* __restrict table,
#pragma unroll
for (int r = 0; r < TILE_SIZE; r++) {
bool cur_match = vote >> r & 1;
if (match) {
if (cur_match) {
int cur_idx = leading_key_idx + r;
for (int j = g.rank(); j < dim; j += TILE_SIZE) {
d_val[(tile_offset + cur_idx) * dim + j] = bucket->vector[cur_idx * dim + j];
for (int j = g.thread_rank(); j < dim; j += TILE_SIZE) {
d_val[(tile_offset + cur_idx) * dim + j] = bucket->vectors[cur_idx * dim + j];
}
}
}
Expand Down
32 changes: 27 additions & 5 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2616,12 +2616,34 @@ class HashTable : public HashTableBase<K, V, S> {
return;
}

int grid_size = std::min(sm_cnt_, static_cast<int>(SAFE_GET_GRID_SIZE(n, options_.block_size)));
bool match_fast_cond = options_.max_bucket_size % TILE_SIZE == 0 \
&& options_.max_bucket_size >= TILE_SIZE && offset % TILE_SIZE == 0 \
&& n % TILE_SIZE == 0;

if (match_fast_cond) {
int grid_size = std::min(sm_cnt_, static_cast<int>(SAFE_GET_GRID_SIZE(n, options_.block_size)));
const int TILE_SIZE = 8;

dump_kernel_v2<key_type, value_type, score_type, PredFunctor, TILE_SIZE>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values, scores,
offset, n, d_counter);
} else {
const size_t score_size = scores ? sizeof(score_type) : 0;
const size_t kvm_size =
sizeof(key_type) + sizeof(value_type) * dim() + score_size;
const size_t block_size = std::min(shared_mem_size_ / 2 / kvm_size, 1024UL);
MERLIN_CHECK(
block_size > 0,
"[HierarchicalKV] block_size <= 0, the K-V-S size may be too large!");

dump_kernel<key_type, value_type, score_type, PredFunctor>
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values, scores,
offset, n, d_counter);
const size_t shared_size = kvm_size * block_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(n, block_size);
dump_kernel<key_type, value_type, score_type, PredFunctor>
<<<grid_size, block_size, shared_size, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values, scores,
offset, n, d_counter);
}

CudaCheckError();
}
Expand Down
124 changes: 124 additions & 0 deletions tests/export_batch_if_test.cc.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <algorithm>
#include <array>
#include <iostream>
#include <thread>
#include <vector>
#include <map>
#include <unordered_map>
#include "merlin_hashtable.cuh"
#include "merlin/types.cuh"
#include "test_util.cuh"

using K = uint64_t;
using V = float;
using S = uint64_t;
using i64 = int64_t;
using u64 = uint64_t;
using f32 = float;
using EvictStrategy = nv::merlin::EvictStrategy;
using TableOptions = nv::merlin::HashTableOptions;

template <class K, class S>
struct ExportIfPredFunctor {
__forceinline__ __device__ bool operator()(const K& key, S& score,
const K& pattern,
const S& threshold) {
return score < threshold;
}
};

void test_export_batch_if() {
constexpr uint64_t CAP = 1024ul;
size_t n = 256;
size_t n0 = 127;
size_t n1 = 128;
size_t n2 = 163;
size_t dim = 32;
size_t table_size = 0;
i64 pattern = 0;
u64 threshold = 40;

cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));

TableOptions options;
options.init_capacity = CAP;
options.max_capacity = CAP;
options.dim = dim;
options.max_hbm_for_vectors = nv::merlin::GB(100);
using Table = nv::merlin::HashTable<i64, f32, u64, EvictStrategy::kCustomized>;

std::unique_ptr<Table> table = std::make_unique<Table>();
table->init(options);

test_util::KVMSBuffer<i64, f32, u64> buffer0;
buffer0.Reserve(n0, dim, stream);
buffer0.ToRange(0, 1, stream);
buffer0.Setscore((u64)15, stream);
table->insert_or_assign(n0, buffer0.keys_ptr(), buffer0.values_ptr(), buffer0.scores_ptr(), stream, true, false);
table_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
MERLIN_EXPECT_TRUE(table_size == n0, "Invalid table size.");

test_util::KVMSBuffer<i64, f32, u64> buffer1;
buffer1.Reserve(n1, dim, stream);
buffer1.ToRange(n0, 1, stream);
buffer1.Setscore((u64)30, stream);
table->insert_or_assign(n1, buffer1.keys_ptr(), buffer1.values_ptr(), buffer1.scores_ptr(), stream, true, false);
table_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
MERLIN_EXPECT_TRUE(table_size == n0 + n1, "Invalid table size.");

test_util::KVMSBuffer<i64, f32, u64> buffer2;
buffer2.Reserve(n2, dim, stream);
buffer2.ToRange(n0 + n1, 1, stream);
buffer2.Setscore((u64)45, stream);
table->insert_or_assign(n2, buffer2.keys_ptr(), buffer2.values_ptr(), buffer2.scores_ptr(), stream, true, false);
table_size = table->size(stream);
CUDA_CHECK(cudaStreamSynchronize(stream));
MERLIN_EXPECT_TRUE(table_size == n0 + n1 + n2, "Invalid table size.");

test_util::KVMSBuffer<i64, f32, u64> buffer_out;
buffer_out.Reserve(CAP, dim, stream);
buffer_out.ToZeros(stream);

size_t* d_cnt = nullptr;
size_t h_cnt = 0;
CUDA_CHECK(cudaMallocAsync(&d_cnt, sizeof(size_t), stream));
CUDA_CHECK(cudaMemsetAsync(d_cnt, 0, sizeof(size_t), stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
table->export_batch_if<ExportIfPredFunctor>(pattern, threshold,
static_cast<size_t>(CAP), 0,
d_cnt, buffer_out.keys_ptr(),
buffer_out.values_ptr(),
buffer_out.scores_ptr(),
stream);
CUDA_CHECK(cudaMemcpyAsync(&h_cnt, d_cnt, sizeof(size_t), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
MERLIN_EXPECT_TRUE(h_cnt == n0 + n1, "export_batch_if get invalid cnt.");

buffer_out.SyncData(false, stream);
CUDA_CHECK(cudaStreamSynchronize(stream));

std::unordered_map<i64, u64> record;
for (size_t i = 0; i < h_cnt; i++) {
i64 key = buffer_out.keys_ptr(false)[i];
u64 score = buffer_out.scores_ptr(false)[i];
MERLIN_EXPECT_TRUE(key == static_cast<i64>(score), "");
record[key] = score;
for (int j = 0; j < dim; j++) {
f32 value = buffer_out.values_ptr(false)[i * dim + j];
MERLIN_EXPECT_TRUE(key == static_cast<i64>(value), "");
}
}
MERLIN_EXPECT_TRUE(record.size() == n0 + n1 + n2, "");
printf("done\n");
}

int main() {
test_export_batch_if();
return 0;
}

0 comments on commit 12ccbd3

Please sign in to comment.