Skip to content

Commit

Permalink
Remove v3 impl
Browse files Browse the repository at this point in the history
  • Loading branch information
oppenheimli committed Aug 21, 2024
1 parent 7aaabe7 commit 2d8a9b6
Showing 1 changed file with 0 additions and 78 deletions.
78 changes: 0 additions & 78 deletions include/merlin/core_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1010,83 +1010,5 @@ __global__ void size_if_kernel(const Table<K, V, S>* __restrict table,
}
}

template <class K, class V, class S,
template <typename, typename> class PredFunctor, int TILE_SIZE>
__global__ void dump_kernel_v3(const Table<K, V, S>* __restrict table,
Bucket<K, V, S>* buckets, const K pattern,
const S threshold, K* d_key, V* __restrict d_val,
S* __restrict d_score, const size_t offset,
const size_t search_length,
size_t* d_dump_counter) {
const size_t bucket_max_size = table->bucket_max_size;
int dim = table->dim;
auto g = cg::tiled_partition<TILE_SIZE>(cg::this_thread_block());

PredFunctor<K, S> pred;

__shared__ int block_cnt;
__shared__ size_t block_offset;

size_t tid = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);

for (size_t ii = tid; ii < search_length; ii += gridDim.x * blockDim.x) {
size_t bkt_idx = (ii + offset) / bucket_max_size;
size_t key_idx = (ii + offset) % bucket_max_size;
size_t leading_key_idx = key_idx / TILE_SIZE * TILE_SIZE;
Bucket<K, V, S>* bucket = &(buckets[bkt_idx]);

const K key =
(bucket->keys(key_idx))->load(cuda::std::memory_order_relaxed);
S score = bucket->scores(key_idx)->load(cuda::std::memory_order_relaxed);

if (threadIdx.x == 0) {
block_cnt = 0;
}
__syncthreads();

bool match =
(!IS_RESERVED_KEY<K>(key)) && pred(key, score, pattern, threshold);
unsigned int vote = g.ballot(match);
int tile_cnt = __popc(vote);

int in_block_tile_offset = 0;
if (g.thread_rank() == 0) {
in_block_tile_offset =
atomicAdd(reinterpret_cast<int*>(&block_cnt), tile_cnt);
}
in_block_tile_offset = g.shfl(in_block_tile_offset, 0);
__syncthreads();

if (threadIdx.x == 0) {
block_offset = atomicAdd(d_dump_counter, static_cast<size_t>(block_cnt));
}
__syncthreads();

int tile_offset = block_offset + in_block_tile_offset;
int bias_g = tile_cnt - __popc(vote >> (key_idx % TILE_SIZE));

if (match) {
d_key[tile_offset + bias_g] = key;
if (d_score) {
d_score[tile_offset + bias_g] = score;
}
}

#pragma unroll
for (int r = 0; r < TILE_SIZE; r++) {
unsigned int biased_vote = vote >> r;
bool cur_match = biased_vote & 1;
if (cur_match) {
int bias = tile_cnt - __popc(biased_vote);
int cur_idx = leading_key_idx + r;
for (int j = g.thread_rank(); j < dim; j += TILE_SIZE) {
d_val[(tile_offset + bias) * dim + j] =
bucket->vectors[cur_idx * dim + j];
}
}
}
}
}

} // namespace merlin
} // namespace nv

0 comments on commit 2d8a9b6

Please sign in to comment.