From 59a5c2d3cb6dbcbc385c8397b422f2bd9c15a794 Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Sat, 2 Nov 2024 21:41:04 +0100 Subject: [PATCH] support sorting tuples Uses merge sort, as an obvious choice for a stable sort of tuples. A recursive data structure of singleton type, representing Peano natural numbers, is used to help with splitting a tuple into two halves in the merge sort. An alternative design would use a reference tuple, but this would require relying on `tail`, which seems more harsh on the compiler. With the recursive datastructure the predecessor operation and the successor operation are both trivial. Allows inference to preserve inferred element type even when tuple length is not known. Follow-up PRs may add further improvements, such as the ability to select an unstable sorting algorithm. The added file, typedomainnumbers.jl is not specific to sorting, thus making it a separate file. Xref #55571. Fixes #54489 --- NEWS.md | 1 + base/Base.jl | 2 + base/sort.jl | 76 +++++++++++++++++++++ base/tuple.jl | 19 ++++-- base/typedomainnumbers.jl | 140 ++++++++++++++++++++++++++++++++++++++ test/choosetests.jl | 2 +- test/sorting.jl | 44 ++++++++++++ test/typedomainnumbers.jl | 31 +++++++++ 8 files changed, 309 insertions(+), 6 deletions(-) create mode 100644 base/typedomainnumbers.jl create mode 100644 test/typedomainnumbers.jl diff --git a/NEWS.md b/NEWS.md index 74cda05e9d0e1..efb2a1b2009a3 100644 --- a/NEWS.md +++ b/NEWS.md @@ -119,6 +119,7 @@ New library features * `Base.require_one_based_indexing` and `Base.has_offset_axes` are now public ([#56196]) * New `ltruncate`, `rtruncate` and `ctruncate` functions for truncating strings to text width, accounting for char widths ([#55351]) * `isless` (and thus `cmp`, sorting, etc.) is now supported for zero-dimensional `AbstractArray`s ([#55772]) +* `sort` now sorts tuples (#56425) Standard library changes ------------------------ diff --git a/base/Base.jl b/base/Base.jl index 39507b625660d..c737bb2c69f96 100644 --- a/base/Base.jl +++ b/base/Base.jl @@ -106,6 +106,8 @@ include("cartesian.jl") using .Cartesian include("multidimensional.jl") +include("typedomainnumbers.jl") + include("broadcast.jl") using .Broadcast using .Broadcast: broadcasted, broadcasted_kwsyntax, materialize, materialize!, diff --git a/base/sort.jl b/base/sort.jl index 6991f12551ab4..d60ff71de7afb 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -1736,6 +1736,82 @@ julia> v """ sort(v::AbstractVector; kws...) = sort!(copymutable(v); kws...) +module _SortTupleStable + using + Base._TypeDomainNumbers.PositiveIntegers, Base._TypeDomainNumbers.IntegersGreaterThanOne, + Base._TypeDomainNumberTupleUtils, Base._TupleTypeByLength + using Base: tail + using Base.Order: Ordering, lt + function merge_recursive((@nospecialize ord::Ordering), a::Tuple, b::Tuple) + if a isa Tuple1OrMore + a + else + b + end + end + function merge_recursive(ord::Ordering, a::Tuple1OrMore, b::Tuple1OrMore) + l = first(a) + r = first(b) + x = tail(a) + y = tail(b) + if lt(ord, r, l) + let rec = merge_recursive(ord, a, y) + (r, rec...) + end + else + let rec = merge_recursive(ord, x, b) + (l, rec...) + end + end + end + function merge_nontrivial(ord::Ordering, a::Tuple1OrMore, b::Tuple1OrMore) + merge_recursive(ord, a, b) + end + function sort_recursive((@nospecialize ord::Ordering), @nospecialize tup::Tuple{Any}) + tup + end + function sort_recursive(ord::Ordering, tup::Tuple2OrMore) + (tup_l, tup_r) = split_tuple_into_halves(tup) + sorted_l = sort_recursive(ord, tup_l) + sorted_r = sort_recursive(ord, tup_r) + merge_nontrivial(ord, sorted_l, sorted_r) + end + function sort_tuple_stable_2_or_more(ord::Ordering, tup::Tuple2OrMore) + sort_recursive(ord, tup) + end + function sort_tuple_array_fallback(ord::Ordering, tup::Tuple2OrMore) + vec = if tup isa NTuple + [tup...] + else + Any[tup...] + end + sort!(vec; order = ord) + (vec...,) + end + function sort_tuple_stable((@nospecialize ord::Ordering), @nospecialize tup::Tuple) + if tup isa Tuple2OrMore + if tup isa Tuple32OrMore + sort_tuple_array_fallback(ord, tup) + else + sort_tuple_stable_2_or_more(ord, tup) + end + else + tup + end + end +end + +function sort( + tup::Tuple; + lt = isless, + by = identity, + rev::Union{Nothing, Bool} = nothing, + order::Ordering = Forward, +) + o = ord(lt, by, rev, order) + _SortTupleStable.sort_tuple_stable(o, tup) +end + ## partialsortperm: the permutation to sort the first k elements of an array ## """ diff --git a/base/tuple.jl b/base/tuple.jl index 3791d74bfc698..a3e9c92252ff9 100644 --- a/base/tuple.jl +++ b/base/tuple.jl @@ -1,5 +1,18 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license +module _TupleTypeByLength + export Tuple1OrMore, Tuple2OrMore, Tuple32OrMore + const Tuple1OrMore = Tuple{Any, Vararg} + const Tuple2OrMore = Tuple{Any, Any, Vararg} + const Tuple32OrMore = Tuple{ + Any, Any, Any, Any, Any, Any, Any, Any, + Any, Any, Any, Any, Any, Any, Any, Any, + Any, Any, Any, Any, Any, Any, Any, Any, + Any, Any, Any, Any, Any, Any, Any, Any, + Vararg{Any, N}, + } where {N} +end + # Document NTuple here where we have everything needed for the doc system """ NTuple{N, T} @@ -358,11 +371,7 @@ map(f, t::Tuple{Any, Any}) = (@inline; (f(t[1]), f(t[2]))) map(f, t::Tuple{Any, Any, Any}) = (@inline; (f(t[1]), f(t[2]), f(t[3]))) map(f, t::Tuple) = (@inline; (f(t[1]), map(f,tail(t))...)) # stop inlining after some number of arguments to avoid code blowup -const Any32{N} = Tuple{Any,Any,Any,Any,Any,Any,Any,Any, - Any,Any,Any,Any,Any,Any,Any,Any, - Any,Any,Any,Any,Any,Any,Any,Any, - Any,Any,Any,Any,Any,Any,Any,Any, - Vararg{Any,N}} +const Any32 = _TupleTypeByLength.Tuple32OrMore const All32{T,N} = Tuple{T,T,T,T,T,T,T,T, T,T,T,T,T,T,T,T, T,T,T,T,T,T,T,T, diff --git a/base/typedomainnumbers.jl b/base/typedomainnumbers.jl new file mode 100644 index 0000000000000..4a2b02b203c06 --- /dev/null +++ b/base/typedomainnumbers.jl @@ -0,0 +1,140 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +# Adapted from the TypeDomainNaturalNumbers.jl package. +module _TypeDomainNumbers + module Zeros + export Zero + struct Zero end + end + + module PositiveIntegers + module RecursiveStep + using ...Zeros + export recursive_step + function recursive_step(@nospecialize t::Type) + Union{Zero, t} + end + end + module UpperBounds + using ..RecursiveStep + abstract type A end + abstract type B{P <: recursive_step(A)} <: A end + abstract type C{P <: recursive_step(B)} <: B{P} end + abstract type D{P <: recursive_step(C)} <: C{P} end + end + using .RecursiveStep + const PositiveIntegerUpperBound = UpperBounds.A + const PositiveIntegerUpperBoundTighter = UpperBounds.D + export + natural_successor, natural_predecessor, + NonnegativeInteger, NonnegativeIntegerUpperBound, + PositiveInteger, PositiveIntegerUpperBound + struct PositiveInteger{ + Predecessor <: recursive_step(PositiveIntegerUpperBoundTighter), + } <: PositiveIntegerUpperBoundTighter{Predecessor} + predecessor::Predecessor + global const NonnegativeInteger = recursive_step(PositiveInteger) + global const NonnegativeIntegerUpperBound = recursive_step(PositiveIntegerUpperBound) + global function natural_successor(p::P) where {P <: NonnegativeInteger} + new{P}(p) + end + end + function natural_predecessor(@nospecialize o::PositiveInteger) + getfield(o, :predecessor) # avoid specializing `getproperty` for each number + end + end + + module IntegersGreaterThanOne + using ..PositiveIntegers + export + IntegerGreaterThanOne, IntegerGreaterThanOneUpperBound, + natural_predecessor_predecessor + const IntegerGreaterThanOne = let t = PositiveInteger + t{P} where {P <: t} + end + const IntegerGreaterThanOneUpperBound = let t = PositiveIntegerUpperBound + PositiveIntegers.UpperBounds.B{P} where {P <: t} + end + function natural_predecessor_predecessor(@nospecialize x::IntegerGreaterThanOne) + natural_predecessor(natural_predecessor(x)) + end + end + + module Constants + using ..Zeros, ..PositiveIntegers + export n0, n1 + const n0 = Zero() + const n1 = natural_successor(n0) + end + + module Utils + using ..PositiveIntegers, ..IntegersGreaterThanOne, ..Constants + using Base: @assume_effects + export half_floor, half_ceiling + @assume_effects :foldable :nothrow function half_floor(@nospecialize m::NonnegativeInteger) + if m isa IntegerGreaterThanOneUpperBound + let n = natural_predecessor_predecessor(m), rec = half_floor(n) + natural_successor(rec) + end + else + n0 + end + end + @assume_effects :foldable :nothrow function half_ceiling(@nospecialize m::NonnegativeInteger) + if m isa IntegerGreaterThanOneUpperBound + let n = natural_predecessor_predecessor(m), rec = half_ceiling(n) + natural_successor(rec) + end + else + if m isa PositiveIntegerUpperBound + n1 + else + n0 + end + end + end + end +end + +module _TypeDomainNumberTupleUtils + using + .._TypeDomainNumbers.PositiveIntegers, .._TypeDomainNumbers.IntegersGreaterThanOne, + .._TypeDomainNumbers.Constants, .._TypeDomainNumbers.Utils, .._TupleTypeByLength + using Base: @assume_effects, front, tail + export tuple_type_domain_length, split_tuple_into_halves, skip_from_front, skip_from_tail + @assume_effects :foldable :nothrow function tuple_type_domain_length(@nospecialize tup::Tuple) + if tup isa Tuple1OrMore + let t = tail(tup), rec = tuple_type_domain_length(t) + natural_successor(rec) + end + else + n0 + end + end + @assume_effects :foldable function skip_from_front((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger) + if skip_count isa PositiveIntegerUpperBound + let cm1 = natural_predecessor(skip_count), t = tail(tup) + @inline skip_from_front(t, cm1) + end + else + tup + end + end + @assume_effects :foldable function skip_from_tail((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger) + if skip_count isa PositiveIntegerUpperBound + let cm1 = natural_predecessor(skip_count), t = front(tup) + @inline skip_from_tail(t, cm1) + end + else + tup + end + end + function split_tuple_into_halves(@nospecialize tup::Tuple) + len = tuple_type_domain_length(tup) + len_l = half_floor(len) + len_r = half_ceiling(len) + tup_l = skip_from_tail(tup, len_r) + tup_r = skip_from_front(tup, len_l) + (tup_l, tup_r) + end +end diff --git a/test/choosetests.jl b/test/choosetests.jl index affdee412bd86..07bbdd47f1b85 100644 --- a/test/choosetests.jl +++ b/test/choosetests.jl @@ -15,7 +15,7 @@ const TESTNAMES = [ "bitarray", "copy", "math", "fastmath", "functional", "iterators", "operators", "ordering", "path", "ccall", "parse", "loading", "gmp", "sorting", "spawn", "backtrace", "exceptions", - "file", "read", "version", "namedtuple", + "file", "read", "version", "namedtuple", "typedomainnumbers", "mpfr", "broadcast", "complex", "floatapprox", "stdlib", "reflection", "regex", "float16", "combinatorics", "sysinfo", "env", "rounding", "ranges", "mod2pi", diff --git a/test/sorting.jl b/test/sorting.jl index 93e0cdd7de5ba..d6c5f36df992a 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -92,6 +92,50 @@ end end @test sort(1:2000, by=x->x÷100, rev=true) == sort(1:2000, by=x->-x÷100) == vcat(2000, (x:x+99 for x in 1900:-100:100)..., 1:99) + @testset "tuples" begin + tup = Tuple(0:9) + @test tup === sort(tup; by = _ -> 0) + @test (0, 2, 4, 6, 8, 1, 3, 5, 7, 9) === sort(tup; by = x -> isodd(x)) + @test (1, 3, 5, 7, 9, 0, 2, 4, 6, 8) === sort(tup; by = x -> iseven(x)) + end +end + +@testset "tuple sorting" begin + max_unrolled_length = 31 + @testset "correctness" begin + tup = Tuple(0:9) + tup_rev = reverse(tup) + @test tup === @inferred sort(tup) + @test tup === sort(tup; rev = false) + @test tup_rev === sort(tup; rev = true) + @test tup_rev === sort(tup; lt = >) + end + @testset "inference" begin + known_length = (Tuple{Vararg{Int, max_unrolled_length}}, Tuple{Vararg{Float64, max_unrolled_length}}) + unknown_length = (Tuple{Vararg{Int}}, Tuple{Vararg{Float64}}) + for Tup ∈ (known_length..., unknown_length...) + @test Tup == Base.infer_return_type(sort, Tuple{Tup}) + end + for Tup ∈ (known_length...,) + @test Core.Compiler.is_foldable(Base.infer_effects(sort, Tuple{Tup})) + end + end + @testset "alloc" begin + function test_zero_allocated(tup::Tuple) + @test iszero(@allocated sort(tup)) + end + test_zero_allocated(ntuple(identity, max_unrolled_length)) + end + @testset "heterogeneous" begin + @testset "stability" begin + tup = (0, 0x0, 0x000) + @test tup === sort(tup) + end + tup = (1, 2, 3, missing, missing) + for t ∈ (tup, (1, missing, 2, missing, 3), (missing, missing, 1, 2, 3)) + @test tup === @inferred sort(t) + end + end end @testset "partialsort" begin diff --git a/test/typedomainnumbers.jl b/test/typedomainnumbers.jl new file mode 100644 index 0000000000000..e5f2e3d09b01c --- /dev/null +++ b/test/typedomainnumbers.jl @@ -0,0 +1,31 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +using + Test, + Base._TypeDomainNumbers.PositiveIntegers, + Base._TypeDomainNumbers.IntegersGreaterThanOne, + Base._TypeDomainNumbers.Constants, + Base._TypeDomainNumberTupleUtils + +@testset "type domain numbers" begin + @test n0 isa NonnegativeInteger + @test n1 isa NonnegativeInteger + @test n1 isa PositiveInteger + @testset "succ" begin + for x ∈ (n0, n1) + @test x === natural_predecessor(@inferred natural_successor(x)) + @test x === natural_predecessor_predecessor(natural_successor(natural_successor(x))) + end + end + @testset "type safety" begin + @test_throws TypeError PositiveInteger{Int} + end + @testset "tuple utils" begin + @test n0 === @inferred tuple_type_domain_length(()) + @test n1 === @inferred tuple_type_domain_length((7,)) + @test ((), ()) === @inferred split_tuple_into_halves(()) + @test ((), (7,)) === @inferred split_tuple_into_halves((7,)) + @test ((3,), (7,)) === @inferred split_tuple_into_halves((3, 7)) + @test ((3,), (7, 9)) === @inferred split_tuple_into_halves((3, 7, 9)) + end +end