-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hetero subgraph with dispatching #43
base: master
Are you sure you want to change the base?
Changes from all commits
2a50f4c
cba41dd
7494c2d
f87e8f3
8a1cb14
08faf36
c59be3f
dd2e8b6
d1c98cc
0a4bc01
1f20313
f8f9059
c4c446a
487eaf6
48c115d
663a675
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,96 +3,139 @@ | |
#include <torch/library.h> | ||
|
||
#include "pyg_lib/csrc/sampler/cpu/mapper.h" | ||
#include "pyg_lib/csrc/sampler/subgraph.h" | ||
#include "pyg_lib/csrc/utils/cpu/convert.h" | ||
|
||
namespace pyg { | ||
namespace sampler { | ||
|
||
namespace { | ||
|
||
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph_kernel( | ||
const at::Tensor& rowptr, | ||
const at::Tensor& col, | ||
const at::Tensor& nodes, | ||
const bool return_edge_id) { | ||
TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor"); | ||
TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor"); | ||
TORCH_CHECK(nodes.is_cpu(), "'nodes' must be a CPU tensor"); | ||
|
||
template <typename T> | ||
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> | ||
subgraph_with_mapper(const at::Tensor& rowptr, | ||
const at::Tensor& col, | ||
const at::Tensor& nodes, | ||
const Mapper<T>& mapper, | ||
const bool return_edge_id) { | ||
const auto num_nodes = rowptr.size(0) - 1; | ||
const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1}); | ||
at::Tensor out_col; | ||
c10::optional<at::Tensor> out_edge_id = c10::nullopt; | ||
|
||
AT_DISPATCH_INTEGRAL_TYPES(nodes.scalar_type(), "subgraph_kernel", [&] { | ||
auto mapper = pyg::sampler::Mapper<scalar_t>(num_nodes, nodes.size(0)); | ||
mapper.fill(nodes); | ||
|
||
const auto rowptr_data = rowptr.data_ptr<scalar_t>(); | ||
const auto col_data = col.data_ptr<scalar_t>(); | ||
const auto nodes_data = nodes.data_ptr<scalar_t>(); | ||
|
||
// We first iterate over all nodes and collect information about the number | ||
// of edges in the induced subgraph. | ||
const auto deg = rowptr.new_empty({nodes.size(0)}); | ||
auto deg_data = deg.data_ptr<scalar_t>(); | ||
auto grain_size = at::internal::GRAIN_SIZE; | ||
at::parallel_for(0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { | ||
for (size_t i = _s; i < _e; ++i) { | ||
const auto v = nodes_data[i]; | ||
// Iterate over all neighbors and check if they are part of `nodes`: | ||
scalar_t d = 0; | ||
for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { | ||
if (mapper.exists(col_data[j])) | ||
d++; | ||
} | ||
deg_data[i] = d; | ||
} | ||
}); | ||
|
||
auto out_rowptr_data = out_rowptr.data_ptr<scalar_t>(); | ||
out_rowptr_data[0] = 0; | ||
auto tmp = out_rowptr.narrow(0, 1, nodes.size(0)); | ||
at::cumsum_out(tmp, deg, /*dim=*/0); | ||
|
||
out_col = col.new_empty({out_rowptr_data[nodes.size(0)]}); | ||
auto out_col_data = out_col.data_ptr<scalar_t>(); | ||
scalar_t* out_edge_id_data; | ||
if (return_edge_id) { | ||
out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]}); | ||
out_edge_id_data = out_edge_id.value().data_ptr<scalar_t>(); | ||
} | ||
|
||
// Customize `grain_size` based on the work each thread does (it will need | ||
// to find `col.size(0) / nodes.size(0)` neighbors on average). | ||
// TODO Benchmark this customization | ||
grain_size = std::max<int64_t>(out_col.size(0) / nodes.size(0), 1); | ||
grain_size = at::internal::GRAIN_SIZE / grain_size; | ||
at::parallel_for(0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { | ||
for (scalar_t i = _s; i < _e; ++i) { | ||
const auto v = nodes_data[i]; | ||
// Iterate over all neighbors and check if they are part of `nodes`: | ||
scalar_t offset = out_rowptr_data[i]; | ||
for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { | ||
const auto w = mapper.map(col_data[j]); | ||
if (w >= 0) { | ||
out_col_data[offset] = w; | ||
if (return_edge_id) | ||
out_edge_id_data[offset] = j; | ||
offset++; | ||
} | ||
AT_DISPATCH_INTEGRAL_TYPES( | ||
nodes.scalar_type(), "subgraph_kernel_with_mapper", [&] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make this a one-liner again? |
||
const auto rowptr_data = rowptr.data_ptr<scalar_t>(); | ||
const auto col_data = col.data_ptr<scalar_t>(); | ||
const auto nodes_data = nodes.data_ptr<scalar_t>(); | ||
|
||
// We first iterate over all nodes and collect information about the | ||
// number of edges in the induced subgraph. | ||
const auto deg = rowptr.new_empty({nodes.size(0)}); | ||
auto deg_data = deg.data_ptr<scalar_t>(); | ||
auto grain_size = at::internal::GRAIN_SIZE; | ||
at::parallel_for( | ||
0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { | ||
for (size_t i = _s; i < _e; ++i) { | ||
const auto v = nodes_data[i]; | ||
// Iterate over all neighbors and check if they are part of | ||
// `nodes`: | ||
scalar_t d = 0; | ||
for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { | ||
if (mapper.exists(col_data[j])) | ||
d++; | ||
} | ||
deg_data[i] = d; | ||
} | ||
}); | ||
|
||
auto out_rowptr_data = out_rowptr.data_ptr<scalar_t>(); | ||
out_rowptr_data[0] = 0; | ||
auto tmp = out_rowptr.narrow(0, 1, nodes.size(0)); | ||
at::cumsum_out(tmp, deg, /*dim=*/0); | ||
|
||
out_col = col.new_empty({out_rowptr_data[nodes.size(0)]}); | ||
auto out_col_data = out_col.data_ptr<scalar_t>(); | ||
scalar_t* out_edge_id_data; | ||
if (return_edge_id) { | ||
out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]}); | ||
out_edge_id_data = out_edge_id.value().data_ptr<scalar_t>(); | ||
} | ||
} | ||
}); | ||
}); | ||
|
||
// Customize `grain_size` based on the work each thread does (it will | ||
// need to find `col.size(0) / nodes.size(0)` neighbors on average). | ||
// TODO Benchmark this customization | ||
grain_size = std::max<int64_t>(out_col.size(0) / nodes.size(0), 1); | ||
grain_size = at::internal::GRAIN_SIZE / grain_size; | ||
at::parallel_for( | ||
0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) { | ||
for (scalar_t i = _s; i < _e; ++i) { | ||
const auto v = nodes_data[i]; | ||
// Iterate over all neighbors and check if they | ||
// are part of `nodes`: | ||
scalar_t offset = out_rowptr_data[i]; | ||
for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) { | ||
const auto w = mapper.map(col_data[j]); | ||
if (w >= 0) { | ||
out_col_data[offset] = w; | ||
if (return_edge_id) | ||
out_edge_id_data[offset] = j; | ||
offset++; | ||
} | ||
} | ||
} | ||
}); | ||
}); | ||
|
||
return std::make_tuple(out_rowptr, out_col, out_edge_id); | ||
} | ||
|
||
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> | ||
subgraph_bipartite_kernel(const at::Tensor& rowptr, | ||
const at::Tensor& col, | ||
const at::Tensor& src_nodes, | ||
const at::Tensor& dst_nodes, | ||
const bool return_edge_id) { | ||
TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor"); | ||
TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor"); | ||
TORCH_CHECK(src_nodes.is_cpu(), "'src_nodes' must be a CPU tensor"); | ||
TORCH_CHECK(dst_nodes.is_cpu(), "'dst_nodes' must be a CPU tensor"); | ||
|
||
const auto num_nodes = rowptr.size(0) - 1; | ||
at::Tensor out_rowptr, out_col; | ||
c10::optional<at::Tensor> out_edge_id; | ||
|
||
AT_DISPATCH_INTEGRAL_TYPES( | ||
src_nodes.scalar_type(), "subgraph_bipartite_kernel", [&] { | ||
// TODO: at::max parallel but still a little expensive | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a |
||
Mapper<scalar_t> mapper(at::max(col).item<scalar_t>() + 1, | ||
dst_nodes.size(0)); | ||
mapper.fill(dst_nodes); | ||
|
||
auto res = subgraph_with_mapper<scalar_t>(rowptr, col, src_nodes, | ||
mapper, return_edge_id); | ||
out_rowptr = std::get<0>(res); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or maybe we could do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
out_col = std::get<1>(res); | ||
out_edge_id = std::get<2>(res); | ||
}); | ||
|
||
return {out_rowptr, out_col, out_edge_id}; | ||
} | ||
|
||
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph_kernel( | ||
const at::Tensor& rowptr, | ||
const at::Tensor& col, | ||
const at::Tensor& nodes, | ||
const bool return_edge_id) { | ||
return subgraph_bipartite_kernel(rowptr, col, nodes, nodes, return_edge_id); | ||
} | ||
|
||
} // namespace | ||
|
||
TORCH_LIBRARY_IMPL(pyg, CPU, m) { | ||
m.impl(TORCH_SELECTIVE_NAME("pyg::subgraph"), TORCH_FN(subgraph_kernel)); | ||
m.impl(TORCH_SELECTIVE_NAME("pyg::subgraph_bipartite"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason we want to expose that? Looks more like an internal function to me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the user want to build a subgraph of a bipartite graph then he can use it. |
||
TORCH_FN(subgraph_bipartite_kernel)); | ||
} | ||
|
||
} // namespace sampler | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,11 @@ | ||
#include "subgraph.h" | ||
#include <pyg_lib/csrc/utils/hetero_dispatch.h> | ||
|
||
#include <ATen/core/dispatch/Dispatcher.h> | ||
#include <torch/library.h> | ||
|
||
#include <functional> | ||
|
||
namespace pyg { | ||
namespace sampler { | ||
|
||
|
@@ -11,7 +14,7 @@ std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph( | |
const at::Tensor& col, | ||
const at::Tensor& nodes, | ||
const bool return_edge_id) { | ||
at::TensorArg rowptr_t{rowptr, "rowtpr", 1}; | ||
at::TensorArg rowptr_t{rowptr, "rowptr", 1}; | ||
at::TensorArg col_t{col, "col", 1}; | ||
at::TensorArg nodes_t{nodes, "nodes", 1}; | ||
|
||
|
@@ -25,10 +28,76 @@ std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph( | |
return op.call(rowptr, col, nodes, return_edge_id); | ||
} | ||
|
||
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> | ||
subgraph_bipartite(const at::Tensor& rowptr, | ||
const at::Tensor& col, | ||
const at::Tensor& src_nodes, | ||
const at::Tensor& dst_nodes, | ||
const bool return_edge_id) { | ||
at::TensorArg rowptr_t{rowptr, "rowptr", 1}; | ||
at::TensorArg col_t{col, "col", 1}; | ||
at::TensorArg src_nodes_t{src_nodes, "src_nodes", 1}; | ||
at::TensorArg dst_nodes_t{dst_nodes, "dst_nodes", 1}; | ||
|
||
at::CheckedFrom c = "subgraph_bipartite"; | ||
at::checkAllDefined(c, {rowptr_t, col_t, src_nodes_t, dst_nodes_t}); | ||
at::checkAllSameType(c, {rowptr_t, col_t, src_nodes_t, dst_nodes_t}); | ||
|
||
static auto op = c10::Dispatcher::singleton() | ||
.findSchemaOrThrow("pyg::subgraph_bipartite", "") | ||
.typed<decltype(subgraph_bipartite)>(); | ||
return op.call(rowptr, col, src_nodes, dst_nodes, return_edge_id); | ||
} | ||
|
||
c10::Dict<utils::EdgeType, | ||
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO, the output should be a tuple of dictionaries (similar to the input). |
||
hetero_subgraph(const utils::EdgeTensorDict& rowptr, | ||
const utils::EdgeTensorDict& col, | ||
const utils::NodeTensorDict& src_nodes, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why we have both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Separating |
||
const utils::NodeTensorDict& dst_nodes, | ||
const c10::Dict<utils::EdgeType, bool>& return_edge_id) { | ||
c10::Dict<utils::EdgeType, | ||
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>> | ||
res; | ||
|
||
// Construct dispatchable arguments | ||
utils::HeteroDispatchArg<utils::NodeTensorDict, at::Tensor, | ||
utils::NodeSrcMode> | ||
src_nodes_arg(src_nodes); | ||
utils::HeteroDispatchArg<utils::NodeTensorDict, at::Tensor, | ||
utils::NodeDstMode> | ||
dst_nodes_arg(dst_nodes); | ||
utils::HeteroDispatchArg<c10::Dict<utils::EdgeType, bool>, bool, | ||
utils::EdgeMode> | ||
edge_id_arg(return_edge_id); | ||
|
||
for (const auto& kv : rowptr) { | ||
const auto& edge_type = kv.key(); | ||
bool pass = src_nodes_arg.filter_by_edge(edge_type) && | ||
dst_nodes_arg.filter_by_edge(edge_type) && | ||
edge_id_arg.filter_by_edge(edge_type); | ||
if (pass) { | ||
const auto& r = rowptr.at(edge_type); | ||
const auto& c = col.at(edge_type); | ||
res.insert(edge_type, subgraph_bipartite( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we user the mapper here? Other-wise, we will re-map across every edge type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it has a cost, but the mapper is more read-intensive. I will add a TODO here. |
||
r, c, src_nodes_arg.value_by_edge(edge_type), | ||
dst_nodes_arg.value_by_edge(edge_type), | ||
edge_id_arg.value_by_edge(edge_type))); | ||
} | ||
} | ||
|
||
return res; | ||
} | ||
|
||
TORCH_LIBRARY_FRAGMENT(pyg, m) { | ||
m.def(TORCH_SELECTIVE_SCHEMA( | ||
"pyg::subgraph(Tensor rowptr, Tensor col, Tensor " | ||
"nodes, bool return_edge_id) -> (Tensor, Tensor, Tensor?)")); | ||
m.def(TORCH_SELECTIVE_SCHEMA( | ||
"pyg::subgraph_bipartite(Tensor rowptr, Tensor col, Tensor " | ||
"src_nodes, Tensor dst_nodes, bool return_edge_id) -> (Tensor, Tensor, " | ||
"Tensor?)")); | ||
m.def("hetero_subgraph", hetero_subgraph); | ||
} | ||
|
||
} // namespace sampler | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me post my question here, I read some documents and based on my understanding scalar_t includes both float, double, int32, int64 during compile. But in a lot of our usecases we are iterating over integers. How does pytorch avoid compile float type for these functions? Is there a better way to be more specific to the data types here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some helper functions like
is_integral
for adtype
, but IMO it is mostly runtime checking. We can also use some STL type checking for compile time.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
AT_DISPATCH_INTEGRAL_TYPES
call handles which typesscalar_t
can take (during compile time).