Skip to content

Commit

Permalink
[GraphBolt][CUDA] Add FeatureCache::QueryDirect.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jul 15, 2024
1 parent 4d497d1 commit 3c0ed42
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 6 deletions.
17 changes: 13 additions & 4 deletions graphbolt/src/feature_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
*/
#include "./feature_cache.h"

#include "./index_select.h"

namespace graphbolt {
namespace storage {

constexpr int kIntGrainSize = 64;

FeatureCache::FeatureCache(
const std::vector<int64_t>& shape, torch::ScalarType dtype)
: tensor_(torch::empty(shape, c10::TensorOptions().dtype(dtype))) {}
const std::vector<int64_t>& shape, torch::ScalarType dtype, bool pin_memory)
: tensor_(torch::empty(
shape, c10::TensorOptions().dtype(dtype).pinned_memory(pin_memory))) {
}

torch::Tensor FeatureCache::Query(
torch::Tensor positions, torch::Tensor indices, int64_t size) {
Expand All @@ -52,6 +56,10 @@ torch::Tensor FeatureCache::Query(
return values;
}

torch::Tensor FeatureCache::QueryDirect(torch::Tensor positions) {
return ops::IndexSelect(tensor_, positions);
}

void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) {
const auto row_bytes = values.slice(0, 0, 1).numel() * values.element_size();
auto values_ptr = reinterpret_cast<std::byte*>(values.data_ptr());
Expand All @@ -68,8 +76,9 @@ void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) {
}

c10::intrusive_ptr<FeatureCache> FeatureCache::Create(
const std::vector<int64_t>& shape, torch::ScalarType dtype) {
return c10::make_intrusive<FeatureCache>(shape, dtype);
const std::vector<int64_t>& shape, torch::ScalarType dtype,
bool pin_memory) {
return c10::make_intrusive<FeatureCache>(shape, dtype, pin_memory);
}

} // namespace storage
Expand Down
16 changes: 14 additions & 2 deletions graphbolt/src/feature_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ struct FeatureCache : public torch::CustomClassHolder {
* @param shape The shape of the cache.
* @param dtype The dtype of elements stored in the cache.
*/
FeatureCache(const std::vector<int64_t>& shape, torch::ScalarType dtype);
FeatureCache(
const std::vector<int64_t>& shape, torch::ScalarType dtype,
bool pin_memory);

/**
* @brief The cache query function. Allocates an empty tensor `values` with
Expand All @@ -57,6 +59,15 @@ struct FeatureCache : public torch::CustomClassHolder {
torch::Tensor Query(
torch::Tensor positions, torch::Tensor indices, int64_t size);

/**
* @brief The cache query function. Returns cache_tensor[positions].
*
* @param positions The positions of the queries items.
*
* @return The values tensor is returned on the same device as positions.
*/
torch::Tensor QueryDirect(torch::Tensor positions);

/**
* @brief The cache replace function.
*
Expand All @@ -66,7 +77,8 @@ struct FeatureCache : public torch::CustomClassHolder {
void Replace(torch::Tensor positions, torch::Tensor values);

static c10::intrusive_ptr<FeatureCache> Create(
const std::vector<int64_t>& shape, torch::ScalarType dtype);
const std::vector<int64_t>& shape, torch::ScalarType dtype,
bool pin_memory);

private:
torch::Tensor tensor_;
Expand Down
1 change: 1 addition & 0 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ TORCH_LIBRARY(graphbolt, m) {
"clock_cache_policy",
&storage::PartitionedCachePolicy::Create<storage::ClockCachePolicy>);
m.class_<storage::FeatureCache>("FeatureCache")
.def("query_direct", &storage::FeatureCache::QueryDirect)
.def("query", &storage::FeatureCache::Query)
.def("replace", &storage::FeatureCache::Replace);
m.def("feature_cache", &storage::FeatureCache::Create);
Expand Down

0 comments on commit 3c0ed42

Please sign in to comment.