Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hetero subgraph with dispatching #43

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
- Added hetero subgraph kernel ([#43](https://github.com/pyg-team/pyg-lib/pull/43)
ZenoTan marked this conversation as resolved.
Show resolved Hide resolved
- Added `pyg::sampler::Mapper` utility for mapping global to local node indices ([#45](https://github.com/pyg-team/pyg-lib/pull/45)
- Added benchmark script ([#45](https://github.com/pyg-team/pyg-lib/pull/45)
- Added download script for benchmark data ([#44](https://github.com/pyg-team/pyg-lib/pull/44)
Expand Down
12 changes: 8 additions & 4 deletions pyg_lib/csrc/sampler/cpu/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ namespace sampler {
template <typename scalar_t>
class Mapper {
public:
using type = scalar_t;
ZenoTan marked this conversation as resolved.
Show resolved Hide resolved

Mapper(scalar_t num_nodes, scalar_t num_entries)
: num_nodes(num_nodes), num_entries(num_entries) {
// Use a some simple heuristic to determine whether we can use a std::vector
Expand All @@ -23,26 +25,28 @@ class Mapper {

void fill(const scalar_t* nodes_data, const scalar_t size) {
if (use_vec) {
for (scalar_t i = 0; i < size; ++i)
for (scalar_t i = 0; i < size; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me post my question here, I read some documents and based on my understanding scalar_t includes both float, double, int32, int64 during compile. But in a lot of our usecases we are iterating over integers. How does pytorch avoid compile float type for these functions? Is there a better way to be more specific to the data types here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some helper functions like is_integral for a dtype, but IMO it is mostly runtime checking. We can also use some STL type checking for compile time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AT_DISPATCH_INTEGRAL_TYPES call handles which types scalar_t can take (during compile time).

to_local_vec[nodes_data[i]] = i;
}
} else {
for (scalar_t i = 0; i < size; ++i)
for (scalar_t i = 0; i < size; ++i) {
to_local_map.insert({nodes_data[i], i});
}
}
}

void fill(const at::Tensor& nodes) {
fill(nodes.data_ptr<scalar_t>(), nodes.numel());
}

bool exists(const scalar_t& node) {
bool exists(const scalar_t& node) const {
if (use_vec)
return to_local_vec[node] >= 0;
else
return to_local_map.count(node) > 0;
}

scalar_t map(const scalar_t& node) {
scalar_t map(const scalar_t& node) const {
if (use_vec)
return to_local_vec[node];
else {
Expand Down
74 changes: 2 additions & 72 deletions pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/library.h>

#include "pyg_lib/csrc/sampler/cpu/mapper.h"
#include "pyg_lib/csrc/sampler/subgraph.h"
#include "pyg_lib/csrc/utils/cpu/convert.h"

namespace pyg {
Expand All @@ -15,78 +16,7 @@ std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph_kernel(
const at::Tensor& col,
const at::Tensor& nodes,
const bool return_edge_id) {
TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor");
TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor");
TORCH_CHECK(nodes.is_cpu(), "'nodes' must be a CPU tensor");

const auto num_nodes = rowptr.size(0) - 1;
const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1});
at::Tensor out_col;
c10::optional<at::Tensor> out_edge_id = c10::nullopt;

AT_DISPATCH_INTEGRAL_TYPES(nodes.scalar_type(), "subgraph_kernel", [&] {
auto mapper = pyg::sampler::Mapper<scalar_t>(num_nodes, nodes.size(0));
mapper.fill(nodes);

const auto rowptr_data = rowptr.data_ptr<scalar_t>();
const auto col_data = col.data_ptr<scalar_t>();
const auto nodes_data = nodes.data_ptr<scalar_t>();

// We first iterate over all nodes and collect information about the number
// of edges in the induced subgraph.
const auto deg = rowptr.new_empty({nodes.size(0)});
auto deg_data = deg.data_ptr<scalar_t>();
auto grain_size = at::internal::GRAIN_SIZE;
at::parallel_for(0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) {
for (size_t i = _s; i < _e; ++i) {
const auto v = nodes_data[i];
// Iterate over all neighbors and check if they are part of `nodes`:
scalar_t d = 0;
for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) {
if (mapper.exists(col_data[j]))
d++;
}
deg_data[i] = d;
}
});

auto out_rowptr_data = out_rowptr.data_ptr<scalar_t>();
out_rowptr_data[0] = 0;
auto tmp = out_rowptr.narrow(0, 1, nodes.size(0));
at::cumsum_out(tmp, deg, /*dim=*/0);

out_col = col.new_empty({out_rowptr_data[nodes.size(0)]});
auto out_col_data = out_col.data_ptr<scalar_t>();
scalar_t* out_edge_id_data;
if (return_edge_id) {
out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]});
out_edge_id_data = out_edge_id.value().data_ptr<scalar_t>();
}

// Customize `grain_size` based on the work each thread does (it will need
// to find `col.size(0) / nodes.size(0)` neighbors on average).
// TODO Benchmark this customization
grain_size = std::max<int64_t>(out_col.size(0) / nodes.size(0), 1);
grain_size = at::internal::GRAIN_SIZE / grain_size;
at::parallel_for(0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) {
for (scalar_t i = _s; i < _e; ++i) {
const auto v = nodes_data[i];
// Iterate over all neighbors and check if they are part of `nodes`:
scalar_t offset = out_rowptr_data[i];
for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) {
const auto w = mapper.map(col_data[j]);
if (w >= 0) {
out_col_data[offset] = w;
if (return_edge_id)
out_edge_id_data[offset] = j;
offset++;
}
}
}
});
});

return std::make_tuple(out_rowptr, out_col, out_edge_id);
return subgraph_bipartite(rowptr, col, nodes, nodes, return_edge_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code structure looks a little weird to me because csrc/sampler/cpu/subgraph_kernel exists for register TORCH_LIBRARY_IMPL and it is using a general implementation in csr/sampler/subgraph.cpp. How about reorganize the code like this:

csr
  - ops
    # all ops expose for pytorch.
    - sampler
  # all general graph operation.
  - sampler

We don't need to refactor the code structure now. But want to hear your opinion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, seems subgraph.cpp also defines library. Why not merge them together since sampler/subgraph.cpp also runs on cpu only.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could follow the style in other pyg repos: put CPU/GPU specific impl in separate folders and provide common interface in a higher directory.

}

} // namespace
Expand Down
154 changes: 153 additions & 1 deletion pyg_lib/csrc/sampler/subgraph.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,99 @@
#include "subgraph.h"
#include <pyg_lib/csrc/utils/hetero_dispatch.h>

#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>

#include <functional>

namespace pyg {
namespace sampler {

template <typename T>
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
subgraph_with_mapper(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& nodes,
const Mapper<T>& mapper,
const bool return_edge_id) {
const auto num_nodes = rowptr.size(0) - 1;
const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1});
at::Tensor out_col;
c10::optional<at::Tensor> out_edge_id = c10::nullopt;

AT_DISPATCH_INTEGRAL_TYPES(
nodes.scalar_type(), "subgraph_kernel_with_mapper", [&] {
const auto rowptr_data = rowptr.data_ptr<scalar_t>();
const auto col_data = col.data_ptr<scalar_t>();
const auto nodes_data = nodes.data_ptr<scalar_t>();

// We first iterate over all nodes and collect information about the
// number of edges in the induced subgraph.
const auto deg = rowptr.new_empty({nodes.size(0)});
auto deg_data = deg.data_ptr<scalar_t>();
auto grain_size = at::internal::GRAIN_SIZE;
at::parallel_for(
0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) {
for (size_t i = _s; i < _e; ++i) {
const auto v = nodes_data[i];
// Iterate over all neighbors and check if they are part of
// `nodes`:
scalar_t d = 0;
for (size_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) {
if (mapper.exists(col_data[j]))
d++;
}
deg_data[i] = d;
}
});

auto out_rowptr_data = out_rowptr.data_ptr<scalar_t>();
out_rowptr_data[0] = 0;
auto tmp = out_rowptr.narrow(0, 1, nodes.size(0));
at::cumsum_out(tmp, deg, /*dim=*/0);

out_col = col.new_empty({out_rowptr_data[nodes.size(0)]});
auto out_col_data = out_col.data_ptr<scalar_t>();
scalar_t* out_edge_id_data;
if (return_edge_id) {
out_edge_id = col.new_empty({out_rowptr_data[nodes.size(0)]});
out_edge_id_data = out_edge_id.value().data_ptr<scalar_t>();
}

// Customize `grain_size` based on the work each thread does (it will
// need to find `col.size(0) / nodes.size(0)` neighbors on average).
// TODO Benchmark this customization
grain_size = std::max<int64_t>(out_col.size(0) / nodes.size(0), 1);
grain_size = at::internal::GRAIN_SIZE / grain_size;
at::parallel_for(
0, nodes.size(0), grain_size, [&](int64_t _s, int64_t _e) {
for (scalar_t i = _s; i < _e; ++i) {
const auto v = nodes_data[i];
// Iterate over all neighbors and check if they
// are part of `nodes`:
scalar_t offset = out_rowptr_data[i];
for (scalar_t j = rowptr_data[v]; j < rowptr_data[v + 1]; ++j) {
const auto w = mapper.map(col_data[j]);
if (w >= 0) {
out_col_data[offset] = w;
if (return_edge_id)
out_edge_id_data[offset] = j;
offset++;
}
}
}
});
});

return std::make_tuple(out_rowptr, out_col, out_edge_id);
}

std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph(
const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& nodes,
const bool return_edge_id) {
at::TensorArg rowptr_t{rowptr, "rowtpr", 1};
at::TensorArg rowptr_t{rowptr, "rowptr", 1};
at::TensorArg col_t{col, "col", 1};
at::TensorArg nodes_t{nodes, "nodes", 1};

Expand All @@ -25,10 +107,80 @@ std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph(
return op.call(rowptr, col, nodes, return_edge_id);
}

std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
subgraph_bipartite(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& src_nodes,
const at::Tensor& dst_nodes,
const bool return_edge_id) {
TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor");
TORCH_CHECK(col.is_cpu(), "'col' must be a CPU tensor");
TORCH_CHECK(src_nodes.is_cpu(), "'src_nodes' must be a CPU tensor");
TORCH_CHECK(dst_nodes.is_cpu(), "'dst_nodes' must be a CPU tensor");

const auto num_nodes = rowptr.size(0) - 1;
at::Tensor out_rowptr, out_col;
c10::optional<at::Tensor> out_edge_id;

AT_DISPATCH_INTEGRAL_TYPES(
src_nodes.scalar_type(), "subgraph_bipartite", [&] {
// TODO: at::max parallel but still a little expensive
Mapper<scalar_t> mapper(at::max(col).item<scalar_t>() + 1,
dst_nodes.size(0));
mapper.fill(dst_nodes);

auto res = subgraph_with_mapper<scalar_t>(rowptr, col, src_nodes,
mapper, return_edge_id);
out_rowptr = std::get<0>(res);
out_col = std::get<1>(res);
out_edge_id = std::get<2>(res);
});

return {out_rowptr, out_col, out_edge_id};
}

c10::Dict<utils::edge_t,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually would have expected we return a tuple of dictionaries, similar to how the input looks like.

std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, the output should be a tuple of dictionaries (similar to the input).

hetero_subgraph(const utils::edge_tensor_dict_t& rowptr,
const utils::edge_tensor_dict_t& col,
const utils::node_tensor_dict_t& src_nodes,
const utils::node_tensor_dict_t& dst_nodes,
const c10::Dict<utils::edge_t, bool>& return_edge_id) {
// Define the bipartite implementation as a std function to pass the type
// check
std::function<std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>(
const at::Tensor&, const at::Tensor&, const at::Tensor&,
const at::Tensor&, bool)>
func = subgraph_bipartite;

// Construct an operator
utils::HeteroDispatchOp<decltype(func)> op(rowptr, col, func);

// Construct dispatchable arguments
utils::HeteroDispatchArg<utils::node_tensor_dict_t, at::Tensor,
utils::NodeSrcMode>
src_nodes_arg(src_nodes);
utils::HeteroDispatchArg<utils::node_tensor_dict_t, at::Tensor,
utils::NodeDstMode>
dst_nodes_arg(dst_nodes);
utils::HeteroDispatchArg<c10::Dict<utils::edge_t, bool>, bool,
utils::EdgeMode>
edge_id_arg(return_edge_id);
return op(src_nodes_arg, dst_nodes_arg, edge_id_arg);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::subgraph(Tensor rowptr, Tensor col, Tensor "
"nodes, bool return_edge_id) -> (Tensor, Tensor, Tensor?)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::subgraph_bipartite(Tensor rowptr, Tensor col, Tensor "
"src_nodes, Tensor dst_nodes, bool return_edge_id) -> (Tensor, Tensor, "
"Tensor?)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::hetero_subgraph(Dict(str, Tensor) rowptr, Dict(str, "
"Tensor) col, Dict(str, Tensor) nodes, Dict(str, bool) "
"return_edge_id) -> Dict(str, (Tensor, Tensor, Tensor?))"));
}

} // namespace sampler
Expand Down
30 changes: 30 additions & 0 deletions pyg_lib/csrc/sampler/subgraph.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/Parallel.h>

#include "pyg_lib/csrc/macros.h"
#include "pyg_lib/csrc/sampler/cpu/mapper.h"
#include "pyg_lib/csrc/utils/types.h"

namespace pyg {
namespace sampler {
Expand All @@ -15,5 +19,31 @@ PYG_API std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph(
const at::Tensor& nodes,
const bool return_edge_id = true);

// A bipartite version of the above function.
PYG_API std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
subgraph_bipartite(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& src_nodes,
const at::Tensor& dst_nodes,
const bool return_edge_id);

// A heterogeneous version of the above function.
// Returns a dict from each relation type to its result
PYG_API c10::Dict<utils::edge_t,
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>>
hetero_subgraph(const utils::edge_tensor_dict_t& rowptr,
const utils::edge_tensor_dict_t& col,
const utils::node_tensor_dict_t& src_nodes,
const utils::node_tensor_dict_t& dst_nodes,
const c10::Dict<utils::edge_t, bool>& return_edge_id);

template <typename T>
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
subgraph_with_mapper(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& nodes,
const Mapper<T>& mapper,
const bool return_edge_id);

} // namespace sampler
} // namespace pyg
Loading