Skip to content

Commit

Permalink
fix index type
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Oct 23, 2024
1 parent 1a1f145 commit a7427c9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
6 changes: 2 additions & 4 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,9 @@ def forward_atomic(
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
)

# (nframes, nloc, nnei), dtype is the same as atype.
# (nframes, nloc, nnei), index type is int64.
j_type = extended_atype[
np.arange(extended_atype.shape[0], dtype=extended_atype.dtype)[
:, None, None
],
np.arange(extended_atype.shape[0], dtype=np.int64)[:, None, None],
masked_nlist,
]

Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def forward_atomic(
torch.arange(
extended_atype.size(0),
device=extended_coord.device,
dtype=extended_atype.dtype,
dtype=torch.int64,
)[:, None, None],
masked_nlist,
]
Expand Down

0 comments on commit a7427c9

Please sign in to comment.