diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index ee9c7a9a76..2899f106bc 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -204,11 +204,9 @@ def forward_atomic( self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4 ) - # (nframes, nloc, nnei), dtype is the same as atype. + # (nframes, nloc, nnei), index type is int64. j_type = extended_atype[ - np.arange(extended_atype.shape[0], dtype=extended_atype.dtype)[ - :, None, None - ], + np.arange(extended_atype.shape[0], dtype=np.int64)[:, None, None], masked_nlist, ] diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index ae8c2d5cb0..2bedccbd43 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -272,7 +272,7 @@ def forward_atomic( torch.arange( extended_atype.size(0), device=extended_coord.device, - dtype=extended_atype.dtype, + dtype=torch.int64, )[:, None, None], masked_nlist, ]