Skip to content

Commit

Permalink
improved function partial application design
Browse files Browse the repository at this point in the history
This replaces `Fix` (xref #54653) with `fix`. The usage is similar: use
`fix(i)(f, x)` instead of `Fix{i}(f, x)`.

Benefits:
* Improved type safety: creating an invalid type such as
  `Fix{:some_symbol}` or `Fix{-7}` is not possible.
* The design should be friendlier to future extensions. E.g., suppose
  that publicly-facing functionality for fixing a keyword (instead of
  positional) argument was desired, it could be achieved by adding a
  new method to `fix` taking a `Symbol`, instead of adding new public
  names.

Lots of changes are shared with PR #56425, if one of them gets merged
the other will be greatly simplified.
  • Loading branch information
nsajko committed Nov 10, 2024
1 parent afdba95 commit 8b2e9ef
Show file tree
Hide file tree
Showing 11 changed files with 378 additions and 74 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ New library functions
* `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]).
* `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]).
* `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims`
* The new `Fix` type is a generalization of `Fix1/Fix2` for fixing a single argument ([#54653]).
* `Fix1`/`Fix2` are now generalized by `fix` ([#54653], [#56518]).

New library features
--------------------
Expand Down
1 change: 1 addition & 0 deletions base/Base_compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ include("error.jl")
include("bool.jl")
include("number.jl")
include("int.jl")
include("typedomainnumbers.jl")
include("operators.jl")
include("pointer.jl")
include("refvalue.jl")
Expand Down
158 changes: 126 additions & 32 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1153,55 +1153,149 @@ julia> filter(!isletter, str)
!(f::Function) = (!) f
!(f::ComposedFunction{typeof(!)}) = f.inner #allows !!f === f

const _PositiveInteger = _TypeDomainNumbers.PositiveIntegers.PositiveInteger

struct PartiallyAppliedFunction{Position <: _PositiveInteger, Func, Arg} <: Function
partially_applied_argument_position::Position
f::Func
x::Arg

function (::Type{PartiallyAppliedFunction{Position}})(func::Func, arg) where {Position <: _PositiveInteger, Func}
Pos = Position::DataType
pos = Pos.instance
new{Pos, _stable_typeof(func), _stable_typeof(arg)}(pos, func, arg)
end
end

function getproperty((@nospecialize v::PartiallyAppliedFunction), s::Symbol)
getfield(v, s)
end # avoid overspecialization

function Base.show(
(@nospecialize io::Base.IO),
(@nospecialize unused::Type{PartiallyAppliedFunction{Position}}),
) where {Position <: _PositiveInteger}
if Position isa DataType
print(io, "fix(")
show(io, Position.instance)
print(io, ')')
else
show(io, PartiallyAppliedFunction)
print(io, '{')
show(io, Position)
print(io, '}')
end
end

function Base.show(
(@nospecialize io::Base.IO),
(@nospecialize unused::Type{PartiallyAppliedFunction{Position, Func}}),
) where {Position <: _PositiveInteger, Func}
show(io, PartiallyAppliedFunction{Position})
print(io, '{')
show(io, Func)
print(io, '}')
end

function Base.show(
(@nospecialize io::Base.IO),
(@nospecialize unused::Type{PartiallyAppliedFunction{Position, Func, Arg}}),
) where {Position <: _PositiveInteger, Func, Arg}
show(io, PartiallyAppliedFunction{Position, Func})
print(io, '{')
show(io, Arg)
print(io, '}')
end

function Base.show((@nospecialize io::Base.IO), @nospecialize p::PartiallyAppliedFunction)
print(io, "fix(")
show(io, p.partially_applied_argument_position)
print(io, ")(")
show(io, p.f)
print(io, ", ")
show(io, p.x)
print(io, ')')
end

function _partially_applied_function_check(m::Int, nm1::Int)
if m < nm1
throw(ArgumentError(LazyString("expected at least ", nm1, " arguments to `fix(", nm1 + 1, ")`, but got ", m)))
end
end

function (partial::PartiallyAppliedFunction)(args::Vararg{Any,M}; kws...) where {M}
n = partial.partially_applied_argument_position
nm1 = _TypeDomainNumbers.PositiveIntegers.natural_predecessor(n)
_partially_applied_function_check(M, Int(nm1))
(args_left, args_right) = _TypeDomainNumberTupleUtils.split_tuple(args, nm1)
partial.f(args_left..., partial.x, args_right...; kws...)
end

"""
Fix{N}(f, x)
fix(::Integer)::UnionAll
Return a [`UnionAll`](@ref) type such that:
* It's a constructor taking two arguments:
1. A function to be partially applied
2. An argument of the above function to be fixed
* Its instances are partial applications of the function, with one positional argument fixed. The argument to `fix` is the one-based index of the position argument to be fixed.
A type representing a partially-applied version of a function `f`, with the argument
`x` fixed at position `N::Int`. In other words, `Fix{3}(f, x)` behaves similarly to
`(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
For example, `fix(3)(f, x)` behaves similarly to `(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
See also: [`Fix1`](@ref), [`Fix2`](@ref).
!!! compat "Julia 1.12"
This general functionality requires at least Julia 1.12, while `Fix1` and `Fix2`
are available earlier.
Requires at least Julia 1.12 (`Fix1` and `Fix2` are available earlier, too).
!!! note
When nesting multiple `Fix`, note that the `N` in `Fix{N}` is _relative_ to the current
When nesting multiple `fix`, note that the `n` in `fix(n)` is _relative_ to the current
available arguments, rather than an absolute ordering on the target function. For example,
`Fix{1}(Fix{2}(f, 4), 4)` fixes the first and second arg, while `Fix{2}(Fix{1}(f, 4), 4)`
`fix(1)(fix(2)(f, 4), 4)` fixes the first and second arg, while `fix(2)(fix(1)(f, 4), 4)`
fixes the first and third arg.
"""
struct Fix{N,F,T} <: Function
f::F
x::T
function Fix{N}(f::F, x) where {N,F}
if !(N isa Int)
throw(ArgumentError(LazyString("expected type parameter in `Fix` to be `Int`, but got `", N, "::", typeof(N), "`")))
elseif N < 1
throw(ArgumentError(LazyString("expected `N` in `Fix{N}` to be integer greater than 0, but got ", N)))
end
new{N,_stable_typeof(f),_stable_typeof(x)}(f, x)
end
end
### Examples
function (f::Fix{N})(args::Vararg{Any,M}; kws...) where {N,M}
M < N-1 && throw(ArgumentError(LazyString("expected at least ", N-1, " arguments to `Fix{", N, "}`, but got ", M)))
return f.f(args[begin:begin+(N-2)]..., f.x, args[begin+(N-1):end]...; kws...)
end
```jldoctest
julia> Base.fix(2)(-, 3)(7)
4
# Special cases for improved constant propagation
(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...)
(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...)
julia> Base.fix(2) === Base.Fix2
true
julia> Base.fix(1)(Base.fix(2)(muladd, 3), 2)(5) === (x -> muladd(2, 3, x))(5)
true
```
"""
Alias for `Fix{1}`. See [`Fix`](@ref Base.Fix).
function fix(@nospecialize m::Integer)
n = Int(m)::Int
if n 0
throw(ArgumentError("the index of the partially applied argument must be positive"))
end
k = _TypeDomainNumbers.Utils.from_abs_int(n)
PartiallyAppliedFunction{typeof(k)}
end

"""
Fix1::UnionAll
[`fix(1)`](@ref Base.fix).
"""
const Fix1{F,T} = Fix{1,F,T}
const Fix1 = fix(1)

"""
Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).
Fix2::UnionAll
[`fix(2)`](@ref Base.fix).
"""
const Fix2{F,T} = Fix{2,F,T}
const Fix2 = fix(2)

# Special cases for improved constant propagation
function (partial::Fix1)(x; kws...)
partial.f(partial.x, x; kws...)
end
function (partial::Fix2)(x; kws...)
partial.f(x, partial.x; kws...)
end


"""
Expand Down
2 changes: 1 addition & 1 deletion base/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public
AsyncCondition,
CodeUnits,
Event,
Fix,
fix,
Fix1,
Fix2,
Generator,
Expand Down
19 changes: 14 additions & 5 deletions base/tuple.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module _TupleTypeByLength
export Tuple1OrMore, Tuple2OrMore, Tuple32OrMore
const Tuple1OrMore = Tuple{Any, Vararg}
const Tuple2OrMore = Tuple{Any, Any, Vararg}
const Tuple32OrMore = Tuple{
Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any,
Vararg{Any, N},
} where {N}
end

# Document NTuple here where we have everything needed for the doc system
"""
NTuple{N, T}
Expand Down Expand Up @@ -358,11 +371,7 @@ map(f, t::Tuple{Any, Any}) = (@inline; (f(t[1]), f(t[2])))
map(f, t::Tuple{Any, Any, Any}) = (@inline; (f(t[1]), f(t[2]), f(t[3])))
map(f, t::Tuple) = (@inline; (f(t[1]), map(f,tail(t))...))
# stop inlining after some number of arguments to avoid code blowup
const Any32{N} = Tuple{Any,Any,Any,Any,Any,Any,Any,Any,
Any,Any,Any,Any,Any,Any,Any,Any,
Any,Any,Any,Any,Any,Any,Any,Any,
Any,Any,Any,Any,Any,Any,Any,Any,
Vararg{Any,N}}
const Any32 = _TupleTypeByLength.Tuple32OrMore
const All32{T,N} = Tuple{T,T,T,T,T,T,T,T,
T,T,T,T,T,T,T,T,
T,T,T,T,T,T,T,T,
Expand Down
Loading

0 comments on commit 8b2e9ef

Please sign in to comment.