Skip to content

Commit

Permalink
Refactor scheduler and switch to a spinner thread concept for wakeups
Browse files Browse the repository at this point in the history
This also adds a counter for idle/sleeping threads to avoid checking every thread when everyone is running.
  • Loading branch information
gbaraldi committed Jan 17, 2025
1 parent 3d85309 commit 0d72173
Show file tree
Hide file tree
Showing 9 changed files with 238 additions and 74 deletions.
2 changes: 1 addition & 1 deletion base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ const liblapack_name = libblas_name
# Note that `atomics.jl` here should be deprecated
Core.eval(Threads, :(include("atomics.jl")))
include("channels.jl")
include("partr.jl")
include("scheduler/scheduler.jl")
include("task.jl")
include("threads_overloads.jl")
include("weakkeydict.jl")
Expand Down
84 changes: 23 additions & 61 deletions base/partr.jl → base/scheduler/partr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,63 +19,6 @@ const heap_d = UInt32(8)
const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)]
const heaps_lock = [SpinLock(), SpinLock()]


"""
cong(max::UInt32)
Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
"""
cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check

get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ())

set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed)

"""
rand_ptls(max::UInt32)
Return a random UInt32 in the range `0:max-1` using the thread-local RNG
state. Max must be greater than 0.
"""
Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32)
rngseed = get_ptls_rng()
val, seed = rand_uniform_max_int32(max, rngseed)
set_ptls_rng(seed)
return val % UInt32
end

# This implementation is based on OpenSSLs implementation of rand_uniform
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
# Comments are vendored from their implementation as well.
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.

# Essentially it boils down to incrementally generating a fixed point
# number on the interval [0, 1) and multiplying this number by the upper
# range limit. Once it is certain what the fractional part contributes to
# the integral part of the product, the algorithm has produced a definitive
# result.
"""
rand_uniform_max_int32(max::UInt32, seed::UInt64)
Return a random UInt32 in the range `0:max-1` using the given seed.
Max must be greater than 0.
"""
Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64)
if max == UInt32(1)
return UInt32(0), seed
end
# We are generating a fixed point number on the interval [0, 1).
# Multiplying this by the range gives us a number on [0, upper).
# The high word of the multiplication result represents the integral part
# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes
seed = UInt64(69069) * seed + UInt64(362437)
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
i = prod >> 32 % UInt32 # integral part
return i % UInt32, seed
end



function multiq_sift_up(heap::taskheap, idx::Int32)
while idx > Int32(1)
parent = (idx - Int32(2)) ÷ heap_d + Int32(1)
Expand Down Expand Up @@ -147,10 +90,10 @@ function multiq_insert(task::Task, priority::UInt16)

task.priority = priority

rn = cong(heap_p)
rn = Base.Scheduler.cong(heap_p)
tpheaps = heaps[tp]
while !trylock(tpheaps[rn].lock)
rn = cong(heap_p)
rn = Base.Scheduler.cong(heap_p)
end

heap = tpheaps[rn]
Expand Down Expand Up @@ -190,8 +133,8 @@ function multiq_deletemin()
if i == heap_p
return nothing
end
rn1 = cong(heap_p)
rn2 = cong(heap_p)
rn1 = Base.Scheduler.cong(heap_p)
rn2 = Base.Scheduler.cong(heap_p)
prio1 = tpheaps[rn1].priority
prio2 = tpheaps[rn2].priority
if prio1 > prio2
Expand All @@ -211,7 +154,21 @@ function multiq_deletemin()
heap = tpheaps[rn1]
task = heap.tasks[1]
if ccall(:jl_set_task_tid, Cint, (Any, Cint), task, tid-1) == 0
# This task is stuck to a thread that's likely sleeping, move the task to it's private queue and wake it up
# We move this out of the queue to avoid spinning on it
ntasks = heap.ntasks
@atomic :monotonic heap.ntasks = ntasks - Int32(1)
heap.tasks[1] = heap.tasks[ntasks]
Base._unsetindex!(heap.tasks, Int(ntasks))
prio1 = typemax(UInt16)
if ntasks > 1
multiq_sift_down(heap, Int32(1))
prio1 = heap.tasks[1].priority
end
@atomic :monotonic heap.priority = prio1
push!(workqueue_for(tid), t)
unlock(heap.lock)
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)
@goto retry
end
ntasks = heap.ntasks
Expand Down Expand Up @@ -243,4 +200,9 @@ function multiq_check_empty()
return true
end


enqueue!(t::Task) = multiq_insert(t, t.priority)
dequeue!() = multiq_deletemin()
checktaskempty() = multiq_check_empty()

end
74 changes: 74 additions & 0 deletions base/scheduler/scheduler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module Scheduler

"""
cong(max::UInt32)
Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
"""
cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check

get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ())

set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed)

"""
rand_ptls(max::UInt32)
Return a random UInt32 in the range `0:max-1` using the thread-local RNG
state. Max must be greater than 0.
"""
Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32)
rngseed = get_ptls_rng()
val, seed = rand_uniform_max_int32(max, rngseed)
set_ptls_rng(seed)
return val % UInt32
end

# This implementation is based on OpenSSLs implementation of rand_uniform
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
# Comments are vendored from their implementation as well.
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.

# Essentially it boils down to incrementally generating a fixed point
# number on the interval [0, 1) and multiplying this number by the upper
# range limit. Once it is certain what the fractional part contributes to
# the integral part of the product, the algorithm has produced a definitive
# result.
"""
rand_uniform_max_int32(max::UInt32, seed::UInt64)
Return a random UInt32 in the range `0:max-1` using the given seed.
Max must be greater than 0.
"""
Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64)
if max == UInt32(1)
return UInt32(0), seed
end
# We are generating a fixed point number on the interval [0, 1).
# Multiplying this by the range gives us a number on [0, upper).
# The high word of the multiplication result represents the integral part
# This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes
seed = UInt64(69069) * seed + UInt64(362437)
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
i = prod >> 32 % UInt32 # integral part
return i % UInt32, seed
end

include("scheduler/partr.jl")

const ChosenScheduler = Partr



# Scheduler interface:
# enqueue! which pushes a runnable Task into it
# dequeue! which pops a runnable Task from it
# checktaskempty which returns true if the scheduler has no available Tasks

enqueue!(t::Task) = ChosenScheduler.enqueue!(t)
dequeue!() = ChosenScheduler.dequeue!()
checktaskempty() = ChosenScheduler.checktaskempty()

end
38 changes: 31 additions & 7 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,6 @@ end

function enq_work(t::Task)
(t._state === task_state_runnable && t.queue === nothing) || error("schedule: Task not runnable")

# Sticky tasks go into their thread's work queue.
if t.sticky
tid = Threads.threadid(t)
Expand Down Expand Up @@ -968,19 +967,44 @@ function enq_work(t::Task)
ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1)
push!(workqueue_for(tid), t)
else
# Otherwise, put the task in the multiqueue.
Partr.multiq_insert(t, t.priority)
# Otherwise, push the task to the scheduler
Scheduler.enqueue!(t)
tid = 0
end
end
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)

if (tid == 0)
Core.Intrinsics.atomic_fence(:sequentially_consistent)
n_spinning = Core.Intrinsics.atomic_pointerref(cglobal(:jl_n_threads_spinning, Cint), :monotonic)
n_spinning == 0 && ccall(:jl_add_spinner, Cvoid, ())
else
ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16)
end
# n_spinning = Core.Intrinsics.atomic_pointerref(cglobal(:jl_n_threads, Cint), :acquire)
# n_spinning == 0 && ccall(:jl_add_spinner, Cvoid, ())
return t
end

const ChildFirst = false

function schedule(t::Task)
# [task] created -scheduled-> wait_time
maybe_record_enqueued!(t)
enq_work(t)
if ChildFirst
ct = current_task()
if ct.sticky || t.sticky
maybe_record_enqueued!(t)
enq_work(t)
else
maybe_record_enqueued!(t)
enq_work(ct)
yieldto(t)
end
else
maybe_record_enqueued!(t)
enq_work(t)
end
return t
end

"""
Expand Down Expand Up @@ -1176,10 +1200,10 @@ function trypoptask(W::StickyWorkqueue)
end
return t
end
return Partr.multiq_deletemin()
return Scheduler.dequeue!()
end

checktaskempty = Partr.multiq_check_empty
checktaskempty = Scheduler.checktaskempty

@noinline function poptask(W::StickyWorkqueue)
task = trypoptask(W)
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_data.inc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
#define JL_EXPORTED_DATA_SYMBOLS(XX) \
XX(jl_n_threadpools, int) \
XX(jl_n_threads, _Atomic(int)) \
XX(jl_n_threads_spinning, _Atomic(int)) \
XX(jl_n_gcthreads, int) \
XX(jl_options, jl_options_t) \
XX(jl_task_gcstack_offset, int) \
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@
XX(jl_tagged_gensym) \
XX(jl_take_buffer) \
XX(jl_task_get_next) \
XX(jl_add_spinner) \
XX(jl_task_stack_buffer) \
XX(jl_termios_size) \
XX(jl_test_cpu_feature) \
Expand Down
1 change: 1 addition & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -2055,6 +2055,7 @@ JL_DLLEXPORT jl_sym_t *jl_get_ARCH(void) JL_NOTSAFEPOINT;
JL_DLLIMPORT jl_value_t *jl_get_libllvm(void) JL_NOTSAFEPOINT;
extern JL_DLLIMPORT int jl_n_threadpools;
extern JL_DLLIMPORT _Atomic(int) jl_n_threads;
extern JL_DLLIMPORT _Atomic(int) jl_n_threads_spinning; // Scheduler internal counter
extern JL_DLLIMPORT int jl_n_gcthreads;
extern int jl_n_markthreads;
extern int jl_n_sweepthreads;
Expand Down
1 change: 1 addition & 0 deletions src/julia_threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ typedef struct _jl_tls_states_t {
uint64_t uv_run_leave;
uint64_t sleep_enter;
uint64_t sleep_leave;
uint64_t woken_up;
)

// some hidden state (usually just because we don't have the type's size declaration)
Expand Down
Loading

0 comments on commit 0d72173

Please sign in to comment.