From a7427c9c1cb7a1aa588e454deccfe42bedb3ba9d Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 23 Oct 2024 09:59:04 +0800 Subject: [PATCH] fix index type --- deepmd/dpmodel/atomic_model/pairtab_atomic_model.py | 6 ++---- deepmd/pt/model/atomic_model/pairtab_atomic_model.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index ee9c7a9a76..2899f106bc 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -204,11 +204,9 @@ def forward_atomic( self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4 ) - # (nframes, nloc, nnei), dtype is the same as atype. + # (nframes, nloc, nnei), index type is int64. j_type = extended_atype[ - np.arange(extended_atype.shape[0], dtype=extended_atype.dtype)[ - :, None, None - ], + np.arange(extended_atype.shape[0], dtype=np.int64)[:, None, None], masked_nlist, ] diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index ae8c2d5cb0..2bedccbd43 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -272,7 +272,7 @@ def forward_atomic( torch.arange( extended_atype.size(0), device=extended_coord.device, - dtype=extended_atype.dtype, + dtype=torch.int64, )[:, None, None], masked_nlist, ]