Skip to content

Commit

Permalink
simplify inner loop
Browse files Browse the repository at this point in the history
  • Loading branch information
xrdaukar committed Feb 14, 2025
1 parent 65e83ab commit df9b84c
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions src/oumi/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,18 +524,13 @@ def _pad_to_max_dim_and_stack(

target_view = target[...]

if pad_on_left_side:
for dim_idx, curr_size in enumerate(input_tensor.shape):
max_size = max_dim_sizes[dim_idx]
if curr_size < max_size:
target_view = target_view.narrow(
dim_idx, start=(max_size - curr_size), length=curr_size
)
else:
for dim_idx, curr_size in enumerate(input_tensor.shape):
max_size = max_dim_sizes[dim_idx]
if curr_size < max_size:
target_view = target_view.narrow(dim_idx, start=0, length=curr_size)
for dim_idx, curr_size in enumerate(input_tensor.shape):
max_size = max_dim_sizes[dim_idx]
if curr_size < max_size:
start_idx = (max_size - curr_size) if pad_on_left_side else 0
target_view = target_view.narrow(
dim_idx, start=start_idx, length=curr_size
)

assert target_view.shape == input_tensor.shape
target_view[...] = input_tensor
Expand Down

0 comments on commit df9b84c

Please sign in to comment.