From d3964b600a5e596bd47cf04b3da4a1e78d4aecaf Mon Sep 17 00:00:00 2001 From: Neven Sajko <4944410+nsajko@users.noreply.github.com> Date: Fri, 10 Jan 2025 10:00:22 +0100 Subject: [PATCH] broadcast: align `ndims` implementation with intent behind code (#56999) The `N<:Integer` constraint was nonsensical, given that `(N === Any) || (N isa Int)`. N5N3 noticed this back in 2022: https://github.com/JuliaLang/julia/pull/44061#discussion_r883261337 Follow up on #44061. Also xref #45477. --- base/broadcast.jl | 9 +++++++-- test/broadcast.jl | 2 ++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index 7f278fc1ca66f..512b397352040 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -280,9 +280,14 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s) end Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}() -Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, :args)) -Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N +Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims_broadcasted(BC) +# the `AbstractArrayStyle` type parameter is required to be either equal to `Any` or be an `Int` value +Base.ndims(BC::Type{<:Broadcasted{<:AbstractArrayStyle{Any},Nothing}}) = _maxndims_broadcasted(BC) +Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N} = N::Int +function _maxndims_broadcasted(BC::Type{<:Broadcasted}) + _maxndims(fieldtype(BC, :args)) +end _maxndims(::Type{T}) where {T<:Tuple} = reduce(max, ntuple(n -> (F = fieldtype(T, n); F <: Tuple ? 1 : ndims(F)), Base._counttuple(T))) _maxndims(::Type{<:Tuple{T}}) where {T} = T <: Tuple ? 1 : ndims(T) function _maxndims(::Type{<:Tuple{T, S}}) where {T, S} diff --git a/test/broadcast.jl b/test/broadcast.jl index 57441de67c0ce..0f5bdf7a40bb1 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -930,6 +930,8 @@ let @test @inferred(Base.IteratorSize(Broadcast.broadcasted(+, (1,2,3), a1, zeros(3,3,3)))) === Base.HasShape{3}() + @test @inferred(Base.IteratorSize(Base.broadcasted(randn))) === Base.HasShape{0}() + # inference on nested bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3))) bc_nest = Base.broadcasted(+, bc , bc)