Skip to content

Commit

Permalink
Merge pull request #116 from cscherrer/static-integer
Browse files Browse the repository at this point in the history
Use StaticInteger instead of StaticInt and add tests for static util functions
  • Loading branch information
theogf authored May 31, 2023
2 parents 5bba40f + 4dac4fc commit 7cb87d5
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MeasureBase"
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
authors = ["Chad Scherrer <chad.scherrer@gmail.com> and contributors"]
version = "0.14.6"
version = "0.14.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions src/MeasureBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const Pretty = PrettyPrinting
using ChainRulesCore
import FillArrays
using Static
using Static: StaticInteger
using FunctionChains

export
Expand Down
2 changes: 1 addition & 1 deletion src/density-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ end
@generated function _logdensity_rel(
μs::Tμ,
νs::Tν,
::Tuple{StaticInt{M},StaticInt{N}},
::Tuple{<:StaticInteger{M},<:StaticInteger{N}},
x::X,
) where {Tμ,Tν,M,N,X}
= schema(Tμ)
Expand Down
2 changes: 1 addition & 1 deletion src/standard/stdmeasure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end

# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}):

_std_measure(::Type{M}, ::StaticInt{1}) where {M<:StdMeasure} = M()
_std_measure(::Type{M}, ::StaticInteger{1}) where {M<:StdMeasure} = M()
_std_measure(::Type{M}, dof::IntegerLike) where {M<:StdMeasure} = M()^dof
_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ))

Expand Down
8 changes: 4 additions & 4 deletions src/static.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
MeasureBase.IntegerLike
Equivalent to `Union{Integer,Static.StaticInt}`.
Equivalent to `Union{Integer,Static.StaticInteger}`.
"""
const IntegerLike = Union{Integer,Static.StaticInt}
const IntegerLike = Union{Integer,Static.StaticInteger}

"""
MeasureBase.one_to(n::IntegerLike)
Expand All @@ -14,7 +14,7 @@ Returns an instance of `Base.OneTo` or `Static.SOneTo`, depending
on the type of `n`.
"""
@inline one_to(n::Integer) = Base.OneTo(n)
@inline one_to(::Static.StaticInt{N}) where {N} = Static.SOneTo{N}()
@inline one_to(::Static.StaticInteger{N}) where {N} = Static.SOneTo{N}()

_dynamic(x::Number) = dynamic(x)
_dynamic(::Static.SOneTo{N}) where {N} = Base.OneTo(N)
Expand Down Expand Up @@ -49,7 +49,7 @@ Returns the length of `x` as a dynamic or static integer.
"""
maybestatic_length(x) = length(x)
maybestatic_length(x::AbstractUnitRange) = length(x)
function maybestatic_length(::Static.OptionallyStaticUnitRange{StaticInt{A},StaticInt{B}}) where {A,B}
function maybestatic_length(::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}) where {A,B}
StaticInt{B - A + 1}()
end

Expand Down
8 changes: 4 additions & 4 deletions src/transport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,14 @@ _origin_depth_pullback(ΔΩ) = NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(_origin_depth), ν) = _origin_depth(ν), _origin_depth_pullback

# If both both measures have no origin:
function _transport_between_origins(ν, ::StaticInt{0}, ::StaticInt{0}, μ, x)
function _transport_between_origins(ν, ::StaticInteger{0}, ::StaticInteger{0}, μ, x)
_transport_with_intermediate(ν, _transport_intermediate(ν, μ), μ, x)
end

@generated function _transport_between_origins(
ν,
::StaticInt{n_ν},
::StaticInt{n_μ},
::StaticInteger{n_ν},
::StaticInteger{n_μ},
μ,
x,
) where {n_ν,n_μ}
Expand Down Expand Up @@ -188,7 +188,7 @@ end

@inline _transport_intermediate(ν, μ) = _transport_intermediate(getdof(ν), getdof(μ))
@inline _transport_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ
@inline _transport_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform()
@inline _transport_intermediate(::StaticInteger{1}, ::StaticInteger{1}) = StdUniform()

_call_transport_def(ν, μ, x) = transport_def(ν, μ, x)
_call_transport_def(::Any, ::Any, x::NoTransportOrigin) = x
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ repeatedly until there's no change. That's what this does.
_rootmeasure(μ, static(n))
end

@generated function _rootmeasure(μ, ::StaticInt{n}) where {n}
@generated function _rootmeasure(μ, ::StaticInteger{n}) where {n}
q = quote end
foreach(1:n) do _
push!(q.args, :(μ = basemeasure(μ)))
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using MeasureBase: test_interface, test_smf
using Aqua
Aqua.test_all(MeasureBase; ambiguities = false)

include("static.jl")

# Aqua._test_ambiguities(
# Aqua.aspkgids(MeasureBase);
# exclude = [LogarithmicNumbers.Logarithmic],
Expand Down
31 changes: 31 additions & 0 deletions test/static.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Test

import MeasureBase

import Static
using Static: static
import FillArrays

@testset "static" begin
@test 2 isa MeasureBase.IntegerLike
@test static(2) isa MeasureBase.IntegerLike
@test true isa MeasureBase.IntegerLike
@test static(true) isa MeasureBase.IntegerLike

@test @inferred(MeasureBase.one_to(7)) isa Base.OneTo
@test @inferred(MeasureBase.one_to(7)) == 1:7
@test @inferred(MeasureBase.one_to(static(7))) isa Static.SOneTo
@test @inferred(MeasureBase.one_to(static(7))) == static(1):static(7)

@test @inferred(MeasureBase.fill_with(4.2, (7,))) == FillArrays.Fill(4.2, 7)
@test @inferred(MeasureBase.fill_with(4.2, (static(7),))) == FillArrays.Fill(4.2, 7)
@test @inferred(MeasureBase.fill_with(4.2, (3, static(7)))) == FillArrays.Fill(4.2, 3, 7)
@test @inferred(MeasureBase.fill_with(4.2, (3:7,))) == FillArrays.Fill(4.2, (3:7,))
@test @inferred(MeasureBase.fill_with(4.2, (static(3):static(7),))) == FillArrays.Fill(4.2, (3:7,))
@test @inferred(MeasureBase.fill_with(4.2, (3:7, static(2):static(5)))) == FillArrays.Fill(4.2, (3:7, 2:5))

@test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) isa Int
@test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) == 7
@test MeasureBase.maybestatic_length(MeasureBase.one_to(static(7))) isa Static.StaticInt
@test MeasureBase.maybestatic_length(MeasureBase.one_to(static(7))) == static(7)
end

2 comments on commit 7cb87d5

@theogf
Copy link
Collaborator Author

@theogf theogf commented on 7cb87d5 May 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/84600

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.7 -m "<description of version>" 7cb87d589699a835013d98a98649310ca2946808
git push origin v0.14.7

Please sign in to comment.