From 481fa25a36b5dffed2f087454dcbecd9490e4bcf Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Mon, 25 Nov 2024 17:29:26 +0100 Subject: [PATCH] Fix cumulative_sum() for non-uniform Tensors --- phiml/math/_ops.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/phiml/math/_ops.py b/phiml/math/_ops.py index 20298762..21f78005 100644 --- a/phiml/math/_ops.py +++ b/phiml/math/_ops.py @@ -648,21 +648,26 @@ def cumulative_sum(x: Tensor, dim: DimFilter, include_0=False, include_sum=True, dim = x.shape.only(dim, reorder=True) assert dim.rank >= 1, f"dim must contain at least one dimension." assert dim.rank == 1 or include_0 + include_sum == 1, f"When summing over multiple flattened dims, exaclty one of (include_0, include_sum) must be True but got include_0={include_0}, include_sum={include_sum}" - native_x = reshaped_native(x, [x.shape - dim, dim]) - b = choose_backend(native_x) - native_result = b.cumsum(native_x, 1) - if include_0: - native_result = b.pad(native_result, ((0, 0), (1, 0))) - if not include_sum: - native_result = native_result[:, :-1] - result = reshaped_tensor(native_result, [x.shape - dim, dim + (include_0 + include_sum) - 1]) - if index_dim is not None: - assert dim.rank == 1, f"multi-dimensional indices not yet supported" - if isinstance(index_dim, str): - index_dim = auto(index_dim, channel) - index_dim = index_dim.with_size(dim.name_list) - result = expand(result, index_dim) - return result + broadcast = broadcast_dims(x) + assert dim.only(broadcast).is_empty, f"Cannot compute cumulative sum along {dim} because input is not uniform along that dimension." + def uniform_cumulative_sum(x: Tensor, index_dim=index_dim, dim=dim.names): + dim = x.shape.only(dim, reorder=True) + native_x = reshaped_native(x, [x.shape - dim, dim]) + b = choose_backend(native_x) + native_result = b.cumsum(native_x, 1) + if include_0: + native_result = b.pad(native_result, ((0, 0), (1, 0))) + if not include_sum: + native_result = native_result[:, :-1] + result = reshaped_tensor(native_result, [x.shape - dim, dim + (include_0 + include_sum) - 1]) + if index_dim is not None: + assert dim.rank == 1, f"multi-dimensional indices not yet supported" + if isinstance(index_dim, str): + index_dim = auto(index_dim, channel) + index_dim = index_dim.with_size(dim.name_list) + result = expand(result, index_dim) + return result + return broadcast_op(uniform_cumulative_sum, [x], broadcast) def fftfreq(resolution: Shape, dx: Union[Tensor, float] = 1, dtype: DType = None):