From 83afb8b005f455190a97a5beafe7f84d2f110d31 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 7 Oct 2024 17:11:06 -0400 Subject: [PATCH] feat: support multiple frame for edge_index OP (#26) Signed-off-by: Jinzhe Zeng --- op/edge_index.cc | 92 +++++++++++++++++++++++++++++------------------ tests/test_op.py | 93 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 151 insertions(+), 34 deletions(-) create mode 100644 tests/test_op.py diff --git a/op/edge_index.cc b/op/edge_index.cc index 88c706c..240794d 100644 --- a/op/edge_index.cc +++ b/op/edge_index.cc @@ -3,51 +3,75 @@ #include -torch::Tensor edge_index(const torch::Tensor &nlist_tensor, - const torch::Tensor &atype_tensor, - const torch::Tensor &mm_tensor) { +torch::Tensor edge_index_kernel(const torch::Tensor &nlist_tensor, + const torch::Tensor &atype_tensor, + const torch::Tensor &mm_tensor) { torch::Tensor nlist_tensor_ = nlist_tensor.cpu().contiguous(); torch::Tensor atype_tensor_ = atype_tensor.cpu().contiguous(); torch::Tensor mm_tensor_ = mm_tensor.cpu().contiguous(); - const int64_t nloc = nlist_tensor_.size(0); - const int64_t nnei = nlist_tensor_.size(1); - const int64_t nall = atype_tensor_.size(0); + if (nlist_tensor_.dim() == 2) { + nlist_tensor_ = + nlist_tensor_.view({1, nlist_tensor_.size(0), nlist_tensor_.size(1)}); + if (atype_tensor_.dim() != 1) { + throw std::invalid_argument("atype_tensor must be 1D"); + } + atype_tensor_ = atype_tensor_.view({1, atype_tensor_.size(0)}); + } else if (nlist_tensor_.dim() == 3) { + if (atype_tensor_.dim() != 2) { + throw std::invalid_argument("atype_tensor must be 2D"); + } + } else { + throw std::invalid_argument("nlist_tensor must be 2D or 3D"); + } + + const int64_t nf = nlist_tensor_.size(0); + const int64_t nloc = nlist_tensor_.size(1); + const int64_t nnei = nlist_tensor_.size(2); + if (atype_tensor_.size(0) != nf) { + throw std::invalid_argument( + "atype_tensor must have the same size as nlist_tensor"); + } + const int64_t nall = atype_tensor_.size(1); const int64_t nmm = mm_tensor_.size(0); int64_t *nlist = nlist_tensor_.view({-1}).data_ptr(); int64_t *atype = atype_tensor_.view({-1}).data_ptr(); int64_t *mm = mm_tensor_.view({-1}).data_ptr(); std::vector edge_index; - edge_index.reserve(nloc * nnei * 2); + edge_index.reserve(nf * nloc * nnei * 2); - for (int64_t ii = 0; ii < nloc; ii++) { - for (int64_t jj = 0; jj < nnei; jj++) { - int64_t idx = ii * nnei + jj; - int64_t kk = nlist[idx]; - if (kk < 0) { - continue; - } - // check if both atype[ii] and atype[kk] are in mm - bool in_mm1 = false; - for (int64_t mm_idx = 0; mm_idx < nmm; mm_idx++) { - if (atype[ii] == mm[mm_idx]) { - in_mm1 = true; - break; + for (int64_t ff = 0; ff < nf; ff++) { + for (int64_t ii = 0; ii < nloc; ii++) { + for (int64_t jj = 0; jj < nnei; jj++) { + int64_t idx = ii * nnei + jj; + int64_t kk = nlist[idx]; + if (kk < 0) { + continue; } - } - bool in_mm2 = false; - for (int64_t mm_idx = 0; mm_idx < nmm; mm_idx++) { - if (atype[kk] == mm[mm_idx]) { - in_mm2 = true; - break; + int64_t global_kk = ff * nall + kk; + int64_t global_ii = ff * nall + ii; + // check if both atype[ii] and atype[kk] are in mm + bool in_mm1 = false; + for (int64_t mm_idx = 0; mm_idx < nmm; mm_idx++) { + if (atype[global_ii] == mm[mm_idx]) { + in_mm1 = true; + break; + } } + bool in_mm2 = false; + for (int64_t mm_idx = 0; mm_idx < nmm; mm_idx++) { + if (atype[global_kk] == mm[mm_idx]) { + in_mm2 = true; + break; + } + } + if (in_mm1 && in_mm2) { + continue; + } + // add edge + edge_index.push_back(global_kk); + edge_index.push_back(global_ii); } - if (in_mm1 && in_mm2) { - continue; - } - // add edge - edge_index.push_back(kk); - edge_index.push_back(ii); } } // convert to tensor @@ -58,6 +82,6 @@ torch::Tensor edge_index(const torch::Tensor &nlist_tensor, return edge_index_tensor.to(nlist_tensor.device()); } -TORCH_LIBRARY(deepmd_gnn, m) { m.def("edge_index", edge_index); } +TORCH_LIBRARY(deepmd_gnn, m) { m.def("edge_index", edge_index_kernel); } // compatbility with old models freezed by deepmd_mace package -TORCH_LIBRARY(deepmd_mace, m) { m.def("mace_edge_index", edge_index); } +TORCH_LIBRARY(deepmd_mace, m) { m.def("mace_edge_index", edge_index_kernel); } diff --git a/tests/test_op.py b/tests/test_op.py new file mode 100644 index 0000000..4a1a495 --- /dev/null +++ b/tests/test_op.py @@ -0,0 +1,93 @@ +"""Test custom operations.""" + +import torch + +import deepmd_gnn.op # noqa: F401 + + +def test_one_frame() -> None: + """Test one frame.""" + nlist_ff = torch.tensor( + [ + [1, 2, -1, -1], + [2, 0, -1, -1], + [0, 1, -1, -1], + ], + dtype=torch.int64, + device="cpu", + ) + extended_atype_ff = torch.tensor( + [0, 1, 2], + dtype=torch.int64, + device="cpu", + ) + mm_types = [1, 2] + expected_edge_index = torch.tensor( + [ + [1, 0], + [2, 0], + [0, 1], + [0, 2], + ], + dtype=torch.int64, + device="cpu", + ) + + edge_index = torch.ops.deepmd_gnn.edge_index( + nlist_ff, + extended_atype_ff, + torch.tensor(mm_types, dtype=torch.int64, device="cpu"), + ) + + assert torch.equal(edge_index, expected_edge_index) + + +def test_two_frame() -> None: + """Test one frame.""" + nlist = torch.tensor( + [ + [ + [1, 2, -1, -1], + [2, 0, -1, -1], + [0, 1, -1, -1], + ], + [ + [1, 2, -1, -1], + [2, 0, -1, -1], + [0, 1, -1, -1], + ], + ], + dtype=torch.int64, + device="cpu", + ) + extended_atype = torch.tensor( + [ + [0, 1, 2], + [0, 1, 2], + ], + dtype=torch.int64, + device="cpu", + ) + mm_types = [1, 2] + expected_edge_index = torch.tensor( + [ + [1, 0], + [2, 0], + [0, 1], + [0, 2], + [4, 3], + [5, 3], + [3, 4], + [3, 5], + ], + dtype=torch.int64, + device="cpu", + ) + + edge_index = torch.ops.deepmd_gnn.edge_index( + nlist, + extended_atype, + torch.tensor(mm_types, dtype=torch.int64, device="cpu"), + ) + + assert torch.equal(edge_index, expected_edge_index)