Skip to content

Commit

Permalink
Merge pull request #2361 from natalie-lang/cond-var-sig-spec
Browse files Browse the repository at this point in the history
Fix Mutex#sleep race condition
  • Loading branch information
seven1m authored Dec 1, 2024
2 parents 73c613c + e853dba commit a1de4cc
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 45 deletions.
17 changes: 9 additions & 8 deletions include/natalie/thread_object.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class ThreadObject : public Object {
Value run(Env *);
Value wakeup(Env *);

Value sleep(Env *, float);
Value sleep(Env *, float, Thread::MutexObject * = nullptr);

void set_value(Value value) { m_value = value; }
Value value(Env *);
Expand Down Expand Up @@ -252,10 +252,10 @@ class ThreadObject : public Object {

static void check_current_exception(Env *env);

static void setup_interrupt_pipe(Env *env);
static int interrupt_read_fileno() { return s_interrupt_read_fileno; }
static void interrupt();
static void clear_interrupt();
static void setup_wake_pipe(Env *env);
static int wake_pipe_read_fileno() { return s_wake_pipe_read_fileno; }
static void wake_all();
static void clear_wake_pipe();

static ClassObject *thread_kill_class() { return s_thread_kill_class; }
static ClassObject *thread_kill_class(Env *env) {
Expand Down Expand Up @@ -318,6 +318,7 @@ class ThreadObject : public Object {

// This condition variable is used to wake a sleeping thread,
// i.e. a thread where Kernel#sleep has been called.
bool m_wakeup { false };
std::condition_variable m_sleep_cond;
std::mutex m_sleep_lock;

Expand All @@ -329,13 +330,13 @@ class ThreadObject : public Object {
// In addition to m_sleep_cond which can wake a sleeping thread,
// we also need a way to wake a thread that is blocked on reading
// from a file descriptor. Any time select(2) is called, we can
// add this s_interrupt_read_fileno to the fd_set so we have
// add this s_wake_pipe_read_fileno to the fd_set so we have
// a way to unblock the call. This is also signaled when any
// IO object is closed, since we'll need to wake up any blocking
// select() calls and check if the IO object was closed.
// TODO: we'll need to rebuild these after a fork :-/
inline static int s_interrupt_read_fileno { -1 };
inline static int s_interrupt_write_fileno { -1 };
inline static int s_wake_pipe_read_fileno { -1 };
inline static int s_wake_pipe_write_fileno { -1 };

// We use this special class as an off-the-books exception class
// for killing threads. It cannot be rescued in user code, but it
Expand Down
24 changes: 12 additions & 12 deletions src/io_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ Value IoObject::close(Env *env) {
// Wake up all threads in case one is blocking on a read to this fd.
// It is undefined behavior on Linux to continue a read() or select()
// on a closed file descriptor.
ThreadObject::interrupt();
ThreadObject::wake_all();

int result;
if (m_fileptr && m_pid > 0) {
Expand Down Expand Up @@ -1070,15 +1070,15 @@ Value IoObject::select(Env *env, Value read_ios, Value write_ios, Value error_io
auto write_ios_ary = write_ios && !write_ios->is_nil() ? write_ios->to_ary(env) : new ArrayObject {};
auto error_ios_ary = error_ios && !error_ios->is_nil() ? error_ios->to_ary(env) : new ArrayObject {};

auto interrupt_fileno = ThreadObject::interrupt_read_fileno();
auto wake_pipe_fileno = ThreadObject::wake_pipe_read_fileno();

int nfds = 0;
auto read_fds = create_fd_set(env, read_ios_ary, &nfds);
auto write_fds = create_fd_set(env, write_ios_ary, &nfds);
auto error_fds = create_fd_set(env, error_ios_ary, &nfds);

FD_SET(interrupt_fileno, &read_fds);
nfds = std::max(nfds, interrupt_fileno + 1);
FD_SET(wake_pipe_fileno, &read_fds);
nfds = std::max(nfds, wake_pipe_fileno + 1);

fd_set read_fds_copy = read_fds;
fd_set write_fds_copy = write_fds;
Expand All @@ -1096,10 +1096,10 @@ Value IoObject::select(Env *env, Value read_ios, Value write_ios, Value error_io
} else if (result == -1) {
// An error the user needs to handle.
break;
} else if (FD_ISSET(interrupt_fileno, &read_fds)) {
} else if (FD_ISSET(wake_pipe_fileno, &read_fds)) {
// Interrupted by our thread file descriptor.
// This thread may need to raise or exit.
ThreadObject::clear_interrupt();
ThreadObject::clear_wake_pipe();
ThreadObject::check_current_exception(env);
if (any_closed(read_ios_ary) || any_closed(write_ios_ary) || any_closed(error_ios_ary))
env->raise("IOError", "closed stream");
Expand All @@ -1120,7 +1120,7 @@ Value IoObject::select(Env *env, Value read_ios, Value write_ios, Value error_io
if (result == 0)
return NilObject::the();

FD_CLR(interrupt_fileno, &read_fds);
FD_CLR(wake_pipe_fileno, &read_fds);

auto readable_ios = create_output_fds(env, &read_fds, read_ios_ary);
auto writeable_ios = create_output_fds(env, &write_fds, write_ios_ary);
Expand All @@ -1132,9 +1132,9 @@ void IoObject::select_read(Env *env, timeval *timeout) const {
fd_set readfds;
FD_ZERO(&readfds);
FD_SET(m_fileno, &readfds);
auto interrupt_fileno = ThreadObject::interrupt_read_fileno();
FD_SET(interrupt_fileno, &readfds);
auto nfds = std::max(m_fileno, interrupt_fileno) + 1;
auto wake_pipe_fileno = ThreadObject::wake_pipe_read_fileno();
FD_SET(wake_pipe_fileno, &readfds);
auto nfds = std::max(m_fileno, wake_pipe_fileno) + 1;

fd_set readfds_copy = readfds;

Expand All @@ -1161,8 +1161,8 @@ void IoObject::select_read(Env *env, timeval *timeout) const {
}
}

if (FD_ISSET(interrupt_fileno, &readfds)) {
ThreadObject::clear_interrupt();
if (FD_ISSET(wake_pipe_fileno, &readfds)) {
ThreadObject::clear_wake_pipe();
ThreadObject::check_current_exception(env);
if (m_closed)
env->raise("IOError", "closed stream");
Expand Down
13 changes: 9 additions & 4 deletions src/thread/conditionvariable.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ def initialize

def broadcast
@mutex.synchronize do
while !@waiting.empty?
until @waiting.empty?
thread = @waiting.shift
thread.wakeup if thread&.status == 'sleep'
if thread.status != 'dead'
thread.wakeup
end
end
end
end
Expand All @@ -21,10 +23,13 @@ def marshal_dump
def signal
@mutex.synchronize do
thread = nil
while !@waiting.empty? && thread&.status != 'sleep'
until @waiting.empty?
thread = @waiting.shift
if thread.status != 'dead'
thread.wakeup
break
end
end
thread&.wakeup
end
end

Expand Down
6 changes: 2 additions & 4 deletions src/thread/mutex_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ Value MutexObject::lock(Env *env) {

Value MutexObject::sleep(Env *env, Value timeout) {
if (!timeout || timeout->is_nil()) {
unlock(env);
ThreadObject::current()->sleep(env, -1.0);
ThreadObject::current()->sleep(env, -1.0, this);
lock(env);
return this;
}
Expand All @@ -43,9 +42,8 @@ Value MutexObject::sleep(Env *env, Value timeout) {
if (timeout_int < 0)
env->raise("ArgumentError", "timeout must be positive");

unlock(env);
const auto timeout_float = timeout->is_float() ? static_cast<float>(timeout->as_float()->to_double()) : static_cast<float>(timeout_int);
ThreadObject::current()->sleep(env, timeout_float);
ThreadObject::current()->sleep(env, timeout_float, this);
lock(env);

return Value::integer(timeout_int);
Expand Down
46 changes: 29 additions & 17 deletions src/thread_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <signal.h>

#include "natalie.hpp"
#include "natalie/thread/mutex_object.hpp"
#include "natalie/thread_object.hpp"

static void set_stack_for_thread(Natalie::ThreadObject *thread_object) {
Expand Down Expand Up @@ -197,7 +198,7 @@ void ThreadObject::finish_main_thread_setup(Env *env, void *start_of_stack) {
set_stack_for_thread(thread);
thread->build_main_fiber();
add_to_list(thread);
setup_interrupt_pipe(env);
setup_wake_pipe(env);
}

void ThreadObject::build_main_fiber() {
Expand Down Expand Up @@ -354,7 +355,7 @@ Value ThreadObject::kill(Env *env) {
} else {
m_exception = exception;
wakeup(env);
ThreadObject::interrupt();
ThreadObject::wake_all();
}

return this;
Expand All @@ -379,7 +380,7 @@ Value ThreadObject::raise(Env *env, Args &&args) {

// In case this thread is blocking on read/select/whatever,
// we may need to interrupt it (and all other threads, incidentally).
ThreadObject::interrupt();
ThreadObject::wake_all();

return NilObject::the();
}
Expand All @@ -397,13 +398,14 @@ Value ThreadObject::wakeup(Env *env) {

{
std::unique_lock sleep_lock { m_sleep_lock };
m_wakeup = true;
m_sleep_cond.notify_one();
}

return this;
}

Value ThreadObject::sleep(Env *env, float timeout) {
Value ThreadObject::sleep(Env *env, float timeout, Thread::MutexObject *mutex_to_unlock) {
timespec t_begin;
if (::clock_gettime(CLOCK_MONOTONIC, &t_begin) < 0)
env->raise_errno();
Expand All @@ -417,16 +419,23 @@ Value ThreadObject::sleep(Env *env, float timeout) {
return Value::integer(elapsed);
};

m_wakeup = false;
if (mutex_to_unlock)
mutex_to_unlock->unlock(env);

if (timeout < 0.0) {
{
std::unique_lock sleep_lock { m_sleep_lock };

check_exception(env);

Defer done_sleeping([] { ThreadObject::set_current_sleeping(false); });
ThreadObject::set_current_sleeping(true);
if (!m_wakeup) {
Defer done_sleeping([] { ThreadObject::set_current_sleeping(false); });
ThreadObject::set_current_sleeping(true);

m_sleep_cond.wait(sleep_lock);
m_sleep_cond.wait(sleep_lock);
}
m_wakeup = false;
}

check_exception(env);
Expand All @@ -440,10 +449,13 @@ Value ThreadObject::sleep(Env *env, float timeout) {

check_exception(env);

Defer done_sleeping([] { ThreadObject::set_current_sleeping(false); });
ThreadObject::set_current_sleeping(true);
if (!m_wakeup) {
Defer done_sleeping([] { ThreadObject::set_current_sleeping(false); });
ThreadObject::set_current_sleeping(true);

m_sleep_cond.wait_for(sleep_lock, wait);
m_sleep_cond.wait_for(sleep_lock, wait);
}
m_wakeup = false;
}

check_exception(env);
Expand Down Expand Up @@ -666,27 +678,27 @@ void ThreadObject::wait_until_running() const {
sched_yield();
}

void ThreadObject::setup_interrupt_pipe(Env *env) {
void ThreadObject::setup_wake_pipe(Env *env) {
int pipefd[2];
if (pipe2(pipefd, O_CLOEXEC | O_NONBLOCK) < 0)
env->raise_errno();
s_interrupt_read_fileno = pipefd[0];
s_interrupt_write_fileno = pipefd[1];
s_wake_pipe_read_fileno = pipefd[0];
s_wake_pipe_write_fileno = pipefd[1];
}

void ThreadObject::interrupt() {
assert(::write(s_interrupt_write_fileno, "!", 1) != -1);
void ThreadObject::wake_all() {
assert(::write(s_wake_pipe_write_fileno, "!", 1) != -1);
}

void ThreadObject::clear_interrupt() {
void ThreadObject::clear_wake_pipe() {
// arbitrarily-chosen buffer size just to avoid several single-byte reads
constexpr int BUF_SIZE = 8;
char buf[BUF_SIZE];
ssize_t bytes;
do {
// This fd is non-blocking, so this can set errno to EAGAIN,
// but we don't care -- we just want the buffer to be empty.
bytes = ::read(s_interrupt_read_fileno, buf, BUF_SIZE);
bytes = ::read(s_wake_pipe_read_fileno, buf, BUF_SIZE);
} while (bytes > 0);
}

Expand Down

0 comments on commit a1de4cc

Please sign in to comment.