Skip to content

Commit

Permalink
Fix gather() without index dim
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Aug 14, 2024
1 parent 3992091 commit 0d70ee0
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2468,7 +2468,8 @@ def gather(values, indices: Tensor, dims: Union[DimFilter, None] = None, pref_in
indices = to_int32(indices)
if values._is_tracer or is_sparse(values):
if not index_dim:
indices = expand(indices, channel(gather=dims))
index_dim = channel(gather=dims)
indices = expand(indices, index_dim)
if not index_dim.item_names[0]:
indices = indices._with_shape_replaced(indices.shape.with_dim_size(index_dim, dims))
if values._is_tracer:
Expand Down

0 comments on commit 0d70ee0

Please sign in to comment.