From d592df19cb94625519215453edc58542a4886106 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 26 Sep 2023 10:44:26 +0100 Subject: [PATCH] [LSC] Ignore incorrect type annotations related to jax.numpy APIs PiperOrigin-RevId: 568475283 Change-Id: Ice9d8b610d2d8ab1c541679590e184ab91cbc05e --- ferminet/jastrows.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ferminet/jastrows.py b/ferminet/jastrows.py index dbb6465..2b88e0c 100644 --- a/ferminet/jastrows.py +++ b/ferminet/jastrows.py @@ -41,8 +41,8 @@ def _jastrow_ee( for r in jnp.split(r_ee, nspins[0:1], axis=0) ] r_ees_parallel = jnp.concatenate([ - r_ees[0][0][jnp.triu_indices(nspins[0], k=1)], - r_ees[1][1][jnp.triu_indices(nspins[1], k=1)], + r_ees[0][0][jnp.triu_indices(nspins[0], k=1)], # pytype: disable=wrong-arg-types # jnp-type + r_ees[1][1][jnp.triu_indices(nspins[1], k=1)], # pytype: disable=wrong-arg-types # jnp-type ]) if r_ees_parallel.shape[0] > 0: