Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

thread utils respect contextvars #4074

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions backend/onyx/utils/threadpool_concurrency.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextvars
import threading
import uuid
from collections.abc import Callable
Expand All @@ -14,10 +15,6 @@
R = TypeVar("R")


# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
# is not safe, update this comment.
def run_functions_tuples_in_parallel(
functions_with_args: list[tuple[Callable, tuple]],
allow_failures: bool = False,
Expand Down Expand Up @@ -46,7 +43,7 @@ def run_functions_tuples_in_parallel(
results = []
with ThreadPoolExecutor(max_workers=workers) as executor:
future_to_index = {
executor.submit(func, *args): i
executor.submit(contextvars.copy_context().run, func, *args): i
for i, (func, args) in enumerate(functions_with_args)
}

Expand Down Expand Up @@ -83,10 +80,6 @@ def execute(self) -> R:
return self.func(*self.args, **self.kwargs)


# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
# is not safe, update this comment.
def run_functions_in_parallel(
function_calls: list[FunctionCall],
allow_failures: bool = False,
Expand All @@ -102,7 +95,9 @@ def run_functions_in_parallel(

with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
future_to_id = {
executor.submit(func_call.execute): func_call.result_id
executor.submit(
contextvars.copy_context().run, func_call.execute
): func_call.result_id
for func_call in function_calls
}

Expand Down Expand Up @@ -143,18 +138,15 @@ def end(self) -> None:
)


# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
# is not safe, update this comment.
def run_with_timeout(
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
) -> R:
"""
Executes a function with a timeout. If the function doesn't complete within the specified
timeout, raises TimeoutError.
"""
task = TimeoutThread(timeout, func, *args, **kwargs)
context = contextvars.copy_context()
task = TimeoutThread(timeout, context.run, func, *args, **kwargs)
task.start()
task.join(timeout)

Expand Down
131 changes: 131 additions & 0 deletions backend/tests/unit/onyx/utils/test_threadpool_contextvars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import contextvars
import time

from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import run_with_timeout

# Create a test contextvar
test_var = contextvars.ContextVar("test_var", default="default")


def get_contextvar_value() -> str:
"""Helper function that runs in a thread and returns the contextvar value"""
# Add a small sleep to ensure we're actually running in a different thread
time.sleep(0.1)
return test_var.get()


def test_run_with_timeout_preserves_contextvar() -> None:
"""Test that run_with_timeout preserves contextvar values"""
# Set a value in the main thread
test_var.set("test_value")

# Run function with timeout and verify the value is preserved
result = run_with_timeout(1.0, get_contextvar_value)
assert result == "test_value"


def test_run_functions_in_parallel_preserves_contextvar() -> None:
"""Test that run_functions_in_parallel preserves contextvar values"""
# Set a value in the main thread
test_var.set("parallel_test")

# Create multiple function calls
function_calls = [
FunctionCall(get_contextvar_value),
FunctionCall(get_contextvar_value),
]

# Run in parallel and verify all results have the correct value
results = run_functions_in_parallel(function_calls)

for result_id, value in results.items():
assert value == "parallel_test"


def test_run_functions_tuples_preserves_contextvar() -> None:
"""Test that run_functions_tuples_in_parallel preserves contextvar values"""
# Set a value in the main thread
test_var.set("tuple_test")

# Create list of function tuples
functions_with_args = [
(get_contextvar_value, ()),
(get_contextvar_value, ()),
]

# Run in parallel and verify all results have the correct value
results = run_functions_tuples_in_parallel(functions_with_args)

for result in results:
assert result == "tuple_test"


def test_nested_contextvar_modifications() -> None:
"""Test that modifications to contextvars in threads don't affect other threads"""

def modify_and_return_contextvar(new_value: str) -> tuple[str, str]:
"""Helper that modifies the contextvar and returns both values"""
original = test_var.get()
test_var.set(new_value)
time.sleep(0.1) # Ensure threads overlap
return original, test_var.get()

# Set initial value
test_var.set("initial")

# Run multiple functions that modify the contextvar
functions_with_args = [
(modify_and_return_contextvar, ("thread1",)),
(modify_and_return_contextvar, ("thread2",)),
]

results = run_functions_tuples_in_parallel(functions_with_args)

# Verify each thread saw the initial value and its own modification
for original, modified in results:
assert original == "initial" # Each thread should see the initial value
assert modified in [
"thread1",
"thread2",
] # Each thread should see its own modification

# Verify the main thread's value wasn't affected
assert test_var.get() == "initial"


def test_contextvar_isolation_between_runs() -> None:
"""Test that contextvar changes don't leak between separate parallel runs"""

def set_and_return_contextvar(value: str) -> str:
test_var.set(value)
return test_var.get()

# First run
test_var.set("first_run")
first_results = run_functions_tuples_in_parallel(
[
(set_and_return_contextvar, ("thread1",)),
(set_and_return_contextvar, ("thread2",)),
]
)

# Verify first run results
assert all(result in ["thread1", "thread2"] for result in first_results)

# Second run should still see the main thread's value
assert test_var.get() == "first_run"

# Second run with different value
test_var.set("second_run")
second_results = run_functions_tuples_in_parallel(
[
(set_and_return_contextvar, ("thread3",)),
(set_and_return_contextvar, ("thread4",)),
]
)

# Verify second run results
assert all(result in ["thread3", "thread4"] for result in second_results)
Loading