diff --git a/NEWS.md b/NEWS.md index ba9ca1c521c55b..2c8bd77ac8278c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -119,6 +119,7 @@ New library features * `Base.require_one_based_indexing` and `Base.has_offset_axes` are now public ([#56196]) * New `ltruncate`, `rtruncate` and `ctruncate` functions for truncating strings to text width, accounting for char widths ([#55351]) * `isless` (and thus `cmp`, sorting, etc.) is now supported for zero-dimensional `AbstractArray`s ([#55772]) +* `sort` now sorts tuples (#56425) Standard library changes ------------------------ diff --git a/base/Base.jl b/base/Base.jl index 3b56dca166cee1..8315c01a6a7dad 100644 --- a/base/Base.jl +++ b/base/Base.jl @@ -390,6 +390,8 @@ include("cartesian.jl") using .Cartesian include("multidimensional.jl") +include("typedomainnumbers.jl") + include("broadcast.jl") using .Broadcast using .Broadcast: broadcasted, broadcasted_kwsyntax, materialize, materialize!, diff --git a/base/sort.jl b/base/sort.jl index ef0f208209fc8d..fed295f72a7c01 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -1736,6 +1736,95 @@ julia> v """ sort(v::AbstractVector; kws...) = sort!(copymutable(v); kws...) +module _SortTupleStable + using + Base._TypeDomainNumbers.PositiveIntegers, Base._TypeDomainNumbers.IntegersGreaterThanOne, + Base._TypeDomainNumbers.Utils, Base._TypeDomainNumberTupleUtils, Base._TupleTypeByLength + using Base: tail + using Base.Order: Ordering, lt + export sort_tuple_stable + function merge_recursive((@nospecialize ord::Ordering), a::Tuple, b::Tuple) + ret = if a isa Tuple1OrMore + a + else + b + end + type_assert_tuple_0_or_more(ret) + end + function merge_recursive(ord::Ordering, a::Tuple1OrMore, b::Tuple1OrMore) + l = first(a) + r = first(b) + x = tail(a) + y = tail(b) + merged = if lt(ord, r, l) + let rec = type_assert_tuple_1_or_more(merge_recursive(ord, a, y)) + (r, rec...) + end + else + let rec = type_assert_tuple_1_or_more(merge_recursive(ord, x, b)) + (l, rec...) + end + end + type_assert_tuple_2_or_more(merged) + end + function merge_nontrivial(ord::Ordering, a::Tuple1OrMore, b::Tuple1OrMore) + ret = merge_recursive(ord, a, b) + type_assert_tuple_2_or_more(ret) + end + function split_tuple(@nospecialize tup::Tuple2OrMore) + len = type_assert_integer_greater_than_1(tuple_type_domain_length(tup)) + len_l = type_assert_positive_integer(half_floor_nontrivial(len)) + len_r = type_assert_positive_integer(half_ceiling_nontrivial(len)) + tup_l = type_assert_tuple_1_or_more(skip_from_tail_nontrivial(tup, len_r)) + tup_r = type_assert_tuple_1_or_more(skip_from_front_nontrivial(tup, len_l)) + (tup_l, tup_r) + end + function sort_recursive((@nospecialize ord::Ordering), @nospecialize tup::Tuple{Any}) + tup + end + function sort_recursive(ord::Ordering, tup::Tuple2OrMore) + (tup_l, tup_r) = split_tuple(tup) + sorted_l = type_assert_tuple_1_or_more(sort_recursive(ord, tup_l)) + sorted_r = type_assert_tuple_1_or_more(sort_recursive(ord, tup_r)) + type_assert_tuple_2_or_more(merge_nontrivial(ord, sorted_l, sorted_r)) + end + function sort_tuple_stable_2_or_more(ord::Ordering, tup::Tuple2OrMore) + ret = sort_recursive(ord, tup) + type_assert_tuple_2_or_more(ret) + end + function sort_tuple_array_fallback(ord::Ordering, tup::Tuple2OrMore) + vec = if tup isa NTuple + [tup...] + else + Any[tup...] + end + sort!(vec; order = ord) + (vec...,) + end + function sort_tuple_stable((@nospecialize ord::Ordering), @nospecialize tup::Tuple) + if tup isa Tuple2OrMore + if tup isa Tuple32OrMore + sort_tuple_array_fallback(ord, tup) + else + sort_tuple_stable_2_or_more(ord, tup) + end + else + tup + end + end +end + +function sort( + tup::Tuple; + lt = isless, + by = identity, + rev::Union{Nothing, Bool} = nothing, + order::Ordering = Forward, +) + o = ord(lt, by, rev, order) + _SortTupleStable.sort_tuple_stable(o, tup) +end + ## partialsortperm: the permutation to sort the first k elements of an array ## """ diff --git a/base/typedomainnumbers.jl b/base/typedomainnumbers.jl new file mode 100644 index 00000000000000..acd97d7dcbd636 --- /dev/null +++ b/base/typedomainnumbers.jl @@ -0,0 +1,194 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +# Adapted from the TypeDomainNaturalNumbers.jl package. +module _TypeDomainNumbers + module Zeros + export Zero + struct Zero end + end + + module PositiveIntegers + module RecursiveStep + using ...Zeros + export recursive_step + function recursive_step(@nospecialize t::Type) + Union{Zero, t} + end + end + module UpperBounds + using ..RecursiveStep + abstract type A end + abstract type B{P <: recursive_step(A)} <: A end + abstract type C{P <: recursive_step(B)} <: B{P} end + abstract type D{P <: recursive_step(C)} <: C{P} end + end + using .RecursiveStep + const PositiveIntegerUpperBound = UpperBounds.A + const PositiveIntegerUpperBoundTighter = UpperBounds.D + export + natural_successor, natural_predecessor, + NonnegativeInteger, NonnegativeIntegerUpperBound, + PositiveInteger, PositiveIntegerUpperBound, + type_assert_nonnegative_integer, type_assert_positive_integer + struct PositiveInteger{ + Predecessor <: recursive_step(PositiveIntegerUpperBoundTighter), + } <: PositiveIntegerUpperBoundTighter{Predecessor} + predecessor::Predecessor + global const NonnegativeInteger = recursive_step(PositiveInteger) + global const NonnegativeIntegerUpperBound = recursive_step(PositiveIntegerUpperBound) + global function natural_successor(p::P) where {P <: NonnegativeInteger} + ret = new{P}(p) + type_assert_positive_integer(ret) + end + end + function type_assert_nonnegative_integer(@nospecialize x::NonnegativeInteger) + x + end + function type_assert_positive_integer(@nospecialize x::PositiveInteger) + x + end + function natural_predecessor(@nospecialize o::PositiveInteger) + ret = getfield(o, :predecessor) # avoid specializing `getproperty` for each number + type_assert_nonnegative_integer(ret) + end + end + + module IntegersGreaterThanOne + using ..PositiveIntegers + export + IntegerGreaterThanOne, IntegerGreaterThanOneUpperBound, + type_assert_integer_greater_than_1 + const IntegerGreaterThanOne = let t = PositiveInteger + t{P} where {P <: t} + end + const IntegerGreaterThanOneUpperBound = let t = PositiveIntegerUpperBound + PositiveIntegers.UpperBounds.B{P} where {P <: t} + end + function type_assert_integer_greater_than_1(@nospecialize x::IntegerGreaterThanOne) + x + end + end + + module Constants + using ..Zeros, ..PositiveIntegers + export n0, n1 + const n0 = Zero() + const n1 = natural_successor(n0) + end + + module Utils + using ..PositiveIntegers, ..IntegersGreaterThanOne, ..Constants + using Base: @assume_effects + export minus_two, half_floor, half_ceiling, half_floor_nontrivial, half_ceiling_nontrivial + function minus_two(@nospecialize m::IntegerGreaterThanOne) + natural_predecessor(natural_predecessor(m)) + end + @assume_effects :foldable :nothrow function half_floor(@nospecialize m::NonnegativeInteger) + ret = if m isa IntegerGreaterThanOneUpperBound + let n = minus_two(m), rec = @inline half_floor(n) + type_assert_positive_integer(natural_successor(rec)) + end + else + n0 + end + type_assert_nonnegative_integer(ret) + end + @assume_effects :foldable :nothrow function half_ceiling(@nospecialize m::NonnegativeInteger) + ret = if m isa IntegerGreaterThanOneUpperBound + let n = minus_two(m), rec = @inline half_ceiling(n) + type_assert_positive_integer(natural_successor(rec)) + end + else + if m isa PositiveIntegerUpperBound + n1 + else + n0 + end + end + type_assert_nonnegative_integer(ret) + end + function half_floor_nontrivial(@nospecialize m::IntegerGreaterThanOne) + ret = half_floor(m) + type_assert_positive_integer(ret) + end + function half_ceiling_nontrivial(@nospecialize m::IntegerGreaterThanOne) + ret = half_ceiling(m) + type_assert_positive_integer(ret) + end + end +end + +module _TupleTypeByLength + export + Tuple1OrMore, Tuple2OrMore, Tuple3OrMore, Tuple4OrMore, Tuple32OrMore, + type_assert_tuple_0_or_more, type_assert_tuple_1_or_more, type_assert_tuple_2_or_more, + type_assert_tuple_3_or_more, type_assert_tuple_4_or_more, + type_assert_tuple_1 + const Tuple1OrMore = Tuple{Any, Vararg} + const Tuple2OrMore = Tuple{Any, Any, Vararg} + const Tuple3OrMore = Tuple{Any, Any, Any, Vararg} + const Tuple4OrMore = Tuple{Any, Any, Any, Any, Vararg} + const Tuple32OrMore = Base.Any32 + function type_assert_tuple_0_or_more(@nospecialize x::Tuple) + x + end + function type_assert_tuple_1_or_more(@nospecialize x::Tuple1OrMore) + x + end + function type_assert_tuple_2_or_more(@nospecialize x::Tuple2OrMore) + x + end + function type_assert_tuple_3_or_more(@nospecialize x::Tuple3OrMore) + x + end + function type_assert_tuple_4_or_more(@nospecialize x::Tuple4OrMore) + x + end +end + +module _TypeDomainNumberTupleUtils + using + .._TypeDomainNumbers.PositiveIntegers, .._TypeDomainNumbers.IntegersGreaterThanOne, + .._TypeDomainNumbers.Constants, .._TupleTypeByLength + using Base: @assume_effects, front, tail + export + tuple_type_domain_length, + skip_from_front, skip_from_tail, + skip_from_front_nontrivial, skip_from_tail_nontrivial + # The `@nospecialize` and `@inline` together should effectively result in specializing + # on the length, without specializing on the types of the elements. + @assume_effects :foldable :nothrow function tuple_type_domain_length(@nospecialize tup::Tuple) + ret = if tup isa Tuple1OrMore + let t = tail(tup), rec = @inline tuple_type_domain_length(t) + type_assert_positive_integer(natural_successor(rec)) + end + else + n0 + end + type_assert_nonnegative_integer(ret) + end + @assume_effects :foldable function skip_from_front((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger) + if skip_count isa PositiveIntegerUpperBound + let cm1 = natural_predecessor(skip_count), t = tail(tup) + @inline skip_from_front(t, cm1) + end + else + tup + end + end + @assume_effects :foldable function skip_from_tail((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger) + if skip_count isa PositiveIntegerUpperBound + let cm1 = natural_predecessor(skip_count), t = front(tup) + @inline skip_from_tail(t, cm1) + end + else + tup + end + end + function skip_from_front_nontrivial((@nospecialize tup::Tuple2OrMore), @nospecialize skip_count::PositiveInteger) + skip_from_front(tup, skip_count) + end + function skip_from_tail_nontrivial((@nospecialize tup::Tuple2OrMore), @nospecialize skip_count::PositiveInteger) + skip_from_tail(tup, skip_count) + end +end diff --git a/test/sorting.jl b/test/sorting.jl index 2714197f58823a..4118391abf5d8d 100644 --- a/test/sorting.jl +++ b/test/sorting.jl @@ -92,6 +92,40 @@ end end @test sort(1:2000, by=x->x÷100, rev=true) == sort(1:2000, by=x->-x÷100) == vcat(2000, (x:x+99 for x in 1900:-100:100)..., 1:99) + @testset "tuples" begin + tup = Tuple(0:9) + @test tup === sort(tup; by = _ -> 0) + @test (0, 2, 4, 6, 8, 1, 3, 5, 7, 9) === sort(tup; by = x -> isodd(x)) + @test (1, 3, 5, 7, 9, 0, 2, 4, 6, 8) === sort(tup; by = x -> iseven(x)) + end +end + +@testset "tuple sorting" begin + max_unrolled_length = 31 + @testset "correctness" begin + tup = Tuple(0:9) + tup_rev = reverse(tup) + @test tup === @inferred sort(tup) + @test tup === sort(tup; rev = false) + @test tup_rev === sort(tup; rev = true) + @test tup_rev === sort(tup; lt = >) + end + @testset "inference" begin + known_length = (Tuple{Vararg{Int, max_unrolled_length}}, Tuple{Vararg{Float64, max_unrolled_length}}) + unknown_length = (Tuple{Vararg{Int}}, Tuple{Vararg{Float64}}) + for Tup ∈ (known_length..., unknown_length...) + @test Tup == Base.infer_return_type(sort, Tuple{Tup}) + end + for Tup ∈ (known_length...,) + @test Core.Compiler.is_foldable(Base.infer_effects(sort, Tuple{Tup})) + end + end + @testset "alloc" begin + function test_zero_allocated(tup::Tuple) + @test iszero(@allocated sort(tup)) + end + test_zero_allocated(ntuple(identity, max_unrolled_length)) + end end @testset "partialsort" begin