Skip to content

Commit

Permalink
breaking: rename OP (#22)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Oct 3, 2024
1 parent 8275a7d commit 8fae7d2
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion deepmd_gnn/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def forward_lower_common(
extended_coord_ff = extended_coord[ff]
extended_atype_ff = extended_atype[ff]
nlist_ff = nlist[ff]
edge_index = torch.ops.deepmd_gnn.mace_edge_index(
edge_index = torch.ops.deepmd_gnn.edge_index(
nlist_ff,
extended_atype_ff,
torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"),
Expand Down
2 changes: 1 addition & 1 deletion deepmd_gnn/nequip.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def forward_lower_common(
extended_coord_ff = extended_coord[ff]
extended_atype_ff = extended_atype[ff]
nlist_ff = nlist[ff]
edge_index = torch.ops.deepmd_gnn.mace_edge_index(
edge_index = torch.ops.deepmd_gnn.edge_index(
nlist_ff,
extended_atype_ff,
torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"),
Expand Down
2 changes: 1 addition & 1 deletion op/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
file(GLOB OP_SRC mace.cc)
file(GLOB OP_SRC edge_index.cc)

add_library(deepmd_gnn MODULE ${OP_SRC})
# link: libdeepmd libtorch
Expand Down
10 changes: 5 additions & 5 deletions op/mace.cc → op/edge_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

#include <iostream>

torch::Tensor mace_edge_index(const torch::Tensor &nlist_tensor,
const torch::Tensor &atype_tensor,
const torch::Tensor &mm_tensor) {
torch::Tensor edge_index(const torch::Tensor &nlist_tensor,
const torch::Tensor &atype_tensor,
const torch::Tensor &mm_tensor) {
torch::Tensor nlist_tensor_ = nlist_tensor.cpu().contiguous();
torch::Tensor atype_tensor_ = atype_tensor.cpu().contiguous();
torch::Tensor mm_tensor_ = mm_tensor.cpu().contiguous();
Expand Down Expand Up @@ -58,6 +58,6 @@ torch::Tensor mace_edge_index(const torch::Tensor &nlist_tensor,
return edge_index_tensor.to(nlist_tensor.device());
}

TORCH_LIBRARY(deepmd_gnn, m) { m.def("mace_edge_index", mace_edge_index); }
TORCH_LIBRARY(deepmd_gnn, m) { m.def("edge_index", edge_index); }
// compatbility with old models freezed by deepmd_mace package
TORCH_LIBRARY(deepmd_mace, m) { m.def("mace_edge_index", mace_edge_index); }
TORCH_LIBRARY(deepmd_mace, m) { m.def("mace_edge_index", edge_index); }

0 comments on commit 8fae7d2

Please sign in to comment.