Skip to content

Make take_along_axis with TF backend compilable. #21239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 4, 2025

Conversation

hertschuh
Copy link
Collaborator

When there are dynamic dimensions, like typically the batch size, tf.broadcast_dynamic_shape is not always compilable.

Replace with an adhoc implementation for dynamic dimensions where we rely on the broadcast itself to fail when the shapes are not broadcastable.

Tested with https://github.com/keras-team/keras-rs/blob/main/examples/listwise_ranking.py on GPU as I was not able to distill a simple reproduction of this.

When there are dynamic dimensions, like typically the batch size, `tf.broadcast_dynamic_shape` is not always compilable.

Replace with an adhoc implementation for dynamic dimensions where we rely on the broadcast itself to fail when the shapes are not broadcastable.

Tested with https://github.com/keras-team/keras-rs/blob/main/examples/listwise_ranking.py on GPU as I was not able to distill a simple reproduction of this.
@codecov-commenter
Copy link

codecov-commenter commented May 2, 2025

Codecov Report

Attention: Patch coverage is 76.47059% with 4 lines in your changes missing coverage. Please review.

Project coverage is 82.59%. Comparing base (48a6692) to head (b456b0d).

Files with missing lines Patch % Lines
keras/src/backend/tensorflow/numpy.py 76.47% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21239      +/-   ##
==========================================
- Coverage   82.60%   82.59%   -0.01%     
==========================================
  Files         564      564              
  Lines       54501    54509       +8     
  Branches     8469     8471       +2     
==========================================
+ Hits        45020    45024       +4     
- Misses       7397     7399       +2     
- Partials     2084     2086       +2     
Flag Coverage Δ
keras 82.41% <76.47%> (-0.01%) ⬇️
keras-jax 63.67% <0.00%> (-0.01%) ⬇️
keras-numpy 58.80% <0.00%> (-0.01%) ⬇️
keras-openvino 32.99% <0.00%> (-0.01%) ⬇️
keras-tensorflow 64.09% <76.47%> (-0.01%) ⬇️
keras-torch 63.75% <0.00%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels May 4, 2025
@fchollet fchollet merged commit ce95589 into keras-team:master May 4, 2025
10 of 11 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels May 4, 2025
@hertschuh hertschuh deleted the tf_take_along_axis branch May 4, 2025 20:56
hertschuh added a commit to hertschuh/keras that referenced this pull request May 7, 2025
Change keras-team#21239 broke one use case when the axis dimension is dynamic, the type of the indices is not int32, and the op is run in graph mode.

Note that the additional unit tests don't actually cover this.
fchollet pushed a commit that referenced this pull request May 7, 2025
Change #21239 broke one use case when the axis dimension is dynamic, the type of the indices is not int32, and the op is run in graph mode.

Note that the additional unit tests don't actually cover this.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants