Skip to content

Commit

Permalink
Initial design of the external storage interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
bashimao committed Jan 17, 2023
1 parent 985df72 commit f93eadf
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 5 deletions.
113 changes: 113 additions & 0 deletions include/merlin/external_storage.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cstdint>
#include <type_traits>
#include "merlin/memory_pool.cuh"

namespace nv {
namespace merlin {

template <class Key, class Value>
class ExternalStorage {
public:
using size_type = size_t;
using key_type = Key;
using value_type = Value;

using dev_mem_pool_type = MemoryPool<DeviceAllocator<char>>;
using host_mem_pool_type = MemoryPool<HostAllocator<char>>;

const size_type value_dim;

ExternalStorage() = delete;

/**
* Constructs external storage object.
*
* @param value_dim The dimensionality of the values. In other words, each
* value stored is exactly `value_dim * sizeof(value_type)` bytes large.
*/
ExternalStorage(const size_type value_dim) : value_dim{value_dim} {}

/**
* @brief Inserts key/value pairs into the external storage that are about to
* be evicted from the Merlin hashtable. If a key/value pair already exists,
* overwrites the current value.
*
* @param dev_mem_pool Memory pool for temporarily allocating device memory.
* @param host_mem_pool Memory pool for temporarily allocating host memory.
* @param hkvs_is_pure_hbm True if the Merlin hashtable store is currently
* operating in pure HBM mode, false otherwise. In pure HBM mode, all `values`
* pointers are GUARANTEED to point to device memory.
* @param n Number of key/value slots provided in other arguments.
* @param d_masked_keys Device pointer to an (n)-sized array of keys.
* Key-Value slots that should be ignored have the key set tO `EMPTY_KEY`.
* @param d_values Device pointer to an (n)-sized array containing pointers to
* respectively a memory location where the current values for a key are
* stored. Each pointer points to a vector of length `value_dim`. Pointers
* *can* be set to `nullptr` for slots where the corresponding key equated to
* the `EMPTY_KEY`. The memory locations can be device or host memory (see
* also `hkvs_is_pure_hbm`).
* @param stream Stream that MUST be used for queuing asynchronous CUDA
* operations. If only the input arguments or resources obtained from
* respectively `dev_mem_pool` and `host_mem_pool` are used for such
* operations, it is not necessary to synchronize the stream prior to
* returning from the function.
*/
virtual void insert_or_assign(dev_mem_pool_type& dev_mem_pool,
host_mem_pool_type& host_mem_pool,
bool hkvs_is_pure_hbm, size_type n,
const key_type* d_masked_keys, // (n)
const value_type* const* d_values, // (n)
cudaStream_t stream) = 0;

/**
* @brief Attempts to find the supplied `d_keys` if the corresponding
* `d_founds`-flag is `false` and fills the stored into the supplied memory
* locations (i.e. in `d_values`).
*
* @param dev_mem_pool Memory pool for temporarily allocating device memory.
* @param host_mem_pool Memory pool for temporarily allocating host memory.
* @param n Number of key/value slots provided in other arguments.
* @param d_keys Device pointer to an (n)-sized array of keys.
* @param d_values Device pointer to an (n * value_dim)-sized array to store
* the retrieved `d_values`. For slots where the corresponding `d_founds`-flag
* is not `false`, the value may already have been assigned and, thus, MUST
* not be altered.
* @param d_founds Device pointer to an (n)-sized array which indicates
* whether the corresponding `d_values` slot is already filled or not. So, if
* and only if `d_founds` is still false, the implementation shall attempt to
* retrieve and fill in the value for the corresponding key. If a key/value
* was retrieved successfully from external storage, the implementation MUST
* also set `d_founds` to `true`.
* @param stream Stream that MUST be used for queuing asynchronous CUDA
* operations. If only the input arguments or resources obtained from
* respectively `dev_mem_pool` and `host_mem_pool` are used for such
* operations, it is not necessary to synchronize the stream prior to
* returning from the function.
*/
virtual void find(dev_mem_pool_type& dev_mem_pool,
host_mem_pool_type& host_mem_pool, size_type n,
const key_type* d_keys, // (n)
value_type* d_values, // (n * value_dim)
bool* d_founds, // (n)
cudaStream_t stream) = 0;
};

} // namespace merlin
} // namespace nv
48 changes: 43 additions & 5 deletions include/merlin_hashtable.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <shared_mutex>
#include <type_traits>
#include "merlin/core_kernels.cuh"
#include "merlin/external_storage.cuh"
#include "merlin/flexible_buffer.cuh"
#include "merlin/memory_pool.cuh"
#include "merlin/types.cuh"
Expand Down Expand Up @@ -160,6 +161,8 @@ class HashTable {
using DeviceMemoryPool = MemoryPool<DeviceAllocator<char>>;
using HostMemoryPool = MemoryPool<HostAllocator<char>>;

using external_storage_type = ExternalStorage<K, V>;

#if THRUST_VERSION >= 101600
static constexpr auto thrust_par = thrust::cuda::par_nosync;
#else
Expand All @@ -179,6 +182,8 @@ class HashTable {
~HashTable() {
CUDA_CHECK(cudaDeviceSynchronize());

unlink_external_storage();

// Erase table.
if (initialized_) {
destroy_table<key_type, vector_type, meta_type, DIM>(&table_);
Expand Down Expand Up @@ -308,9 +313,12 @@ class HashTable {
}
} else {
const size_t dev_ws_size = n * (sizeof(vector_type*) + sizeof(int));
auto dev_ws = dev_mem_pool_->get_workspace<1>(dev_ws_size, stream);
auto dev_ws = dev_mem_pool_->get_workspace<1>(
dev_ws_size + (ext_store_ ? n * sizeof(key_type) : 0), stream);
auto d_dst = dev_ws.get<vector_type**>(0);
auto d_src_offset = reinterpret_cast<int*>(d_dst + n);
auto d_evicted_keys =
ext_store_ ? reinterpret_cast<key_type*>(d_src_offset + n) : nullptr;

CUDA_CHECK(cudaMemsetAsync(d_dst, 0, dev_ws_size, stream));

Expand All @@ -322,18 +330,26 @@ class HashTable {
if (metas == nullptr) {
upsert_kernel<key_type, vector_type, meta_type, DIM, TILE_SIZE>
<<<grid_size, block_size, 0, stream>>>(
table_, keys, d_dst, table_->buckets, table_->buckets_size,
table_->bucket_max_size, table_->buckets_num, d_src_offset,
N);
table_, keys,
/* d_evicted_keys, */ d_dst, table_->buckets,
table_->buckets_size, table_->bucket_max_size,
table_->buckets_num, d_src_offset, N);
} else {
upsert_kernel<key_type, vector_type, meta_type, DIM, TILE_SIZE>
<<<grid_size, block_size, 0, stream>>>(
table_, keys, d_dst, metas, table_->buckets,
table_, keys,
/* d_evicted_keys, */ d_dst, metas, table_->buckets,
table_->buckets_size, table_->bucket_max_size,
table_->buckets_num, d_src_offset, N);
}
}

if (ext_store_) {
ext_store_->insert_or_assign(
*dev_mem_pool_, *host_mem_pool_, table_->is_pure_hbm, n,
d_evicted_keys, reinterpret_cast<value_type**>(d_dst), stream);
}

{
thrust::device_ptr<uintptr_t> d_dst_ptr(
reinterpret_cast<uintptr_t*>(d_dst));
Expand Down Expand Up @@ -575,6 +591,11 @@ class HashTable {
}
}

if (ext_store_) {
ext_store_->find(*dev_mem_pool_, *host_mem_pool_, n, keys, values, founds,
stream);
}

CudaCheckError();
}

Expand Down Expand Up @@ -1113,6 +1134,21 @@ class HashTable {
return total_count;
}

void link_external_storage(
std::shared_ptr<external_storage_type>& ext_store) {
MERLIN_CHECK(
ext_store->value_dim == DIM,
"Provided external storage value dimension is not incompatible!");

std::unique_lock<std::shared_timed_mutex> lock(mutex_);
ext_store_ = ext_store;
}

void unlink_external_storage() {
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
ext_store_.reset();
}

private:
inline bool is_fast_mode() const noexcept { return table_->is_pure_hbm; }

Expand Down Expand Up @@ -1171,6 +1207,8 @@ class HashTable {

std::unique_ptr<DeviceMemoryPool> dev_mem_pool_;
std::unique_ptr<HostMemoryPool> host_mem_pool_;

std::shared_ptr<external_storage_type> ext_store_;
};

} // namespace merlin
Expand Down

0 comments on commit f93eadf

Please sign in to comment.