Skip to content

Commit

Permalink
support sorting tuples
Browse files Browse the repository at this point in the history
Uses merge sort, as an obvious choice for a stable sort of tuples.

A recursive data structure of singleton type, representing Peano
natural numbers, is used to help with splitting a tuple into two halves
in the merge sort. An alternative design would use a reference tuple,
but this would require relying on `tail`, which seems more harsh on the
compiler. With the recursive datastructure the predecessor operation
and the successor operation are both trivial.

Allows inference to preserve inferred element type even when tuple
length is not known.

Follow-up PRs may add further improvements, such as the ability to
select an unstable sorting algorithm.

The added file, typedomainnumbers.jl is not specific to sorting, thus
making it a separate file. Xref #55571.

Fixes #54489
  • Loading branch information
nsajko committed Nov 8, 2024
1 parent 683da41 commit 6f49f4c
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 0 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ New library features
* `Base.require_one_based_indexing` and `Base.has_offset_axes` are now public ([#56196])
* New `ltruncate`, `rtruncate` and `ctruncate` functions for truncating strings to text width, accounting for char widths ([#55351])
* `isless` (and thus `cmp`, sorting, etc.) is now supported for zero-dimensional `AbstractArray`s ([#55772])
* `sort` now sorts tuples (#56425)

Standard library changes
------------------------
Expand Down
2 changes: 2 additions & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ include("cartesian.jl")
using .Cartesian
include("multidimensional.jl")

include("typedomainnumbers.jl")

include("broadcast.jl")
using .Broadcast
using .Broadcast: broadcasted, broadcasted_kwsyntax, materialize, materialize!,
Expand Down
89 changes: 89 additions & 0 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1736,6 +1736,95 @@ julia> v
"""
sort(v::AbstractVector; kws...) = sort!(copymutable(v); kws...)

module _SortTupleStable
using
Base._TypeDomainNumbers.PositiveIntegers, Base._TypeDomainNumbers.IntegersGreaterThanOne,
Base._TypeDomainNumbers.Utils, Base._TypeDomainNumberTupleUtils, Base._TupleTypeByLength
using Base: tail
using Base.Order: Ordering, lt
export sort_tuple_stable
function merge_recursive((@nospecialize ord::Ordering), a::Tuple, b::Tuple)
ret = if a isa Tuple1OrMore
a
else
b
end
type_assert_tuple_0_or_more(ret)
end
function merge_recursive(ord::Ordering, a::Tuple1OrMore, b::Tuple1OrMore)
l = first(a)
r = first(b)
x = tail(a)
y = tail(b)
merged = if lt(ord, r, l)
let rec = type_assert_tuple_1_or_more(merge_recursive(ord, a, y))
(r, rec...)
end
else
let rec = type_assert_tuple_1_or_more(merge_recursive(ord, x, b))
(l, rec...)
end
end
type_assert_tuple_2_or_more(merged)
end
function merge_nontrivial(ord::Ordering, a::Tuple1OrMore, b::Tuple1OrMore)
ret = merge_recursive(ord, a, b)
type_assert_tuple_2_or_more(ret)
end
function split_tuple(@nospecialize tup::Tuple2OrMore)
len = type_assert_integer_greater_than_1(tuple_type_domain_length(tup))
len_l = type_assert_positive_integer(half_floor_nontrivial(len))
len_r = type_assert_positive_integer(half_ceiling_nontrivial(len))
tup_l = type_assert_tuple_1_or_more(skip_from_tail_nontrivial(tup, len_r))
tup_r = type_assert_tuple_1_or_more(skip_from_front_nontrivial(tup, len_l))
(tup_l, tup_r)
end
function sort_recursive((@nospecialize ord::Ordering), @nospecialize tup::Tuple{Any})
tup
end
function sort_recursive(ord::Ordering, tup::Tuple2OrMore)
(tup_l, tup_r) = split_tuple(tup)
sorted_l = type_assert_tuple_1_or_more(sort_recursive(ord, tup_l))
sorted_r = type_assert_tuple_1_or_more(sort_recursive(ord, tup_r))
type_assert_tuple_2_or_more(merge_nontrivial(ord, sorted_l, sorted_r))
end
function sort_tuple_stable_2_or_more(ord::Ordering, tup::Tuple2OrMore)
ret = sort_recursive(ord, tup)
type_assert_tuple_2_or_more(ret)
end
function sort_tuple_array_fallback(ord::Ordering, tup::Tuple2OrMore)
vec = if tup isa NTuple
[tup...]
else
Any[tup...]
end
sort!(vec; order = ord)
(vec...,)
end
function sort_tuple_stable((@nospecialize ord::Ordering), @nospecialize tup::Tuple)
if tup isa Tuple2OrMore
if tup isa Tuple32OrMore
sort_tuple_array_fallback(ord, tup)
else
sort_tuple_stable_2_or_more(ord, tup)
end
else
tup
end
end
end

function sort(
tup::Tuple;
lt = isless,
by = identity,
rev::Union{Nothing, Bool} = nothing,
order::Ordering = Forward,
)
o = ord(lt, by, rev, order)
_SortTupleStable.sort_tuple_stable(o, tup)
end

## partialsortperm: the permutation to sort the first k elements of an array ##

"""
Expand Down
192 changes: 192 additions & 0 deletions base/typedomainnumbers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# Adapted from the TypeDomainNaturalNumbers.jl package.
module _TypeDomainNumbers
module Zeros
export Zero
struct Zero end
end

module PositiveIntegers
module RecursiveStep
using ...Zeros
export recursive_step
function recursive_step(@nospecialize t::Type)
Union{Zero, t}
end
end
module UpperBounds
using ..RecursiveStep
abstract type A end
abstract type B{P <: recursive_step(A)} <: A end
abstract type C{P <: recursive_step(B)} <: B{P} end
abstract type D{P <: recursive_step(C)} <: C{P} end
end
using .RecursiveStep
const PositiveIntegerUpperBound = UpperBounds.A
const PositiveIntegerUpperBoundTighter = UpperBounds.D
export
natural_successor, natural_predecessor,
NonnegativeInteger, NonnegativeIntegerUpperBound,
PositiveInteger, PositiveIntegerUpperBound,
type_assert_nonnegative_integer, type_assert_positive_integer
struct PositiveInteger{
Predecessor <: recursive_step(PositiveIntegerUpperBoundTighter),
} <: PositiveIntegerUpperBoundTighter{Predecessor}
predecessor::Predecessor
global const NonnegativeInteger = recursive_step(PositiveInteger)
global const NonnegativeIntegerUpperBound = recursive_step(PositiveIntegerUpperBound)
global function natural_successor(p::P) where {P <: NonnegativeInteger}
ret = new{P}(p)
type_assert_positive_integer(ret)
end
end
function type_assert_nonnegative_integer(@nospecialize x::NonnegativeInteger)
x
end
function type_assert_positive_integer(@nospecialize x::PositiveInteger)
x
end
function natural_predecessor(@nospecialize o::PositiveInteger)
ret = getfield(o, :predecessor) # avoid specializing `getproperty` for each number
type_assert_nonnegative_integer(ret)
end
end

module IntegersGreaterThanOne
using ..PositiveIntegers
export
IntegerGreaterThanOne, IntegerGreaterThanOneUpperBound,
type_assert_integer_greater_than_1
const IntegerGreaterThanOne = let t = PositiveInteger
t{P} where {P <: t}
end
const IntegerGreaterThanOneUpperBound = let t = PositiveIntegerUpperBound
PositiveIntegers.UpperBounds.B{P} where {P <: t}
end
function type_assert_integer_greater_than_1(@nospecialize x::IntegerGreaterThanOne)
x
end
end

module Constants
using ..Zeros, ..PositiveIntegers
export n0, n1
const n0 = Zero()
const n1 = natural_successor(n0)
end

module Utils
using ..PositiveIntegers, ..IntegersGreaterThanOne, ..Constants
using Base: @assume_effects
export minus_two, half_floor, half_ceiling, half_floor_nontrivial, half_ceiling_nontrivial
function minus_two(@nospecialize m::IntegerGreaterThanOne)
natural_predecessor(natural_predecessor(m))
end
@assume_effects :foldable :nothrow function half_floor(@nospecialize m::NonnegativeInteger)
ret = if m isa IntegerGreaterThanOneUpperBound
let n = minus_two(m), rec = half_floor(n)
type_assert_positive_integer(natural_successor(rec))
end
else
n0
end
type_assert_nonnegative_integer(ret)
end
@assume_effects :foldable :nothrow function half_ceiling(@nospecialize m::NonnegativeInteger)
ret = if m isa IntegerGreaterThanOneUpperBound
let n = minus_two(m), rec = half_ceiling(n)
type_assert_positive_integer(natural_successor(rec))
end
else
if m isa PositiveIntegerUpperBound
n1
else
n0
end
end
type_assert_nonnegative_integer(ret)
end
function half_floor_nontrivial(@nospecialize m::IntegerGreaterThanOne)
ret = half_floor(m)
type_assert_positive_integer(ret)
end
function half_ceiling_nontrivial(@nospecialize m::IntegerGreaterThanOne)
ret = half_ceiling(m)
type_assert_positive_integer(ret)
end
end
end

module _TupleTypeByLength
export
Tuple1OrMore, Tuple2OrMore, Tuple3OrMore, Tuple4OrMore, Tuple32OrMore,
type_assert_tuple_0_or_more, type_assert_tuple_1_or_more, type_assert_tuple_2_or_more,
type_assert_tuple_3_or_more, type_assert_tuple_4_or_more,
type_assert_tuple_1
const Tuple1OrMore = Tuple{Any, Vararg}
const Tuple2OrMore = Tuple{Any, Any, Vararg}
const Tuple3OrMore = Tuple{Any, Any, Any, Vararg}
const Tuple4OrMore = Tuple{Any, Any, Any, Any, Vararg}
const Tuple32OrMore = Base.Any32
function type_assert_tuple_0_or_more(@nospecialize x::Tuple)
x
end
function type_assert_tuple_1_or_more(@nospecialize x::Tuple1OrMore)
x
end
function type_assert_tuple_2_or_more(@nospecialize x::Tuple2OrMore)
x
end
function type_assert_tuple_3_or_more(@nospecialize x::Tuple3OrMore)
x
end
function type_assert_tuple_4_or_more(@nospecialize x::Tuple4OrMore)
x
end
end

module _TypeDomainNumberTupleUtils
using
.._TypeDomainNumbers.PositiveIntegers, .._TypeDomainNumbers.IntegersGreaterThanOne,
.._TypeDomainNumbers.Constants, .._TupleTypeByLength
using Base: @assume_effects, front, tail
export
tuple_type_domain_length,
skip_from_front, skip_from_tail,
skip_from_front_nontrivial, skip_from_tail_nontrivial
@assume_effects :foldable :nothrow function tuple_type_domain_length(@nospecialize tup::Tuple)
ret = if tup isa Tuple1OrMore
let t = tail(tup), rec = tuple_type_domain_length(t)
type_assert_positive_integer(natural_successor(rec))
end
else
n0
end
type_assert_nonnegative_integer(ret)
end
@assume_effects :foldable function skip_from_front((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger)
if skip_count isa PositiveIntegerUpperBound
let cm1 = natural_predecessor(skip_count), t = tail(tup)
@inline skip_from_front(t, cm1)
end
else
tup
end
end
@assume_effects :foldable function skip_from_tail((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger)
if skip_count isa PositiveIntegerUpperBound
let cm1 = natural_predecessor(skip_count), t = front(tup)
@inline skip_from_tail(t, cm1)
end
else
tup
end
end
function skip_from_front_nontrivial((@nospecialize tup::Tuple2OrMore), @nospecialize skip_count::PositiveInteger)
skip_from_front(tup, skip_count)
end
function skip_from_tail_nontrivial((@nospecialize tup::Tuple2OrMore), @nospecialize skip_count::PositiveInteger)
skip_from_tail(tup, skip_count)
end
end
44 changes: 44 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,50 @@ end
end
@test sort(1:2000, by=x->x÷100, rev=true) == sort(1:2000, by=x->-x÷100) ==
vcat(2000, (x:x+99 for x in 1900:-100:100)..., 1:99)
@testset "tuples" begin
tup = Tuple(0:9)
@test tup === sort(tup; by = _ -> 0)
@test (0, 2, 4, 6, 8, 1, 3, 5, 7, 9) === sort(tup; by = x -> isodd(x))
@test (1, 3, 5, 7, 9, 0, 2, 4, 6, 8) === sort(tup; by = x -> iseven(x))
end
end

@testset "tuple sorting" begin
max_unrolled_length = 31
@testset "correctness" begin
tup = Tuple(0:9)
tup_rev = reverse(tup)
@test tup === @inferred sort(tup)
@test tup === sort(tup; rev = false)
@test tup_rev === sort(tup; rev = true)
@test tup_rev === sort(tup; lt = >)
end
@testset "inference" begin
known_length = (Tuple{Vararg{Int, max_unrolled_length}}, Tuple{Vararg{Float64, max_unrolled_length}})
unknown_length = (Tuple{Vararg{Int}}, Tuple{Vararg{Float64}})
for Tup (known_length..., unknown_length...)
@test Tup == Base.infer_return_type(sort, Tuple{Tup})
end
for Tup (known_length...,)
@test Core.Compiler.is_foldable(Base.infer_effects(sort, Tuple{Tup}))
end
end
@testset "alloc" begin
function test_zero_allocated(tup::Tuple)
@test iszero(@allocated sort(tup))
end
test_zero_allocated(ntuple(identity, max_unrolled_length))
end
@testset "heterogeneous" begin
@testset "stability" begin
tup = (0, 0x0, 0x000)
@test tup === sort(tup)
end
tup = (1, 2, 3, missing, missing)
for t (tup, (1, missing, 2, missing, 3), (missing, missing, 1, 2, 3))
@test tup === @inferred sort(t)
end
end
end

@testset "partialsort" begin
Expand Down

0 comments on commit 6f49f4c

Please sign in to comment.