Skip to content

Commit

Permalink
Update nx/lib/nx/defn/grad.ex
Browse files Browse the repository at this point in the history
Co-authored-by: José Valim <jose.valim@dashbit.co>
  • Loading branch information
polvalente and josevalim authored Sep 24, 2024
1 parent affdc90 commit add0134
Showing 1 changed file with 1 addition and 10 deletions.
11 changes: 1 addition & 10 deletions nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
Expand Up @@ -272,16 +272,7 @@ defmodule Nx.Defn.Grad do
%{vectorized_axes: vectorized_axes, names: names},
parent_names
) do
reversed_inner_axes =
Enum.reduce(names, [], fn name, acc ->
if name in parent_names do
[name | acc]
else
acc
end
end)

Keyword.keys(vectorized_axes) ++ Enum.reverse(reversed_inner_axes)
Keyword.keys(vectorized_axes) ++ Enum.filter(names, & &1 in parent_names)
end

defp revectorize_node(node, vectorized_names) do
Expand Down

0 comments on commit add0134

Please sign in to comment.