Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disallow mixing offset and non-offset axes in conv input #586

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/OffsetArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ module OffsetArraysExt
import DSP
import OffsetArrays

DSP.conv_with_offset(::OffsetArrays.IdOffsetRange) = true
DSP.conv_axis_with_offset(::OffsetArrays.IdOffsetRange) = true

end
34 changes: 21 additions & 13 deletions src/dspbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,16 @@
end

# whether the given axis are to be considered to carry an offset for `conv!` and `conv`
conv_with_offset(::Base.OneTo) = false
conv_with_offset(a::Any) = throw(ArgumentError("unsupported axis type $(typeof(a))"))
conv_axis_with_offset(::Base.OneTo) = false
conv_axis_with_offset(a::Any) = throw(ArgumentError("unsupported axis type $(typeof(a))"))

Check warning on line 664 in src/dspbase.jl

View check run for this annotation

Codecov / codecov/patch

src/dspbase.jl#L664

Added line #L664 was not covered by tests

function conv_axes_with_offset(as::Tuple...)
with_offset = ((map(a -> map(conv_axis_with_offset, a), as)...)...,)
if !allequal(with_offset)
throw(ArgumentError("cannot mix offset and non-offset axes"))
end
return !isempty(with_offset) && first(with_offset)
end

const FFTTypes = Union{Float32, Float64, ComplexF32, ComplexF64}

Expand All @@ -677,7 +685,7 @@
`size(out,d) ≥ size(u,d) + size(v,d) - 1` must hold. If both input and output
have offset axes, `firstindex(out,d) ≤ firstindex(u,d) + firstindex(v,d)` and
`lastindex(out,d) ≥ lastindex(u,d) + lastindex(v,d)` must hold (for d = 1,...,N).
A mix of offset and non-offset axes between input and output is not permitted.
A mix of offset and non-offset axes is not permitted.

The `algorithm` keyword allows choosing the algorithm to use:
* `:direct`: Evaluates the convolution sum in time domain.
Expand All @@ -704,12 +712,8 @@
v::AbstractArray{<:Number, N};
algorithm=:auto
) where {T<:Number, N}
offset = conv_axes_with_offset(axes(out), axes(u), axes(v)) ? 0 : 1
output_indices = CartesianIndices(map(axes(out), axes(u), axes(v)) do ao, au, av
input_has_offset = conv_with_offset(au) || conv_with_offset(av)
if input_has_offset !== conv_with_offset(ao)
throw(ArgumentError("output must have offset axes if and only if the input has"))
end
offset = input_has_offset ? 0 : 1
return (first(au)+first(av) : last(au)+last(av)) .- offset
end)

Expand Down Expand Up @@ -752,9 +756,13 @@
end
end

conv_output_axis(au, av) =
conv_with_offset(au) || conv_with_offset(av) ?
(first(au)+first(av):last(au)+last(av)) : Base.OneTo(last(au) + last(av) - 1)
function conv_output_axes(au::Tuple, av::Tuple)
if conv_axes_with_offset(au, av)
return map((au, av) -> first(au)+first(av):last(au)+last(av), au, av)
else
return map((au, av) -> Base.OneTo(last(au) + last(av) - 1), au, av)
end
end

"""
conv(u, v; algorithm)
Expand All @@ -768,7 +776,7 @@
u::AbstractArray{Tu, N}, v::AbstractArray{Tv, N}; kwargs...
) where {Tu<:Number, Tv<:Number, N}
T = promote_type(Tu, Tv)
out_axes = map(conv_output_axis, axes(u), axes(v))
out_axes = conv_output_axes(axes(u), axes(v))
out = similar(u, T, out_axes)
return conv!(out, u, v; kwargs...)
end
Expand All @@ -792,7 +800,7 @@
"""
function conv(u::AbstractVector{T}, v::Transpose{T,<:AbstractVector}, A::AbstractMatrix{T}) where T
# Arbitrary indexing offsets not implemented
if any(conv_with_offset, (axes(u)..., axes(v)..., axes(A)...))
if any(conv_axis_with_offset, (axes(u)..., axes(v)..., axes(A)...))
throw(ArgumentError("offset axes not supported"))
end
m = length(u)+size(A,1)-1
Expand Down
13 changes: 10 additions & 3 deletions test/dsp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,19 @@ end

offset_arr = OffsetVector{Int}(undef, -1:2)
offset_arr[:] = a
@test conv(offset_arr, 1:3) == OffsetVector(expectation, 0:5)
@test_throws ArgumentError conv(offset_arr, 1:3)
@test conv(offset_arr, OffsetArray(1:3)) == OffsetVector(expectation, 0:5)
offset_arr_f = OffsetVector{Float64}(undef, -1:2)
offset_arr_f[:] = fa
@test conv(offset_arr_f, 1:3) ≈ OffsetVector(fexp, 0:5)
@test_throws ArgumentError conv(offset_arr_f, 1:3)
@test conv(offset_arr_f, OffsetArray(1:3)) ≈ OffsetVector(fexp, 0:5)
@test_throws ArgumentError conv!(zeros(6), offset_arr, 1:3) # output needs to be OA, too
@test_throws ArgumentError conv!(OffsetVector{Int}(undef, 1:6), 1:4, 1:3) # output mustn't be OA

@test conv(fa, fill(true)) == conv(fill(true), fa) == fa
@test_broken conv(offset_arr_f, fill(true)) == conv(fill(true), offset_arr_f) == offset_arr_f
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As fill(true) doesn't have any axes, this should arguably work. However, conv currently first brings all arguments to the same dimensionality, turning fill(true) into [true] here, and the axis compatibility check then says no. Fixing that is probably not important and certainly beyond the scope of this PR, but I thought leaving a @test_broken as a reminder might be in order.

@test conv(fill(true), fill(true)) == fill(true)

for M in [10, 200], N in [10, 200], T in [Float64, ComplexF64]
u = rand(T, M)
v = rand(T, N)
Expand Down Expand Up @@ -156,7 +162,8 @@ end

offset_arr = OffsetMatrix{Int}(undef, -1:1, -1:1)
offset_arr[:] = a
@test conv(offset_arr, b) == OffsetArray(expectation, 0:3, 0:3)
@test_throws ArgumentError conv(offset_arr, b)
@test conv(offset_arr, OffsetArray(b)) == OffsetArray(expectation, 0:3, 0:3)

for (M1, M2) in [(10, 20), (190, 200)], (N1, N2) in [(20, 10), (210, 200)], T in [Float64, ComplexF64]
u = rand(T, M1, M2)
Expand Down
Loading