Skip to content

Commit

Permalink
feat: support multiple frame for edge_index OP (#26)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Oct 7, 2024
1 parent ceec468 commit 83afb8b
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 34 deletions.
92 changes: 58 additions & 34 deletions op/edge_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,75 @@

#include <iostream>

torch::Tensor edge_index(const torch::Tensor &nlist_tensor,
const torch::Tensor &atype_tensor,
const torch::Tensor &mm_tensor) {
torch::Tensor edge_index_kernel(const torch::Tensor &nlist_tensor,
const torch::Tensor &atype_tensor,
const torch::Tensor &mm_tensor) {
torch::Tensor nlist_tensor_ = nlist_tensor.cpu().contiguous();
torch::Tensor atype_tensor_ = atype_tensor.cpu().contiguous();
torch::Tensor mm_tensor_ = mm_tensor.cpu().contiguous();
const int64_t nloc = nlist_tensor_.size(0);
const int64_t nnei = nlist_tensor_.size(1);
const int64_t nall = atype_tensor_.size(0);
if (nlist_tensor_.dim() == 2) {
nlist_tensor_ =
nlist_tensor_.view({1, nlist_tensor_.size(0), nlist_tensor_.size(1)});
if (atype_tensor_.dim() != 1) {
throw std::invalid_argument("atype_tensor must be 1D");
}
atype_tensor_ = atype_tensor_.view({1, atype_tensor_.size(0)});
} else if (nlist_tensor_.dim() == 3) {
if (atype_tensor_.dim() != 2) {
throw std::invalid_argument("atype_tensor must be 2D");
}
} else {
throw std::invalid_argument("nlist_tensor must be 2D or 3D");
}

const int64_t nf = nlist_tensor_.size(0);
const int64_t nloc = nlist_tensor_.size(1);
const int64_t nnei = nlist_tensor_.size(2);
if (atype_tensor_.size(0) != nf) {
throw std::invalid_argument(
"atype_tensor must have the same size as nlist_tensor");
}
const int64_t nall = atype_tensor_.size(1);
const int64_t nmm = mm_tensor_.size(0);
int64_t *nlist = nlist_tensor_.view({-1}).data_ptr<int64_t>();
int64_t *atype = atype_tensor_.view({-1}).data_ptr<int64_t>();
int64_t *mm = mm_tensor_.view({-1}).data_ptr<int64_t>();

std::vector<int64_t> edge_index;
edge_index.reserve(nloc * nnei * 2);
edge_index.reserve(nf * nloc * nnei * 2);

for (int64_t ii = 0; ii < nloc; ii++) {
for (int64_t jj = 0; jj < nnei; jj++) {
int64_t idx = ii * nnei + jj;
int64_t kk = nlist[idx];
if (kk < 0) {
continue;
}
// check if both atype[ii] and atype[kk] are in mm
bool in_mm1 = false;
for (int64_t mm_idx = 0; mm_idx < nmm; mm_idx++) {
if (atype[ii] == mm[mm_idx]) {
in_mm1 = true;
break;
for (int64_t ff = 0; ff < nf; ff++) {
for (int64_t ii = 0; ii < nloc; ii++) {
for (int64_t jj = 0; jj < nnei; jj++) {
int64_t idx = ii * nnei + jj;
int64_t kk = nlist[idx];
if (kk < 0) {
continue;
}
}
bool in_mm2 = false;
for (int64_t mm_idx = 0; mm_idx < nmm; mm_idx++) {
if (atype[kk] == mm[mm_idx]) {
in_mm2 = true;
break;
int64_t global_kk = ff * nall + kk;
int64_t global_ii = ff * nall + ii;
// check if both atype[ii] and atype[kk] are in mm
bool in_mm1 = false;
for (int64_t mm_idx = 0; mm_idx < nmm; mm_idx++) {
if (atype[global_ii] == mm[mm_idx]) {
in_mm1 = true;
break;
}
}
bool in_mm2 = false;
for (int64_t mm_idx = 0; mm_idx < nmm; mm_idx++) {
if (atype[global_kk] == mm[mm_idx]) {
in_mm2 = true;
break;
}
}
if (in_mm1 && in_mm2) {
continue;
}
// add edge
edge_index.push_back(global_kk);
edge_index.push_back(global_ii);
}
if (in_mm1 && in_mm2) {
continue;
}
// add edge
edge_index.push_back(kk);
edge_index.push_back(ii);
}
}
// convert to tensor
Expand All @@ -58,6 +82,6 @@ torch::Tensor edge_index(const torch::Tensor &nlist_tensor,
return edge_index_tensor.to(nlist_tensor.device());
}

TORCH_LIBRARY(deepmd_gnn, m) { m.def("edge_index", edge_index); }
TORCH_LIBRARY(deepmd_gnn, m) { m.def("edge_index", edge_index_kernel); }
// compatbility with old models freezed by deepmd_mace package
TORCH_LIBRARY(deepmd_mace, m) { m.def("mace_edge_index", edge_index); }
TORCH_LIBRARY(deepmd_mace, m) { m.def("mace_edge_index", edge_index_kernel); }
93 changes: 93 additions & 0 deletions tests/test_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Test custom operations."""

import torch

import deepmd_gnn.op # noqa: F401


def test_one_frame() -> None:
"""Test one frame."""
nlist_ff = torch.tensor(
[
[1, 2, -1, -1],
[2, 0, -1, -1],
[0, 1, -1, -1],
],
dtype=torch.int64,
device="cpu",
)
extended_atype_ff = torch.tensor(
[0, 1, 2],
dtype=torch.int64,
device="cpu",
)
mm_types = [1, 2]
expected_edge_index = torch.tensor(
[
[1, 0],
[2, 0],
[0, 1],
[0, 2],
],
dtype=torch.int64,
device="cpu",
)

edge_index = torch.ops.deepmd_gnn.edge_index(
nlist_ff,
extended_atype_ff,
torch.tensor(mm_types, dtype=torch.int64, device="cpu"),
)

assert torch.equal(edge_index, expected_edge_index)


def test_two_frame() -> None:
"""Test one frame."""
nlist = torch.tensor(
[
[
[1, 2, -1, -1],
[2, 0, -1, -1],
[0, 1, -1, -1],
],
[
[1, 2, -1, -1],
[2, 0, -1, -1],
[0, 1, -1, -1],
],
],
dtype=torch.int64,
device="cpu",
)
extended_atype = torch.tensor(
[
[0, 1, 2],
[0, 1, 2],
],
dtype=torch.int64,
device="cpu",
)
mm_types = [1, 2]
expected_edge_index = torch.tensor(
[
[1, 0],
[2, 0],
[0, 1],
[0, 2],
[4, 3],
[5, 3],
[3, 4],
[3, 5],
],
dtype=torch.int64,
device="cpu",
)

edge_index = torch.ops.deepmd_gnn.edge_index(
nlist,
extended_atype,
torch.tensor(mm_types, dtype=torch.int64, device="cpu"),
)

assert torch.equal(edge_index, expected_edge_index)

0 comments on commit 83afb8b

Please sign in to comment.