From e42c7fcdecbcc0fd7bfaf9cbb928e47e7c5f87b1 Mon Sep 17 00:00:00 2001 From: czkkkkkk Date: Fri, 22 Dec 2023 14:50:05 +0800 Subject: [PATCH] [Graphbolt] Implement Temporal Neighbor Sampling. (#6784) --- .../graphbolt/fused_csc_sampling_graph.h | 37 +++ graphbolt/src/fused_csc_sampling_graph.cc | 228 +++++++++++++++++- graphbolt/src/utils.h | 23 ++ .../impl/fused_csc_sampling_graph.py | 32 ++- 4 files changed, 309 insertions(+), 11 deletions(-) diff --git a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h index 52fea624c6e8..8a3e23ab509f 100644 --- a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h +++ b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h @@ -508,12 +508,28 @@ int64_t NumPick( const torch::optional& probs_or_mask, int64_t offset, int64_t num_neighbors); +int64_t TemporalNumPick( + torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout, + bool replace, const torch::optional& probs_or_mask, + const torch::optional& node_timestamp, + const torch::optional& edge_timestamp, int64_t seed_offset, + int64_t offset, int64_t num_neighbors); + int64_t NumPickByEtype( const std::vector& fanouts, bool replace, const torch::Tensor& type_per_edge, const torch::optional& probs_or_mask, int64_t offset, int64_t num_neighbors); +int64_t TemporalNumPickByEtype( + torch::Tensor seed_timestamp, torch::Tensor csc_indices, + const std::vector& fanouts, bool replace, + const torch::Tensor& type_per_edge, + const torch::optional& probs_or_mask, + const torch::optional& node_timestamp, + const torch::optional& edge_timestamp, int64_t seed_offset, + int64_t offset, int64_t num_neighbors); + /** * @brief Picks a specified number of neighbors for a node, starting from the * given offset and having the specified number of neighbors. @@ -562,6 +578,16 @@ int64_t Pick( const torch::optional& probs_or_mask, SamplerArgs args, PickedType* picked_data_ptr); +template +int64_t TemporalPick( + torch::Tensor seed_timestamp, torch::Tensor csc_indices, + int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout, + bool replace, const torch::TensorOptions& options, + const torch::optional& probs_or_mask, + const torch::optional& node_timestamp, + const torch::optional& edge_timestamp, + PickedType* picked_data_ptr); + /** * @brief Picks a specified number of neighbors for a node per edge type, * starting from the given offset and having the specified number of neighbors. @@ -597,6 +623,17 @@ int64_t PickByEtype( const torch::optional& probs_or_mask, SamplerArgs args, PickedType* picked_data_ptr); +template +int64_t TemporalPickByEtype( + torch::Tensor seed_timestamp, torch::Tensor csc_indices, + int64_t seed_offset, int64_t offset, int64_t num_neighbors, + const std::vector& fanouts, bool replace, + const torch::TensorOptions& options, const torch::Tensor& type_per_edge, + const torch::optional& probs_or_mask, + const torch::optional& node_timestamp, + const torch::optional& edge_timestamp, + PickedType* picked_data_ptr); + template < bool NonUniform, bool Replace, typename ProbsType, typename PickedType, int StackSize = 1024> diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index 20337519e418..fbb7ba12cb41 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -18,6 +18,7 @@ #include "./random.h" #include "./shared_memory_helper.h" +#include "./utils.h" namespace { torch::optional> TensorizeDict( @@ -349,6 +350,31 @@ auto GetNumPickFn( }; } +auto GetTemporalNumPickFn( + torch::Tensor seed_timestamp, torch::Tensor csc_indices, + const std::vector& fanouts, bool replace, + const torch::optional& type_per_edge, + const torch::optional& probs_or_mask, + const torch::optional& node_timestamp, + const torch::optional& edge_timestamp) { + // If fanouts.size() > 1, returns the total number of all edge types of the + // given node. + return [&seed_timestamp, &csc_indices, &fanouts, replace, &probs_or_mask, + &type_per_edge, &node_timestamp, &edge_timestamp]( + int64_t seed_offset, int64_t offset, int64_t num_neighbors) { + if (fanouts.size() > 1) { + return TemporalNumPickByEtype( + seed_timestamp, csc_indices, fanouts, replace, type_per_edge.value(), + probs_or_mask, node_timestamp, edge_timestamp, seed_offset, offset, + num_neighbors); + } else { + return TemporalNumPick( + seed_timestamp, csc_indices, fanouts[0], replace, probs_or_mask, + node_timestamp, edge_timestamp, seed_offset, offset, num_neighbors); + } + }; +} + /** * @brief Get a lambda function which contains the sampling process. * @@ -400,6 +426,39 @@ auto GetPickFn( }; } +auto GetTemporalPickFn( + torch::Tensor seed_timestamp, torch::Tensor csc_indices, + const std::vector& fanouts, bool replace, + const torch::TensorOptions& options, + const torch::optional& type_per_edge, + const torch::optional& probs_or_mask, + const torch::optional& node_timestamp, + const torch::optional& edge_timestamp) { + return [&seed_timestamp, &csc_indices, &fanouts, replace, &options, + &type_per_edge, &probs_or_mask, &node_timestamp, &edge_timestamp]( + int64_t seed_offset, int64_t offset, int64_t num_neighbors, + auto picked_data_ptr) { + // If fanouts.size() > 1, perform sampling for each edge type of each + // node; otherwise just sample once for each node with no regard of edge + // types. + if (fanouts.size() > 1) { + return TemporalPickByEtype( + seed_timestamp, csc_indices, seed_offset, offset, num_neighbors, + fanouts, replace, options, type_per_edge.value(), probs_or_mask, + node_timestamp, edge_timestamp, picked_data_ptr); + } else { + int64_t num_sampled = TemporalPick( + seed_timestamp, csc_indices, seed_offset, offset, num_neighbors, + fanouts[0], replace, options, probs_or_mask, node_timestamp, + edge_timestamp, picked_data_ptr); + if (type_per_edge) { + std::sort(picked_data_ptr, picked_data_ptr + num_sampled); + } + return num_sampled; + } + }; +} + template c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighborsImpl( @@ -579,14 +638,31 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( torch::optional probs_name, torch::optional node_timestamp_attr_name, torch::optional edge_timestamp_attr_name) const { - // TODO(zhenkun): // 1. Get probs_or_mask. + auto probs_or_mask = this->EdgeAttribute(probs_name); + if (probs_name.has_value()) { + // Note probs will be passed as input for 'torch.multinomial' in deeper + // stack, which doesn't support 'torch.half' and 'torch.bool' data types. To + // avoid crashes, convert 'probs_or_mask' to 'float32' data type. + if (probs_or_mask.value().dtype() == torch::kBool || + probs_or_mask.value().dtype() == torch::kFloat16) { + probs_or_mask = probs_or_mask.value().to(torch::kFloat32); + } + } // 2. Get the timestamp attribute for nodes of the graph + auto node_timestamp = this->NodeAttribute(node_timestamp_attr_name); // 3. Get the timestamp attribute for edges of the graph - // 4. GetTemporalNumPickFn (New implementation) - // 5. GetTemporalPickFn (New implementation) - // 6. Call SampleNeighborsImpl (Old implementation) - return c10::intrusive_ptr(); + auto edge_timestamp = this->EdgeAttribute(edge_timestamp_attr_name); + // 4. Call SampleNeighborsImpl + return SampleNeighborsImpl( + input_nodes, return_eids, + GetTemporalNumPickFn( + input_nodes_timestamp, this->indices_, fanouts, replace, + type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp), + GetTemporalPickFn( + input_nodes_timestamp, this->indices_, fanouts, replace, + indptr_.options(), type_per_edge_, probs_or_mask, node_timestamp, + edge_timestamp)); } std::tuple @@ -669,6 +745,43 @@ int64_t NumPick( return replace ? fanout : std::min(fanout, num_valid_neighbors); } +torch::Tensor TemporalMask( + int64_t seed_timestamp, torch::Tensor csc_indices, + const torch::optional& probs_or_mask, + const torch::optional& node_timestamp, + const torch::optional& edge_timestamp, + std::pair edge_range) { + auto [l, r] = edge_range; + torch::Tensor mask = torch::ones({r - l}, torch::kBool); + if (node_timestamp.has_value()) { + auto neighbor_timestamp = + node_timestamp.value().index_select(0, csc_indices.slice(0, l, r)); + mask &= neighbor_timestamp <= seed_timestamp; + } + if (edge_timestamp.has_value()) { + mask &= edge_timestamp.value().slice(0, l, r) <= seed_timestamp; + } + if (probs_or_mask.has_value()) { + mask &= probs_or_mask.value().slice(0, l, r) != 0; + } + return mask; +} + +int64_t TemporalNumPick( + torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout, + bool replace, const torch::optional& probs_or_mask, + const torch::optional& node_timestamp, + const torch::optional& edge_timestamp, int64_t seed_offset, + int64_t offset, int64_t num_neighbors) { + auto mask = TemporalMask( + utils::GetValueByIndex(seed_timestamp, seed_offset), csc_indics, + probs_or_mask, node_timestamp, edge_timestamp, + {offset, offset + num_neighbors}); + int64_t num_valid_neighbors = utils::GetValueByIndex(mask.sum(), 0); + if (num_valid_neighbors == 0 || fanout == -1) return num_valid_neighbors; + return replace ? fanout : std::min(fanout, num_valid_neighbors); +} + int64_t NumPickByEtype( const std::vector& fanouts, bool replace, const torch::Tensor& type_per_edge, @@ -699,6 +812,40 @@ int64_t NumPickByEtype( return total_count; } +int64_t TemporalNumPickByEtype( + torch::Tensor seed_timestamp, torch::Tensor csc_indices, + const std::vector& fanouts, bool replace, + const torch::Tensor& type_per_edge, + const torch::optional& probs_or_mask, + const torch::optional& node_timestamp, + const torch::optional& edge_timestamp, int64_t seed_offset, + int64_t offset, int64_t num_neighbors) { + int64_t etype_begin = offset; + const int64_t end = offset + num_neighbors; + int64_t total_count = 0; + AT_DISPATCH_INTEGRAL_TYPES( + type_per_edge.scalar_type(), "TemporalNumPickFnByEtype", ([&] { + const scalar_t* type_per_edge_data = type_per_edge.data_ptr(); + while (etype_begin < end) { + scalar_t etype = type_per_edge_data[etype_begin]; + TORCH_CHECK( + etype >= 0 && etype < (int64_t)fanouts.size(), + "Etype values exceed the number of fanouts."); + auto etype_end_it = std::upper_bound( + type_per_edge_data + etype_begin, type_per_edge_data + end, + etype); + int64_t etype_end = etype_end_it - type_per_edge_data; + // Do sampling for one etype. + total_count += TemporalNumPick( + seed_timestamp, csc_indices, fanouts[etype], replace, + probs_or_mask, node_timestamp, edge_timestamp, seed_offset, + etype_begin, etype_end - etype_begin); + etype_begin = etype_end; + } + })); + return total_count; +} + /** * @brief Perform uniform sampling of elements and return the sampled indices. * @@ -983,6 +1130,35 @@ int64_t Pick( } } +template +int64_t TemporalPick( + torch::Tensor seed_timestamp, torch::Tensor csc_indices, + int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout, + bool replace, const torch::TensorOptions& options, + const torch::optional& probs_or_mask, + const torch::optional& node_timestamp, + const torch::optional& edge_timestamp, + PickedType* picked_data_ptr) { + auto mask = TemporalMask( + utils::GetValueByIndex(seed_timestamp, seed_offset), csc_indices, + probs_or_mask, node_timestamp, edge_timestamp, + {offset, offset + num_neighbors}); + torch::Tensor masked_prob; + if (probs_or_mask.has_value()) { + masked_prob = + probs_or_mask.value().slice(0, offset, offset + num_neighbors) * mask; + } else { + masked_prob = mask.to(torch::kFloat32); + } + auto picked_indices = NonUniformPickOp(masked_prob, fanout, replace); + auto picked_indices_ptr = picked_indices.data_ptr(); + for (int i = 0; i < picked_indices.numel(); ++i) { + picked_data_ptr[i] = + static_cast(picked_indices_ptr[i]) + offset; + } + return picked_indices.numel(); +} + template int64_t PickByEtype( int64_t offset, int64_t num_neighbors, const std::vector& fanouts, @@ -1020,6 +1196,48 @@ int64_t PickByEtype( return pick_offset; } +template +int64_t TemporalPickByEtype( + torch::Tensor seed_timestamp, torch::Tensor csc_indices, + int64_t seed_offset, int64_t offset, int64_t num_neighbors, + const std::vector& fanouts, bool replace, + const torch::TensorOptions& options, const torch::Tensor& type_per_edge, + const torch::optional& probs_or_mask, + const torch::optional& node_timestamp, + const torch::optional& edge_timestamp, + PickedType* picked_data_ptr) { + int64_t etype_begin = offset; + int64_t etype_end = offset; + int64_t pick_offset = 0; + AT_DISPATCH_INTEGRAL_TYPES( + type_per_edge.scalar_type(), "TemporalPickByEtype", ([&] { + const scalar_t* type_per_edge_data = type_per_edge.data_ptr(); + const auto end = offset + num_neighbors; + while (etype_begin < end) { + scalar_t etype = type_per_edge_data[etype_begin]; + TORCH_CHECK( + etype >= 0 && etype < (int64_t)fanouts.size(), + "Etype values exceed the number of fanouts."); + int64_t fanout = fanouts[etype]; + auto etype_end_it = std::upper_bound( + type_per_edge_data + etype_begin, type_per_edge_data + end, + etype); + etype_end = etype_end_it - type_per_edge_data; + // Do sampling for one etype. + if (fanout != 0) { + int64_t picked_count = TemporalPick( + seed_timestamp, csc_indices, seed_offset, etype_begin, + etype_end - etype_begin, fanout, replace, options, + probs_or_mask, node_timestamp, edge_timestamp, + picked_data_ptr + pick_offset); + pick_offset += picked_count; + } + etype_begin = etype_end; + } + })); + return pick_offset; +} + template int64_t Pick( int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, diff --git a/graphbolt/src/utils.h b/graphbolt/src/utils.h index 96fa1cf10eea..093e920af017 100644 --- a/graphbolt/src/utils.h +++ b/graphbolt/src/utils.h @@ -19,6 +19,29 @@ inline bool is_accessible_from_gpu(torch::Tensor tensor) { return tensor.is_pinned() || tensor.device().type() == c10::DeviceType::CUDA; } +/** + * @brief Retrieves the value of the tensor at the given index. + * + * @note If the tensor is not contiguous, it will be copied to a contiguous + * tensor. + * + * @tparam T The type of the tensor. + * @param tensor The tensor. + * @param index The index. + * + * @return T The value of the tensor at the given index. + */ +template +T GetValueByIndex(const torch::Tensor& tensor, int64_t index) { + TORCH_CHECK( + index >= 0 && index < tensor.numel(), + "The index should be within the range of the tensor, but got index ", + index, " and tensor size ", tensor.numel()); + auto contiguous_tensor = tensor.contiguous(); + auto data_ptr = contiguous_tensor.data_ptr(); + return data_ptr[index]; +} + } // namespace utils } // namespace graphbolt diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 676e6b5798b4..ab13b990f0fe 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -439,11 +439,18 @@ def _convert_to_fused_sampled_subgraph( node_pairs=node_pairs, original_edge_ids=original_edge_ids ) - def _convert_to_homogeneous_nodes(self, nodes): + def _convert_to_homogeneous_nodes(self, nodes, timestamps=None): homogeneous_nodes = [] + homogeneous_timestamps = [] for ntype, ids in nodes.items(): ntype_id = self.node_type_to_id[ntype] homogeneous_nodes.append(ids + self.node_type_offset[ntype_id]) + if timestamps is not None: + homogeneous_timestamps.append(timestamps[ntype]) + if timestamps is not None: + return torch.cat(homogeneous_nodes), torch.cat( + homogeneous_timestamps + ) return torch.cat(homogeneous_nodes) def _convert_to_sampled_subgraph( @@ -830,7 +837,7 @@ def sample_layer_neighbors( else: return self._convert_to_sampled_subgraph(C_sampled_subgraph) - def _temporal_sample_neighbors( + def temporal_sample_neighbors( self, nodes: torch.Tensor, input_nodes_timestamp: torch.Tensor, @@ -887,26 +894,39 @@ def _temporal_sample_neighbors( Returns ------- - torch.classes.graphbolt.SampledSubgraph - The sampled C subgraph. + FusedSampledSubgraphImpl + The sampled subgraph. """ + if isinstance(nodes, dict): + nodes, input_nodes_timestamp = self._convert_to_homogeneous_nodes( + nodes, input_nodes_timestamp + ) + # Ensure nodes is 1-D tensor. self._check_sampler_arguments(nodes, fanouts, probs_name) has_original_eids = ( self.edge_attributes is not None and ORIGINAL_EDGE_ID in self.edge_attributes ) - return self._c_csc_graph.temporal_sample_neighbors( + C_sampled_subgraph = self._c_csc_graph.temporal_sample_neighbors( nodes, input_nodes_timestamp, fanouts.tolist(), replace, - False, has_original_eids, probs_name, node_timestamp_attr_name, edge_timestamp_attr_name, ) + # Broadcast the input nodes' timestamp to the sampled neighbors. + sampled_count = torch.diff(C_sampled_subgraph.indptr) + neighbors_timestamp = input_nodes_timestamp.repeat_interleave( + sampled_count + ) + return ( + self._convert_to_sampled_subgraph(C_sampled_subgraph), + neighbors_timestamp, + ) def sample_negative_edges_uniform( self, edge_type, node_pairs, negative_ratio