From 3c0ed42b8153ef074665864022babba7aeda156f Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Mon, 15 Jul 2024 17:37:46 -0400 Subject: [PATCH] [GraphBolt][CUDA] Add `FeatureCache::QueryDirect`. --- graphbolt/src/feature_cache.cc | 17 +++++++++++++---- graphbolt/src/feature_cache.h | 16 ++++++++++++++-- graphbolt/src/python_binding.cc | 1 + 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/graphbolt/src/feature_cache.cc b/graphbolt/src/feature_cache.cc index c532a0efe79d..e9f7b1a96caa 100644 --- a/graphbolt/src/feature_cache.cc +++ b/graphbolt/src/feature_cache.cc @@ -19,14 +19,18 @@ */ #include "./feature_cache.h" +#include "./index_select.h" + namespace graphbolt { namespace storage { constexpr int kIntGrainSize = 64; FeatureCache::FeatureCache( - const std::vector& shape, torch::ScalarType dtype) - : tensor_(torch::empty(shape, c10::TensorOptions().dtype(dtype))) {} + const std::vector& shape, torch::ScalarType dtype, bool pin_memory) + : tensor_(torch::empty( + shape, c10::TensorOptions().dtype(dtype).pinned_memory(pin_memory))) { +} torch::Tensor FeatureCache::Query( torch::Tensor positions, torch::Tensor indices, int64_t size) { @@ -52,6 +56,10 @@ torch::Tensor FeatureCache::Query( return values; } +torch::Tensor FeatureCache::QueryDirect(torch::Tensor positions) { + return ops::IndexSelect(tensor_, positions); +} + void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) { const auto row_bytes = values.slice(0, 0, 1).numel() * values.element_size(); auto values_ptr = reinterpret_cast(values.data_ptr()); @@ -68,8 +76,9 @@ void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) { } c10::intrusive_ptr FeatureCache::Create( - const std::vector& shape, torch::ScalarType dtype) { - return c10::make_intrusive(shape, dtype); + const std::vector& shape, torch::ScalarType dtype, + bool pin_memory) { + return c10::make_intrusive(shape, dtype, pin_memory); } } // namespace storage diff --git a/graphbolt/src/feature_cache.h b/graphbolt/src/feature_cache.h index 880e7bb37cdf..a26f1a21aa2e 100644 --- a/graphbolt/src/feature_cache.h +++ b/graphbolt/src/feature_cache.h @@ -35,7 +35,9 @@ struct FeatureCache : public torch::CustomClassHolder { * @param shape The shape of the cache. * @param dtype The dtype of elements stored in the cache. */ - FeatureCache(const std::vector& shape, torch::ScalarType dtype); + FeatureCache( + const std::vector& shape, torch::ScalarType dtype, + bool pin_memory); /** * @brief The cache query function. Allocates an empty tensor `values` with @@ -57,6 +59,15 @@ struct FeatureCache : public torch::CustomClassHolder { torch::Tensor Query( torch::Tensor positions, torch::Tensor indices, int64_t size); + /** + * @brief The cache query function. Returns cache_tensor[positions]. + * + * @param positions The positions of the queries items. + * + * @return The values tensor is returned on the same device as positions. + */ + torch::Tensor QueryDirect(torch::Tensor positions); + /** * @brief The cache replace function. * @@ -66,7 +77,8 @@ struct FeatureCache : public torch::CustomClassHolder { void Replace(torch::Tensor positions, torch::Tensor values); static c10::intrusive_ptr Create( - const std::vector& shape, torch::ScalarType dtype); + const std::vector& shape, torch::ScalarType dtype, + bool pin_memory); private: torch::Tensor tensor_; diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 15feffec5e2c..f6c6c8dd3789 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -119,6 +119,7 @@ TORCH_LIBRARY(graphbolt, m) { "clock_cache_policy", &storage::PartitionedCachePolicy::Create); m.class_("FeatureCache") + .def("query_direct", &storage::FeatureCache::QueryDirect) .def("query", &storage::FeatureCache::Query) .def("replace", &storage::FeatureCache::Replace); m.def("feature_cache", &storage::FeatureCache::Create);