Skip to content

Commit

Permalink
todo
Browse files Browse the repository at this point in the history
  • Loading branch information
Lifann authored and oppenheimli committed Aug 13, 2024
1 parent 46c9f89 commit 808b073
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ class HashTable : public HashTableBase<K, V, S> {
cudaDeviceProp deviceProp;
CUDA_CHECK(cudaGetDeviceProperties(&deviceProp, options_.device_id));
shared_mem_size_ = deviceProp.sharedMemPerBlock;
sm_cnt_ = deviceProp.multiProcessorCount;
create_table<key_type, value_type, score_type>(
&table_, allocator_, options_.dim, options_.init_capacity,
options_.max_capacity, options_.max_hbm_for_vectors,
Expand Down Expand Up @@ -2611,20 +2612,14 @@ class HashTable : public HashTableBase<K, V, S> {
return;
}
n = std::min(table_->capacity - offset, n);
if (n == 0) {
return;
}

const size_t score_size = scores ? sizeof(score_type) : 0;
const size_t kvm_size =
sizeof(key_type) + sizeof(value_type) * dim() + score_size;
const size_t block_size = std::min(shared_mem_size_ / 2 / kvm_size, 1024UL);
MERLIN_CHECK(
block_size > 0,
"[HierarchicalKV] block_size <= 0, the K-V-S size may be too large!");

const size_t shared_size = kvm_size * block_size;
const size_t grid_size = SAFE_GET_GRID_SIZE(n, block_size);
int grid_size = std::min(sm_cnt_, static_cast<int>(SAFE_GET_GRID_SIZE(n, options_.block_size)));

dump_kernel<key_type, value_type, score_type, PredFunctor>
<<<grid_size, block_size, shared_size, stream>>>(
<<<grid_size, options_.block_size, 0, stream>>>(
d_table_, table_->buckets, pattern, threshold, keys, values, scores,
offset, n, d_counter);

Expand Down Expand Up @@ -3037,6 +3032,7 @@ class HashTable : public HashTableBase<K, V, S> {
TableCore* table_ = nullptr;
TableCore* d_table_ = nullptr;
size_t shared_mem_size_ = 0;
int sm_cnt_ = 0;
std::atomic_bool reach_max_capacity_{false};
bool initialized_ = false;
mutable group_shared_mutex mutex_;
Expand Down

0 comments on commit 808b073

Please sign in to comment.