Skip to content

Commit

Permalink
[Graphbolt] Implement Temporal Neighbor Sampling. (#6784)
Browse files Browse the repository at this point in the history
  • Loading branch information
czkkkkkk authored Dec 22, 2023
1 parent 8a8f2b0 commit e42c7fc
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 11 deletions.
37 changes: 37 additions & 0 deletions graphbolt/include/graphbolt/fused_csc_sampling_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,12 +508,28 @@ int64_t NumPick(
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors);

int64_t TemporalNumPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout,
bool replace, const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors);

int64_t NumPickByEtype(
const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask, int64_t offset,
int64_t num_neighbors);

int64_t TemporalNumPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors);

/**
* @brief Picks a specified number of neighbors for a node, starting from the
* given offset and having the specified number of neighbors.
Expand Down Expand Up @@ -562,6 +578,16 @@ int64_t Pick(
const torch::optional<torch::Tensor>& probs_or_mask,
SamplerArgs<SamplerType::LABOR> args, PickedType* picked_data_ptr);

template <typename PickedType>
int64_t TemporalPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout,
bool replace, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp,
PickedType* picked_data_ptr);

/**
* @brief Picks a specified number of neighbors for a node per edge type,
* starting from the given offset and having the specified number of neighbors.
Expand Down Expand Up @@ -597,6 +623,17 @@ int64_t PickByEtype(
const torch::optional<torch::Tensor>& probs_or_mask, SamplerArgs<S> args,
PickedType* picked_data_ptr);

template <typename PickedType>
int64_t TemporalPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
int64_t seed_offset, int64_t offset, int64_t num_neighbors,
const std::vector<int64_t>& fanouts, bool replace,
const torch::TensorOptions& options, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp,
PickedType* picked_data_ptr);

template <
bool NonUniform, bool Replace, typename ProbsType, typename PickedType,
int StackSize = 1024>
Expand Down
228 changes: 223 additions & 5 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "./random.h"
#include "./shared_memory_helper.h"
#include "./utils.h"

namespace {
torch::optional<torch::Dict<std::string, torch::Tensor>> TensorizeDict(
Expand Down Expand Up @@ -349,6 +350,31 @@ auto GetNumPickFn(
};
}

auto GetTemporalNumPickFn(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
const std::vector<int64_t>& fanouts, bool replace,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp) {
// If fanouts.size() > 1, returns the total number of all edge types of the
// given node.
return [&seed_timestamp, &csc_indices, &fanouts, replace, &probs_or_mask,
&type_per_edge, &node_timestamp, &edge_timestamp](
int64_t seed_offset, int64_t offset, int64_t num_neighbors) {
if (fanouts.size() > 1) {
return TemporalNumPickByEtype(
seed_timestamp, csc_indices, fanouts, replace, type_per_edge.value(),
probs_or_mask, node_timestamp, edge_timestamp, seed_offset, offset,
num_neighbors);
} else {
return TemporalNumPick(
seed_timestamp, csc_indices, fanouts[0], replace, probs_or_mask,
node_timestamp, edge_timestamp, seed_offset, offset, num_neighbors);
}
};
}

/**
* @brief Get a lambda function which contains the sampling process.
*
Expand Down Expand Up @@ -400,6 +426,39 @@ auto GetPickFn(
};
}

auto GetTemporalPickFn(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
const std::vector<int64_t>& fanouts, bool replace,
const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp) {
return [&seed_timestamp, &csc_indices, &fanouts, replace, &options,
&type_per_edge, &probs_or_mask, &node_timestamp, &edge_timestamp](
int64_t seed_offset, int64_t offset, int64_t num_neighbors,
auto picked_data_ptr) {
// If fanouts.size() > 1, perform sampling for each edge type of each
// node; otherwise just sample once for each node with no regard of edge
// types.
if (fanouts.size() > 1) {
return TemporalPickByEtype(
seed_timestamp, csc_indices, seed_offset, offset, num_neighbors,
fanouts, replace, options, type_per_edge.value(), probs_or_mask,
node_timestamp, edge_timestamp, picked_data_ptr);
} else {
int64_t num_sampled = TemporalPick(
seed_timestamp, csc_indices, seed_offset, offset, num_neighbors,
fanouts[0], replace, options, probs_or_mask, node_timestamp,
edge_timestamp, picked_data_ptr);
if (type_per_edge) {
std::sort(picked_data_ptr, picked_data_ptr + num_sampled);
}
return num_sampled;
}
};
}

template <typename NumPickFn, typename PickFn>
c10::intrusive_ptr<FusedSampledSubgraph>
FusedCSCSamplingGraph::SampleNeighborsImpl(
Expand Down Expand Up @@ -579,14 +638,31 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
torch::optional<std::string> probs_name,
torch::optional<std::string> node_timestamp_attr_name,
torch::optional<std::string> edge_timestamp_attr_name) const {
// TODO(zhenkun):
// 1. Get probs_or_mask.
auto probs_or_mask = this->EdgeAttribute(probs_name);
if (probs_name.has_value()) {
// Note probs will be passed as input for 'torch.multinomial' in deeper
// stack, which doesn't support 'torch.half' and 'torch.bool' data types. To
// avoid crashes, convert 'probs_or_mask' to 'float32' data type.
if (probs_or_mask.value().dtype() == torch::kBool ||
probs_or_mask.value().dtype() == torch::kFloat16) {
probs_or_mask = probs_or_mask.value().to(torch::kFloat32);
}
}
// 2. Get the timestamp attribute for nodes of the graph
auto node_timestamp = this->NodeAttribute(node_timestamp_attr_name);
// 3. Get the timestamp attribute for edges of the graph
// 4. GetTemporalNumPickFn (New implementation)
// 5. GetTemporalPickFn (New implementation)
// 6. Call SampleNeighborsImpl (Old implementation)
return c10::intrusive_ptr<FusedSampledSubgraph>();
auto edge_timestamp = this->EdgeAttribute(edge_timestamp_attr_name);
// 4. Call SampleNeighborsImpl
return SampleNeighborsImpl(
input_nodes, return_eids,
GetTemporalNumPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp),
GetTemporalPickFn(
input_nodes_timestamp, this->indices_, fanouts, replace,
indptr_.options(), type_per_edge_, probs_or_mask, node_timestamp,
edge_timestamp));
}

std::tuple<torch::Tensor, torch::Tensor>
Expand Down Expand Up @@ -669,6 +745,43 @@ int64_t NumPick(
return replace ? fanout : std::min(fanout, num_valid_neighbors);
}

torch::Tensor TemporalMask(
int64_t seed_timestamp, torch::Tensor csc_indices,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp,
std::pair<int64_t, int64_t> edge_range) {
auto [l, r] = edge_range;
torch::Tensor mask = torch::ones({r - l}, torch::kBool);
if (node_timestamp.has_value()) {
auto neighbor_timestamp =
node_timestamp.value().index_select(0, csc_indices.slice(0, l, r));
mask &= neighbor_timestamp <= seed_timestamp;
}
if (edge_timestamp.has_value()) {
mask &= edge_timestamp.value().slice(0, l, r) <= seed_timestamp;
}
if (probs_or_mask.has_value()) {
mask &= probs_or_mask.value().slice(0, l, r) != 0;
}
return mask;
}

int64_t TemporalNumPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout,
bool replace, const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors) {
auto mask = TemporalMask(
utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indics,
probs_or_mask, node_timestamp, edge_timestamp,
{offset, offset + num_neighbors});
int64_t num_valid_neighbors = utils::GetValueByIndex<int64_t>(mask.sum(), 0);
if (num_valid_neighbors == 0 || fanout == -1) return num_valid_neighbors;
return replace ? fanout : std::min(fanout, num_valid_neighbors);
}

int64_t NumPickByEtype(
const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge,
Expand Down Expand Up @@ -699,6 +812,40 @@ int64_t NumPickByEtype(
return total_count;
}

int64_t TemporalNumPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
const std::vector<int64_t>& fanouts, bool replace,
const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp, int64_t seed_offset,
int64_t offset, int64_t num_neighbors) {
int64_t etype_begin = offset;
const int64_t end = offset + num_neighbors;
int64_t total_count = 0;
AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "TemporalNumPickFnByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin];
TORCH_CHECK(
etype >= 0 && etype < (int64_t)fanouts.size(),
"Etype values exceed the number of fanouts.");
auto etype_end_it = std::upper_bound(
type_per_edge_data + etype_begin, type_per_edge_data + end,
etype);
int64_t etype_end = etype_end_it - type_per_edge_data;
// Do sampling for one etype.
total_count += TemporalNumPick(
seed_timestamp, csc_indices, fanouts[etype], replace,
probs_or_mask, node_timestamp, edge_timestamp, seed_offset,
etype_begin, etype_end - etype_begin);
etype_begin = etype_end;
}
}));
return total_count;
}

/**
* @brief Perform uniform sampling of elements and return the sampled indices.
*
Expand Down Expand Up @@ -983,6 +1130,35 @@ int64_t Pick(
}
}

template <typename PickedType>
int64_t TemporalPick(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout,
bool replace, const torch::TensorOptions& options,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp,
PickedType* picked_data_ptr) {
auto mask = TemporalMask(
utils::GetValueByIndex<int64_t>(seed_timestamp, seed_offset), csc_indices,
probs_or_mask, node_timestamp, edge_timestamp,
{offset, offset + num_neighbors});
torch::Tensor masked_prob;
if (probs_or_mask.has_value()) {
masked_prob =
probs_or_mask.value().slice(0, offset, offset + num_neighbors) * mask;
} else {
masked_prob = mask.to(torch::kFloat32);
}
auto picked_indices = NonUniformPickOp(masked_prob, fanout, replace);
auto picked_indices_ptr = picked_indices.data_ptr<int64_t>();
for (int i = 0; i < picked_indices.numel(); ++i) {
picked_data_ptr[i] =
static_cast<PickedType>(picked_indices_ptr[i]) + offset;
}
return picked_indices.numel();
}

template <SamplerType S, typename PickedType>
int64_t PickByEtype(
int64_t offset, int64_t num_neighbors, const std::vector<int64_t>& fanouts,
Expand Down Expand Up @@ -1020,6 +1196,48 @@ int64_t PickByEtype(
return pick_offset;
}

template <typename PickedType>
int64_t TemporalPickByEtype(
torch::Tensor seed_timestamp, torch::Tensor csc_indices,
int64_t seed_offset, int64_t offset, int64_t num_neighbors,
const std::vector<int64_t>& fanouts, bool replace,
const torch::TensorOptions& options, const torch::Tensor& type_per_edge,
const torch::optional<torch::Tensor>& probs_or_mask,
const torch::optional<torch::Tensor>& node_timestamp,
const torch::optional<torch::Tensor>& edge_timestamp,
PickedType* picked_data_ptr) {
int64_t etype_begin = offset;
int64_t etype_end = offset;
int64_t pick_offset = 0;
AT_DISPATCH_INTEGRAL_TYPES(
type_per_edge.scalar_type(), "TemporalPickByEtype", ([&] {
const scalar_t* type_per_edge_data = type_per_edge.data_ptr<scalar_t>();
const auto end = offset + num_neighbors;
while (etype_begin < end) {
scalar_t etype = type_per_edge_data[etype_begin];
TORCH_CHECK(
etype >= 0 && etype < (int64_t)fanouts.size(),
"Etype values exceed the number of fanouts.");
int64_t fanout = fanouts[etype];
auto etype_end_it = std::upper_bound(
type_per_edge_data + etype_begin, type_per_edge_data + end,
etype);
etype_end = etype_end_it - type_per_edge_data;
// Do sampling for one etype.
if (fanout != 0) {
int64_t picked_count = TemporalPick(
seed_timestamp, csc_indices, seed_offset, etype_begin,
etype_end - etype_begin, fanout, replace, options,
probs_or_mask, node_timestamp, edge_timestamp,
picked_data_ptr + pick_offset);
pick_offset += picked_count;
}
etype_begin = etype_end;
}
}));
return pick_offset;
}

template <typename PickedType>
int64_t Pick(
int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace,
Expand Down
23 changes: 23 additions & 0 deletions graphbolt/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@ inline bool is_accessible_from_gpu(torch::Tensor tensor) {
return tensor.is_pinned() || tensor.device().type() == c10::DeviceType::CUDA;
}

/**
* @brief Retrieves the value of the tensor at the given index.
*
* @note If the tensor is not contiguous, it will be copied to a contiguous
* tensor.
*
* @tparam T The type of the tensor.
* @param tensor The tensor.
* @param index The index.
*
* @return T The value of the tensor at the given index.
*/
template <typename T>
T GetValueByIndex(const torch::Tensor& tensor, int64_t index) {
TORCH_CHECK(
index >= 0 && index < tensor.numel(),
"The index should be within the range of the tensor, but got index ",
index, " and tensor size ", tensor.numel());
auto contiguous_tensor = tensor.contiguous();
auto data_ptr = contiguous_tensor.data_ptr<T>();
return data_ptr[index];
}

} // namespace utils
} // namespace graphbolt

Expand Down
Loading

0 comments on commit e42c7fc

Please sign in to comment.