diff --git a/scico/trace.py b/scico/trace.py index d58924ee..2c38f69f 100644 --- a/scico/trace.py +++ b/scico/trace.py @@ -116,11 +116,11 @@ def _trace_arg_repr(val: Any) -> str: if isinstance(val, jax.Array) and not isinstance( val, jax._src.interpreters.partial_eval.JaxprTracer ): - if call_trace.show_jax_device: + if call_trace.show_jax_device: # type: ignore platform = list(val.devices())[0].platform # assume all of same type devices = ",".join(map(str, sorted([d.id for d in val.devices()]))) dev_str = f"{clr_devc}{{dev={platform}({devices})}}{clr_args}" - if call_trace.show_jax_sharding and isinstance( + if call_trace.show_jax_sharding and isinstance( # type: ignore val.sharding, jax._src.sharding_impls.PositionalSharding ): shard_str = f"{clr_devc}{{shard={val.sharding.shape}}}{clr_args}"