Skip to content

Commit

Permalink
Fix cumulative_sum() for non-uniform Tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Nov 25, 2024
1 parent 33a6026 commit 481fa25
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,21 +648,26 @@ def cumulative_sum(x: Tensor, dim: DimFilter, include_0=False, include_sum=True,
dim = x.shape.only(dim, reorder=True)
assert dim.rank >= 1, f"dim must contain at least one dimension."
assert dim.rank == 1 or include_0 + include_sum == 1, f"When summing over multiple flattened dims, exaclty one of (include_0, include_sum) must be True but got include_0={include_0}, include_sum={include_sum}"
native_x = reshaped_native(x, [x.shape - dim, dim])
b = choose_backend(native_x)
native_result = b.cumsum(native_x, 1)
if include_0:
native_result = b.pad(native_result, ((0, 0), (1, 0)))
if not include_sum:
native_result = native_result[:, :-1]
result = reshaped_tensor(native_result, [x.shape - dim, dim + (include_0 + include_sum) - 1])
if index_dim is not None:
assert dim.rank == 1, f"multi-dimensional indices not yet supported"
if isinstance(index_dim, str):
index_dim = auto(index_dim, channel)
index_dim = index_dim.with_size(dim.name_list)
result = expand(result, index_dim)
return result
broadcast = broadcast_dims(x)
assert dim.only(broadcast).is_empty, f"Cannot compute cumulative sum along {dim} because input is not uniform along that dimension."
def uniform_cumulative_sum(x: Tensor, index_dim=index_dim, dim=dim.names):
dim = x.shape.only(dim, reorder=True)
native_x = reshaped_native(x, [x.shape - dim, dim])
b = choose_backend(native_x)
native_result = b.cumsum(native_x, 1)
if include_0:
native_result = b.pad(native_result, ((0, 0), (1, 0)))
if not include_sum:
native_result = native_result[:, :-1]
result = reshaped_tensor(native_result, [x.shape - dim, dim + (include_0 + include_sum) - 1])
if index_dim is not None:
assert dim.rank == 1, f"multi-dimensional indices not yet supported"
if isinstance(index_dim, str):
index_dim = auto(index_dim, channel)
index_dim = index_dim.with_size(dim.name_list)
result = expand(result, index_dim)
return result
return broadcast_op(uniform_cumulative_sum, [x], broadcast)


def fftfreq(resolution: Shape, dx: Union[Tensor, float] = 1, dtype: DType = None):
Expand Down

0 comments on commit 481fa25

Please sign in to comment.