Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug where wait_to_thread() would await coroutines using the wrong event loops #2037

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion betty/ancestry/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ async def linked_data_schema(cls, project: Project) -> JsonLdObject:
Enum(
*[
presence_role.plugin_id()
for presence_role in wait_to_thread(EVENT_TYPE_REPOSITORY.select())
for presence_role in wait_to_thread(EVENT_TYPE_REPOSITORY.select)
],
title="Event type",
),
Expand Down
2 changes: 1 addition & 1 deletion betty/ancestry/presence_role/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self):
super().__init__(
*[
presence_role.plugin_id()
for presence_role in wait_to_thread(PRESENCE_ROLE_REPOSITORY.select())
for presence_role in wait_to_thread(PRESENCE_ROLE_REPOSITORY.select)
],
def_name="presenceRole",
title="Presence role",
Expand Down
7 changes: 3 additions & 4 deletions betty/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,8 @@ def localizer(self) -> Localizer:
if self._localizer is None:
self._assert_bootstrapped()
self._localizer = wait_to_thread(
self.localizers.get_negotiated(
self.configuration.locale or DEFAULT_LOCALE
)
self.localizers.get_negotiated,
self.configuration.locale or DEFAULT_LOCALE,
)
return self._localizer

Expand Down Expand Up @@ -150,7 +149,7 @@ def http_client(self) -> aiohttp.ClientSession:
},
)
wait_to_thread(
self._async_exit_stack.enter_async_context(self._http_client)
self._async_exit_stack.enter_async_context, self._http_client
)
return self._http_client

Expand Down
41 changes: 28 additions & 13 deletions betty/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,26 @@

from __future__ import annotations

from asyncio import TaskGroup, run
from asyncio import TaskGroup, get_running_loop, run_coroutine_threadsafe
from concurrent.futures import wait
from threading import Thread
from typing import (
Awaitable,
TypeVar,
Generic,
cast,
Coroutine,
Any,
ParamSpec,
TYPE_CHECKING,
cast,
)
from typing_extensions import override

if TYPE_CHECKING:
from collections.abc import Callable

_T = TypeVar("_T")
_P = ParamSpec("_P")


async def gather(*coroutines: Coroutine[Any, None, _T]) -> tuple[_T, ...]:
Expand All @@ -31,20 +39,24 @@ async def gather(*coroutines: Coroutine[Any, None, _T]) -> tuple[_T, ...]:
return tuple(task.result() for task in tasks)


def wait_to_thread(f: Awaitable[_T]) -> _T:
"""
Wait for an awaitable in another thread.
"""
synced = _WaiterThread(f)
def wait_to_thread(
Copy link
Owner Author

@bartfeenstra bartfeenstra Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extract this API change into a separate PR

f: Callable[_P, Awaitable[_T]], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
synced = _WaiterThread(f, *args, **kwargs)
synced.start()
synced.join()
return synced.return_value


class _WaiterThread(Thread, Generic[_T]):
def __init__(self, awaitable: Awaitable[_T]):
def __init__(
self, f: Callable[_P, Awaitable[_T]], *args: _P.args, **kwargs: _P.kwargs
):
super().__init__()
self._awaitable = awaitable
self._loop = get_running_loop()
self._f = f
self._args = args
self._kwargs = kwargs
self._return_value: _T | None = None
self._e: BaseException | None = None

Expand All @@ -54,12 +66,15 @@ def return_value(self) -> _T:
raise self._e
return cast(_T, self._return_value)

@override
def run(self) -> None:
run(self._run())

async def _run(self) -> None:
try:
self._return_value = await self._awaitable
future = run_coroutine_threadsafe(
self._f(*self._args, **self._kwargs), self._loop
)
wait([future])
self._return_value = future.result()
foo()
except BaseException as e: # noqa: B036
# Store the exception, so it can be reraised when the calling thread
# gets self.return_value.
Expand Down
4 changes: 2 additions & 2 deletions betty/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def list_commands(self, ctx: click.Context) -> Iterable[str]:
self._bootstrap()
return [
command.plugin_id()
for command in wait_to_thread(commands.COMMAND_REPOSITORY.select())
for command in wait_to_thread(commands.COMMAND_REPOSITORY.select)
]

@override
Expand All @@ -95,7 +95,7 @@ def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None
self._bootstrap()
try:
return wait_to_thread(
commands.COMMAND_REPOSITORY.get(cmd_name)
commands.COMMAND_REPOSITORY.get, cmd_name
).click_command()
except PluginNotFound:
return None
Expand Down
23 changes: 10 additions & 13 deletions betty/cli/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
if TYPE_CHECKING:
from betty.app import App
from betty.machine_name import MachineName
from collections.abc import Callable, Coroutine
from collections.abc import Callable

_T = TypeVar("_T")
_P = ParamSpec("_P")
Expand Down Expand Up @@ -164,7 +164,7 @@ def invoke(self, ctx: click.Context) -> Any:


@overload
def command(name: Callable[..., Coroutine[Any, Any, Any]]) -> BettyCommand:
def command(name: Callable[..., Any]) -> BettyCommand:
pass


Expand All @@ -173,7 +173,7 @@ def command(
name: str | None,
cls: type[_BettyCommandT],
**attrs: Any,
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], _BettyCommandT]:
) -> Callable[[Callable[..., Any]], _BettyCommandT]:
pass


Expand All @@ -183,25 +183,22 @@ def command(
*,
cls: type[_BettyCommandT],
**attrs: Any,
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], _BettyCommandT]:
) -> Callable[[Callable[..., Any]], _BettyCommandT]:
pass


@overload
def command(
name: str | None = None, cls: None = None, **attrs: Any
) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], BettyCommand]:
) -> Callable[[Callable[..., Any]], BettyCommand]:
pass


def command(
name: str | None | Callable[..., Coroutine[Any, Any, Any]] = None,
name: str | None | Callable[..., Any] = None,
cls: type[BettyCommand] | None = None,
**attrs: Any,
) -> (
click.Command
| Callable[[Callable[..., Coroutine[Any, Any, Any]]], click.Command | BettyCommand]
):
) -> click.Command | Callable[[Callable[..., Any]], click.Command | BettyCommand]:
"""
Mark something a Betty command.

Expand All @@ -216,7 +213,7 @@ def command(
if cls is None:
cls = BettyCommand

def decorator(f: Callable[..., Coroutine[Any, Any, Any]]) -> BettyCommand:
def decorator(f: Callable[..., Any]) -> BettyCommand:
@click.command(cast(str | None, name), cls, **attrs)
@click.option(
"-v",
Expand Down Expand Up @@ -252,7 +249,7 @@ def decorator(f: Callable[..., Coroutine[Any, Any, Any]]) -> BettyCommand:
)
@wraps(f)
def _command(*args: _P.args, **kwargs: _P.kwargs) -> None:
wait_to_thread(f(*args, **kwargs))
f(*args, **kwargs)

return _command # type: ignore[return-value]

Expand Down Expand Up @@ -364,7 +361,7 @@ def _project(
project: Project = ctx.with_resource( # type: ignore[attr-defined]
SynchronizedContextManager(Project.new_temporary(ctx_app(ctx)))
)
wait_to_thread(_read_project_configuration(project, configuration_file_path))
wait_to_thread(_read_project_configuration, project, configuration_file_path)
ctx.with_resource( # type: ignore[attr-defined]
SynchronizedContextManager(project)
)
Expand Down
5 changes: 3 additions & 2 deletions betty/cli/commands/clear_caches.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import TYPE_CHECKING

from betty.asyncio import wait_to_thread
from betty.cli.commands import command, pass_app
from betty.typing import internal

Expand All @@ -13,6 +14,6 @@
@internal
@command(help="Clear all caches.")
@pass_app
async def clear_caches(app: App) -> None: # noqa D103
await app.cache.clear()
def clear_caches(app: App) -> None: # noqa D103
wait_to_thread(app.cache.clear)
logging.getLogger(__name__).info(app.localizer._("All caches cleared."))
7 changes: 6 additions & 1 deletion betty/cli/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import click

from betty.app import config as app_config
from betty.asyncio import wait_to_thread
from betty.cli.commands import command, pass_app
from betty.config import write_configuration_file
from betty.locale import DEFAULT_LOCALE, get_display_name
Expand All @@ -24,7 +25,11 @@
help="Set the locale for Betty's user interface. This must be an IETF BCP 47 language tag.",
)
@pass_app
async def config(app: App, *, locale: str) -> None: # noqa D103
def config(app: App, *, locale: str) -> None: # noqa D103
wait_to_thread(_config, app, locale=locale)


async def _config(app: App, *, locale: str) -> None:
logger = getLogger(__name__)
app.configuration.locale = locale
localizer = await app.localizers.get(locale)
Expand Down
7 changes: 6 additions & 1 deletion betty/cli/commands/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
from typing import TYPE_CHECKING

from betty.asyncio import wait_to_thread
from betty.cli.commands import command, pass_app
from betty.typing import internal

Expand All @@ -13,7 +14,11 @@
@internal
@command(help="Explore a demonstration site.")
@pass_app
async def demo(app: App) -> None: # noqa D103
def demo(app: App) -> None: # noqa D103
wait_to_thread(_demo, app)


async def _demo(app: App) -> None:
from betty.project.extension.demo import DemoServer

async with DemoServer(app=app) as server:
Expand Down
5 changes: 3 additions & 2 deletions betty/cli/commands/dev_new_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import click

from betty.assertion import assert_locale
from betty.asyncio import wait_to_thread
from betty.cli.commands import command, parameter_callback
from betty.locale import translation
from betty.typing import internal
Expand All @@ -14,5 +15,5 @@
help="Create a new translation.\n\nThis is available only when developing Betty.",
)
@click.argument("locale", required=True, callback=parameter_callback(assert_locale()))
async def dev_new_translation(locale: str) -> None: # noqa D103
await translation.new_dev_translation(locale)
def dev_new_translation(locale: str) -> None: # noqa D103
wait_to_thread(translation.new_dev_translation, locale)
5 changes: 3 additions & 2 deletions betty/cli/commands/dev_update_translations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations # noqa D100

from betty.asyncio import wait_to_thread
from betty.cli.commands import command
from betty.locale import translation
from betty.typing import internal
Expand All @@ -10,5 +11,5 @@
short_help="Update all existing translations for Betty itself",
help="Update all existing translations.\n\nThis is available only when developing Betty.",
)
async def dev_update_translations() -> None: # noqa D103
await translation.update_dev_translations()
def dev_update_translations() -> None: # noqa D103
wait_to_thread(translation.update_dev_translations)
7 changes: 6 additions & 1 deletion betty/cli/commands/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING

from betty import documentation
from betty.asyncio import wait_to_thread
from betty.cli.commands import command, pass_app
from betty.typing import internal

Expand All @@ -14,7 +15,11 @@
@internal
@command(help="View the documentation.")
@pass_app
async def docs(app: App): # noqa D103
def docs(app: App): # noqa D103
wait_to_thread(_docs, app)


async def _docs(app: App):
server = documentation.DocumentationServer(
app.binary_file_cache.path,
localizer=app.localizer,
Expand Down
6 changes: 3 additions & 3 deletions betty/cli/commands/extension_new_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
required=True,
callback=parameter_callback(
lambda extension_id: wait_to_thread(
extension.EXTENSION_REPOSITORY.get(extension_id)
extension.EXTENSION_REPOSITORY.get, extension_id
)
),
)
@click.argument("locale", required=True, callback=parameter_callback(assert_locale()))
@pass_app
async def extension_new_translation( # noqa D103
def extension_new_translation( # noqa D103
app: App, extension: type[Extension], locale: str
) -> None:
with user_facing_error_to_bad_parameter(app.localizer):
await translation.new_extension_translation(locale, extension)
wait_to_thread(translation.new_extension_translation, locale, extension)
8 changes: 5 additions & 3 deletions betty/cli/commands/extension_update_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
required=True,
callback=parameter_callback(
lambda extension_id: wait_to_thread(
extension.EXTENSION_REPOSITORY.get(extension_id)
extension.EXTENSION_REPOSITORY.get, extension_id
)
),
)
Expand All @@ -47,8 +47,10 @@
callback=parameter_callback(assert_sequence(assert_directory_path())),
)
@pass_app
async def extension_update_translations( # noqa D103
def extension_update_translations( # noqa D103
app: App, extension: type[Extension], source: Path, exclude: tuple[Path]
) -> None:
with user_facing_error_to_bad_parameter(app.localizer):
await translation.update_extension_translations(extension, source, set(exclude))
wait_to_thread(
translation.update_extension_translations, extension, source, set(exclude)
)
7 changes: 6 additions & 1 deletion betty/cli/commands/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING

from betty.asyncio import wait_to_thread
from betty.cli.commands import command, pass_project
from betty.typing import internal

Expand All @@ -12,7 +13,11 @@
@internal
@command(help="Generate a static site.")
@pass_project
async def generate(project: Project) -> None: # noqa D103
def generate(project: Project) -> None: # noqa D103
wait_to_thread(_generate, project)


async def _generate(project: Project) -> None:
from betty.project import generate, load

await load.load(project)
Expand Down
Loading
Loading