-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hetero subgraph with dispatching #43
base: master
Are you sure you want to change the base?
Conversation
Codecov Report
@@ Coverage Diff @@
## master #43 +/- ##
==========================================
- Coverage 97.27% 96.51% -0.76%
==========================================
Files 10 12 +2
Lines 220 287 +67
==========================================
+ Hits 214 277 +63
- Misses 6 10 +4
Continue to review full report at Codecov.
|
for more information, see https://pre-commit.ci
@@ -23,26 +25,28 @@ class Mapper { | |||
|
|||
void fill(const scalar_t* nodes_data, const scalar_t size) { | |||
if (use_vec) { | |||
for (scalar_t i = 0; i < size; ++i) | |||
for (scalar_t i = 0; i < size; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me post my question here, I read some documents and based on my understanding scalar_t includes both float, double, int32, int64 during compile. But in a lot of our usecases we are iterating over integers. How does pytorch avoid compile float type for these functions? Is there a better way to be more specific to the data types here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some helper functions like is_integral
for a dtype
, but IMO it is mostly runtime checking. We can also use some STL type checking for compile time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The AT_DISPATCH_INTEGRAL_TYPES
call handles which types scalar_t
can take (during compile time).
}); | ||
|
||
return std::make_tuple(out_rowptr, out_col, out_edge_id); | ||
return subgraph_bipartite(rowptr, col, nodes, nodes, return_edge_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code structure looks a little weird to me because csrc/sampler/cpu/subgraph_kernel
exists for register TORCH_LIBRARY_IMPL
and it is using a general implementation in csr/sampler/subgraph.cpp
. How about reorganize the code like this:
csr
- ops
# all ops expose for pytorch.
- sampler
# all general graph operation.
- sampler
We don't need to refactor the code structure now. But want to hear your opinion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm, seems subgraph.cpp also defines library. Why not merge them together since sampler/subgraph.cpp also runs on cpu only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could follow the style in other pyg repos: put CPU/GPU specific impl in separate folders and provide common interface in a higher directory.
}); | ||
|
||
return std::make_tuple(out_rowptr, out_col, out_edge_id); | ||
return subgraph_bipartite(rowptr, col, nodes, nodes, return_edge_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm, seems subgraph.cpp also defines library. Why not merge them together since sampler/subgraph.cpp also runs on cpu only.
|
||
auto res = subgraph_with_mapper<scalar_t>(rowptr, col, src_nodes, | ||
mapper, return_edge_id); | ||
out_rowptr = std::get<0>(res); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or maybe we could do std::tie(out_powptr, out_col, out_edge_id) = res
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
pyg_lib/csrc/sampler/subgraph.cpp
Outdated
|
||
for (const auto& kv : rowptr) { | ||
const auto& edge_type = kv.key(); | ||
bool pass = filter_args_by_edge(edge_type, src_nodes_arg, dst_nodes_arg, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd still prefer
pass = src_nodes_args.filter_by_edge(edge_type) && dst_nodes_args.filter_by_edge(edge_type) && edge_id_arg.filter_by_edge(edge_type)
or from an efficiency point of view.
auto dst = get_dst(edge_type)
auto src = get_src(edge_type)
bool pass = return_edge_id.counts(edge_type) > 0 && src_nodes.counts(src) > 0 && dst_nodes.counts(dst) > 0;
pyg_lib/csrc/sampler/subgraph.cpp
Outdated
const auto& r = rowptr.at(edge_type); | ||
const auto& c = col.at(edge_type); | ||
res.insert(edge_type, | ||
subgraph_bipartite(r, c, std::get<0>(vals), std::get<1>(vals), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and here would just be
subgraph_bipartite(r, c, src_nodes.at(src), dst_nodes.at(dst), return_edge_id.at(edge_type));
pyg_lib/csrc/sampler/subgraph.cpp
Outdated
@@ -25,10 +28,42 @@ std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph( | |||
return op.call(rowptr, col, nodes, return_edge_id); | |||
} | |||
|
|||
c10::Dict<utils::edge_t, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually would have expected we return a tuple of dictionaries, similar to how the input looks like.
@@ -23,26 +25,28 @@ class Mapper { | |||
|
|||
void fill(const scalar_t* nodes_data, const scalar_t size) { | |||
if (use_vec) { | |||
for (scalar_t i = 0; i < size; ++i) | |||
for (scalar_t i = 0; i < size; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The AT_DISPATCH_INTEGRAL_TYPES
call handles which types scalar_t
can take (during compile time).
offset++; | ||
} | ||
AT_DISPATCH_INTEGRAL_TYPES( | ||
nodes.scalar_type(), "subgraph_kernel_with_mapper", [&] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this a one-liner again?
} // namespace | ||
|
||
TORCH_LIBRARY_IMPL(pyg, CPU, m) { | ||
m.impl(TORCH_SELECTIVE_NAME("pyg::subgraph"), TORCH_FN(subgraph_kernel)); | ||
m.impl(TORCH_SELECTIVE_NAME("pyg::subgraph_bipartite"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason we want to expose that? Looks more like an internal function to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the user want to build a subgraph of a bipartite graph then he can use it.
} | ||
|
||
c10::Dict<utils::EdgeType, | ||
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, the output should be a tuple of dictionaries (similar to the input).
if (pass) { | ||
const auto& r = rowptr.at(edge_type); | ||
const auto& c = col.at(edge_type); | ||
res.insert(edge_type, subgraph_bipartite( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we user the mapper here? Other-wise, we will re-map across every edge type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it has a cost, but the mapper is more read-intensive. I will add a TODO here.
|
||
inline NodeType get_dst(const EdgeType& e) { | ||
return e.substr(e.find_last_of(SPLIT_TOKEN) + 1); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also add a function that maps tuples to strings and vice versa.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea.
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>> | ||
hetero_subgraph(const utils::EdgeTensorDict& rowptr, | ||
const utils::EdgeTensorDict& col, | ||
const utils::NodeTensorDict& src_nodes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why we have both src_nodes
and dst_nodes
. IMO, these can be safely merged as in https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.HeteroData.subgraph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separating src
and dst
is just to give some flexibility. We could also have the merged API though.
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
No description provided.