From 8bcf83b5f17788f08a533d4af4fcbdbeff6421f6 Mon Sep 17 00:00:00 2001 From: Bart Feenstra Date: Wed, 25 Sep 2024 17:02:13 +0100 Subject: [PATCH 1/4] Fix a bug where wait_to_thread() would await coroutines using the wrong event loops --- betty/ancestry/event.py | 2 +- betty/ancestry/presence_role/__init__.py | 2 +- betty/app/__init__.py | 7 +++--- betty/asyncio.py | 22 ++++++++++++++---- betty/cli/__init__.py | 4 ++-- betty/cli/commands/__init__.py | 23 ++++++++----------- betty/cli/commands/clear_caches.py | 5 ++-- betty/cli/commands/config.py | 7 +++++- betty/cli/commands/demo.py | 7 +++++- betty/cli/commands/dev_new_translation.py | 5 ++-- betty/cli/commands/dev_update_translations.py | 5 ++-- betty/cli/commands/docs.py | 7 +++++- .../cli/commands/extension_new_translation.py | 6 ++--- .../commands/extension_update_translations.py | 8 ++++--- betty/cli/commands/generate.py | 7 +++++- betty/cli/commands/new.py | 7 +++++- betty/cli/commands/new_translation.py | 5 ++-- betty/cli/commands/serve.py | 7 +++++- betty/cli/commands/update_translations.py | 10 +++++--- betty/config/__init__.py | 2 +- betty/contextlib.py | 4 ++-- betty/jinja2/__init__.py | 2 +- betty/jinja2/test.py | 2 +- betty/json/linked_data.py | 2 +- betty/model/collections.py | 4 ++-- betty/plugin/assertion.py | 4 ++-- betty/project/__init__.py | 12 +++++----- betty/project/config.py | 2 +- betty/project/extension/requirement.py | 6 ++--- betty/project/extension/wikipedia/__init__.py | 2 +- betty/requirement.py | 4 ++-- betty/tests/test_asyncio.py | 2 +- documentation/conf.py | 2 +- 33 files changed, 122 insertions(+), 74 deletions(-) diff --git a/betty/ancestry/event.py b/betty/ancestry/event.py index e08aec71c..bd3464fd7 100644 --- a/betty/ancestry/event.py +++ b/betty/ancestry/event.py @@ -206,7 +206,7 @@ async def linked_data_schema(cls, project: Project) -> JsonLdObject: Enum( *[ presence_role.plugin_id() - for presence_role in wait_to_thread(EVENT_TYPE_REPOSITORY.select()) + for presence_role in wait_to_thread(EVENT_TYPE_REPOSITORY.select) ], title="Event type", ), diff --git a/betty/ancestry/presence_role/__init__.py b/betty/ancestry/presence_role/__init__.py index c282f6349..ada85564d 100644 --- a/betty/ancestry/presence_role/__init__.py +++ b/betty/ancestry/presence_role/__init__.py @@ -39,7 +39,7 @@ def __init__(self): super().__init__( *[ presence_role.plugin_id() - for presence_role in wait_to_thread(PRESENCE_ROLE_REPOSITORY.select()) + for presence_role in wait_to_thread(PRESENCE_ROLE_REPOSITORY.select) ], def_name="presenceRole", title="Presence role", diff --git a/betty/app/__init__.py b/betty/app/__init__.py index d4fc47370..429c71ee4 100644 --- a/betty/app/__init__.py +++ b/betty/app/__init__.py @@ -120,9 +120,8 @@ def localizer(self) -> Localizer: if self._localizer is None: self._assert_bootstrapped() self._localizer = wait_to_thread( - self.localizers.get_negotiated( - self.configuration.locale or DEFAULT_LOCALE - ) + self.localizers.get_negotiated, + self.configuration.locale or DEFAULT_LOCALE, ) return self._localizer @@ -150,7 +149,7 @@ def http_client(self) -> aiohttp.ClientSession: }, ) wait_to_thread( - self._async_exit_stack.enter_async_context(self._http_client) + self._async_exit_stack.enter_async_context, self._http_client ) return self._http_client diff --git a/betty/asyncio.py b/betty/asyncio.py index ba1d50c5b..8aaa3f82a 100644 --- a/betty/asyncio.py +++ b/betty/asyncio.py @@ -13,9 +13,15 @@ cast, Coroutine, Any, + ParamSpec, + TYPE_CHECKING, ) +if TYPE_CHECKING: + from collections.abc import Callable + _T = TypeVar("_T") +_P = ParamSpec("_P") async def gather(*coroutines: Coroutine[Any, None, _T]) -> tuple[_T, ...]: @@ -31,20 +37,26 @@ async def gather(*coroutines: Coroutine[Any, None, _T]) -> tuple[_T, ...]: return tuple(task.result() for task in tasks) -def wait_to_thread(f: Awaitable[_T]) -> _T: +def wait_to_thread( + f: Callable[_P, Awaitable[_T]], *args: _P.args, **kwargs: _P.kwargs +) -> _T: """ Wait for an awaitable in another thread. """ - synced = _WaiterThread(f) + synced = _WaiterThread(f, *args, **kwargs) synced.start() synced.join() return synced.return_value class _WaiterThread(Thread, Generic[_T]): - def __init__(self, awaitable: Awaitable[_T]): + def __init__( + self, f: Callable[_P, Awaitable[_T]], *args: _P.args, **kwargs: _P.kwargs + ): super().__init__() - self._awaitable = awaitable + self._f = f + self._args = args + self._kwargs = kwargs self._return_value: _T | None = None self._e: BaseException | None = None @@ -59,7 +71,7 @@ def run(self) -> None: async def _run(self) -> None: try: - self._return_value = await self._awaitable + self._return_value = await self._f(*self._args, **self._kwargs) except BaseException as e: # noqa: B036 # Store the exception, so it can be reraised when the calling thread # gets self.return_value. diff --git a/betty/cli/__init__.py b/betty/cli/__init__.py index 14e8162b2..4c8fbbc40 100644 --- a/betty/cli/__init__.py +++ b/betty/cli/__init__.py @@ -85,7 +85,7 @@ def list_commands(self, ctx: click.Context) -> Iterable[str]: self._bootstrap() return [ command.plugin_id() - for command in wait_to_thread(commands.COMMAND_REPOSITORY.select()) + for command in wait_to_thread(commands.COMMAND_REPOSITORY.select) ] @override @@ -95,7 +95,7 @@ def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None self._bootstrap() try: return wait_to_thread( - commands.COMMAND_REPOSITORY.get(cmd_name) + commands.COMMAND_REPOSITORY.get, cmd_name ).click_command() except PluginNotFound: return None diff --git a/betty/cli/commands/__init__.py b/betty/cli/commands/__init__.py index 5edf7f382..e12b137d2 100644 --- a/betty/cli/commands/__init__.py +++ b/betty/cli/commands/__init__.py @@ -42,7 +42,7 @@ if TYPE_CHECKING: from betty.app import App from betty.machine_name import MachineName - from collections.abc import Callable, Coroutine + from collections.abc import Callable _T = TypeVar("_T") _P = ParamSpec("_P") @@ -164,7 +164,7 @@ def invoke(self, ctx: click.Context) -> Any: @overload -def command(name: Callable[..., Coroutine[Any, Any, Any]]) -> BettyCommand: +def command(name: Callable[..., Any]) -> BettyCommand: pass @@ -173,7 +173,7 @@ def command( name: str | None, cls: type[_BettyCommandT], **attrs: Any, -) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], _BettyCommandT]: +) -> Callable[[Callable[..., Any]], _BettyCommandT]: pass @@ -183,25 +183,22 @@ def command( *, cls: type[_BettyCommandT], **attrs: Any, -) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], _BettyCommandT]: +) -> Callable[[Callable[..., Any]], _BettyCommandT]: pass @overload def command( name: str | None = None, cls: None = None, **attrs: Any -) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], BettyCommand]: +) -> Callable[[Callable[..., Any]], BettyCommand]: pass def command( - name: str | None | Callable[..., Coroutine[Any, Any, Any]] = None, + name: str | None | Callable[..., Any] = None, cls: type[BettyCommand] | None = None, **attrs: Any, -) -> ( - click.Command - | Callable[[Callable[..., Coroutine[Any, Any, Any]]], click.Command | BettyCommand] -): +) -> click.Command | Callable[[Callable[..., Any]], click.Command | BettyCommand]: """ Mark something a Betty command. @@ -216,7 +213,7 @@ def command( if cls is None: cls = BettyCommand - def decorator(f: Callable[..., Coroutine[Any, Any, Any]]) -> BettyCommand: + def decorator(f: Callable[..., Any]) -> BettyCommand: @click.command(cast(str | None, name), cls, **attrs) @click.option( "-v", @@ -252,7 +249,7 @@ def decorator(f: Callable[..., Coroutine[Any, Any, Any]]) -> BettyCommand: ) @wraps(f) def _command(*args: _P.args, **kwargs: _P.kwargs) -> None: - wait_to_thread(f(*args, **kwargs)) + f(*args, **kwargs) return _command # type: ignore[return-value] @@ -364,7 +361,7 @@ def _project( project: Project = ctx.with_resource( # type: ignore[attr-defined] SynchronizedContextManager(Project.new_temporary(ctx_app(ctx))) ) - wait_to_thread(_read_project_configuration(project, configuration_file_path)) + wait_to_thread(_read_project_configuration, project, configuration_file_path) ctx.with_resource( # type: ignore[attr-defined] SynchronizedContextManager(project) ) diff --git a/betty/cli/commands/clear_caches.py b/betty/cli/commands/clear_caches.py index d0d3bb5d9..b5b566418 100644 --- a/betty/cli/commands/clear_caches.py +++ b/betty/cli/commands/clear_caches.py @@ -3,6 +3,7 @@ import logging from typing import TYPE_CHECKING +from betty.asyncio import wait_to_thread from betty.cli.commands import command, pass_app from betty.typing import internal @@ -13,6 +14,6 @@ @internal @command(help="Clear all caches.") @pass_app -async def clear_caches(app: App) -> None: # noqa D103 - await app.cache.clear() +def clear_caches(app: App) -> None: # noqa D103 + wait_to_thread(app.cache.clear) logging.getLogger(__name__).info(app.localizer._("All caches cleared.")) diff --git a/betty/cli/commands/config.py b/betty/cli/commands/config.py index dad6b6f61..af4e7fb97 100644 --- a/betty/cli/commands/config.py +++ b/betty/cli/commands/config.py @@ -6,6 +6,7 @@ import click from betty.app import config as app_config +from betty.asyncio import wait_to_thread from betty.cli.commands import command, pass_app from betty.config import write_configuration_file from betty.locale import DEFAULT_LOCALE, get_display_name @@ -24,7 +25,11 @@ help="Set the locale for Betty's user interface. This must be an IETF BCP 47 language tag.", ) @pass_app -async def config(app: App, *, locale: str) -> None: # noqa D103 +def config(app: App, *, locale: str) -> None: # noqa D103 + wait_to_thread(_config, app, locale=locale) + + +async def _config(app: App, *, locale: str) -> None: logger = getLogger(__name__) app.configuration.locale = locale localizer = await app.localizers.get(locale) diff --git a/betty/cli/commands/demo.py b/betty/cli/commands/demo.py index 9ac2b2722..f37df013b 100644 --- a/betty/cli/commands/demo.py +++ b/betty/cli/commands/demo.py @@ -3,6 +3,7 @@ import asyncio from typing import TYPE_CHECKING +from betty.asyncio import wait_to_thread from betty.cli.commands import command, pass_app from betty.typing import internal @@ -13,7 +14,11 @@ @internal @command(help="Explore a demonstration site.") @pass_app -async def demo(app: App) -> None: # noqa D103 +def demo(app: App) -> None: # noqa D103 + wait_to_thread(_demo, app) + + +async def _demo(app: App) -> None: from betty.project.extension.demo import DemoServer async with DemoServer(app=app) as server: diff --git a/betty/cli/commands/dev_new_translation.py b/betty/cli/commands/dev_new_translation.py index 6d96d0528..d03084b9a 100644 --- a/betty/cli/commands/dev_new_translation.py +++ b/betty/cli/commands/dev_new_translation.py @@ -3,6 +3,7 @@ import click from betty.assertion import assert_locale +from betty.asyncio import wait_to_thread from betty.cli.commands import command, parameter_callback from betty.locale import translation from betty.typing import internal @@ -14,5 +15,5 @@ help="Create a new translation.\n\nThis is available only when developing Betty.", ) @click.argument("locale", required=True, callback=parameter_callback(assert_locale())) -async def dev_new_translation(locale: str) -> None: # noqa D103 - await translation.new_dev_translation(locale) +def dev_new_translation(locale: str) -> None: # noqa D103 + wait_to_thread(translation.new_dev_translation, locale) diff --git a/betty/cli/commands/dev_update_translations.py b/betty/cli/commands/dev_update_translations.py index 1ebd32a9f..45a2984e7 100644 --- a/betty/cli/commands/dev_update_translations.py +++ b/betty/cli/commands/dev_update_translations.py @@ -1,5 +1,6 @@ from __future__ import annotations # noqa D100 +from betty.asyncio import wait_to_thread from betty.cli.commands import command from betty.locale import translation from betty.typing import internal @@ -10,5 +11,5 @@ short_help="Update all existing translations for Betty itself", help="Update all existing translations.\n\nThis is available only when developing Betty.", ) -async def dev_update_translations() -> None: # noqa D103 - await translation.update_dev_translations() +def dev_update_translations() -> None: # noqa D103 + wait_to_thread(translation.update_dev_translations) diff --git a/betty/cli/commands/docs.py b/betty/cli/commands/docs.py index 675b8d4df..c392ae3a6 100644 --- a/betty/cli/commands/docs.py +++ b/betty/cli/commands/docs.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from betty import documentation +from betty.asyncio import wait_to_thread from betty.cli.commands import command, pass_app from betty.typing import internal @@ -14,7 +15,11 @@ @internal @command(help="View the documentation.") @pass_app -async def docs(app: App): # noqa D103 +def docs(app: App): # noqa D103 + wait_to_thread(_docs, app) + + +async def _docs(app: App): server = documentation.DocumentationServer( app.binary_file_cache.path, localizer=app.localizer, diff --git a/betty/cli/commands/extension_new_translation.py b/betty/cli/commands/extension_new_translation.py index e2ca7d21f..fcee692c8 100644 --- a/betty/cli/commands/extension_new_translation.py +++ b/betty/cli/commands/extension_new_translation.py @@ -26,14 +26,14 @@ required=True, callback=parameter_callback( lambda extension_id: wait_to_thread( - extension.EXTENSION_REPOSITORY.get(extension_id) + extension.EXTENSION_REPOSITORY.get, extension_id ) ), ) @click.argument("locale", required=True, callback=parameter_callback(assert_locale())) @pass_app -async def extension_new_translation( # noqa D103 +def extension_new_translation( # noqa D103 app: App, extension: type[Extension], locale: str ) -> None: with user_facing_error_to_bad_parameter(app.localizer): - await translation.new_extension_translation(locale, extension) + wait_to_thread(translation.new_extension_translation, locale, extension) diff --git a/betty/cli/commands/extension_update_translations.py b/betty/cli/commands/extension_update_translations.py index f4220979d..ad8e41ac2 100644 --- a/betty/cli/commands/extension_update_translations.py +++ b/betty/cli/commands/extension_update_translations.py @@ -32,7 +32,7 @@ required=True, callback=parameter_callback( lambda extension_id: wait_to_thread( - extension.EXTENSION_REPOSITORY.get(extension_id) + extension.EXTENSION_REPOSITORY.get, extension_id ) ), ) @@ -47,8 +47,10 @@ callback=parameter_callback(assert_sequence(assert_directory_path())), ) @pass_app -async def extension_update_translations( # noqa D103 +def extension_update_translations( # noqa D103 app: App, extension: type[Extension], source: Path, exclude: tuple[Path] ) -> None: with user_facing_error_to_bad_parameter(app.localizer): - await translation.update_extension_translations(extension, source, set(exclude)) + wait_to_thread( + translation.update_extension_translations, extension, source, set(exclude) + ) diff --git a/betty/cli/commands/generate.py b/betty/cli/commands/generate.py index d8f6faa15..a4153e397 100644 --- a/betty/cli/commands/generate.py +++ b/betty/cli/commands/generate.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from betty.asyncio import wait_to_thread from betty.cli.commands import command, pass_project from betty.typing import internal @@ -12,7 +13,11 @@ @internal @command(help="Generate a static site.") @pass_project -async def generate(project: Project) -> None: # noqa D103 +def generate(project: Project) -> None: # noqa D103 + wait_to_thread(_generate, project) + + +async def _generate(project: Project) -> None: from betty.project import generate, load await load.load(project) diff --git a/betty/cli/commands/new.py b/betty/cli/commands/new.py index f49b2c7d4..325a2672d 100644 --- a/betty/cli/commands/new.py +++ b/betty/cli/commands/new.py @@ -5,6 +5,7 @@ import click from betty.assertion import assert_path, assert_str, assert_locale +from betty.asyncio import wait_to_thread from betty.cli.commands import command, pass_app from betty.cli.error import user_facing_error_to_bad_parameter from betty.config import write_configuration_file @@ -54,7 +55,11 @@ def _assert_url(value: Any) -> str: @internal @command(help="Create a new project.") @pass_app -async def new(app: App) -> None: # noqa D103 +def new(app: App) -> None: # noqa D103 + wait_to_thread(_new, app) + + +async def _new(app: App) -> None: configuration_file_path = click.prompt( app.localizer._("Where do you want to save your project's configuration file?"), value_proc=user_facing_error_to_bad_parameter(app.localizer)( diff --git a/betty/cli/commands/new_translation.py b/betty/cli/commands/new_translation.py index 094af2d6b..7c8903c41 100644 --- a/betty/cli/commands/new_translation.py +++ b/betty/cli/commands/new_translation.py @@ -5,6 +5,7 @@ import click from betty.assertion import assert_locale +from betty.asyncio import wait_to_thread from betty.cli.commands import command, pass_project, parameter_callback from betty.locale import translation from betty.typing import internal @@ -17,5 +18,5 @@ @command(short_help="Create a new translation") @click.argument("locale", required=True, callback=parameter_callback(assert_locale())) @pass_project -async def new_translation(project: Project, locale: str) -> None: # noqa D103 - await translation.new_project_translation(locale, project) +def new_translation(project: Project, locale: str) -> None: # noqa D103 + wait_to_thread(translation.new_project_translation, locale, project) diff --git a/betty/cli/commands/serve.py b/betty/cli/commands/serve.py index ded8cb3f0..9b3694c6f 100644 --- a/betty/cli/commands/serve.py +++ b/betty/cli/commands/serve.py @@ -3,6 +3,7 @@ import asyncio from typing import TYPE_CHECKING +from betty.asyncio import wait_to_thread from betty.cli.commands import command, pass_project from betty.typing import internal @@ -13,7 +14,11 @@ @internal @command(help="Serve a generated site.") @pass_project -async def serve(project: Project) -> None: # noqa D103 +def serve(project: Project) -> None: # noqa D103 + wait_to_thread(_serve, project) + + +async def _serve(project: Project) -> None: from betty import serve async with serve.BuiltinProjectServer(project) as server: diff --git a/betty/cli/commands/update_translations.py b/betty/cli/commands/update_translations.py index 47ffd53a0..0f86f7cb4 100644 --- a/betty/cli/commands/update_translations.py +++ b/betty/cli/commands/update_translations.py @@ -10,6 +10,7 @@ assert_none, assert_sequence, ) +from betty.asyncio import wait_to_thread from betty.cli.commands import command, pass_project, parameter_callback from betty.locale import translation from betty.typing import internal @@ -31,9 +32,12 @@ callback=parameter_callback(assert_sequence(assert_directory_path())), ) @pass_project -async def update_translations( # noqa D103 +def update_translations( # noqa D103 project: Project, source: Path | None, exclude: tuple[Path] ) -> None: - await translation.update_project_translations( - project.configuration.project_directory_path, source, set(exclude) + wait_to_thread( + translation.update_project_translations, + project.configuration.project_directory_path, + source, + set(exclude), ) diff --git a/betty/config/__init__.py b/betty/config/__init__.py index 28ee3fc8d..943904b7c 100644 --- a/betty/config/__init__.py +++ b/betty/config/__init__.py @@ -107,7 +107,7 @@ async def _assert(configuration_file_path: Path) -> _ConfigurationT: configuration.load(serde_format.load(read_configuration)) return configuration - return assert_file_path().chain(lambda value: wait_to_thread(_assert(value))) + return assert_file_path().chain(lambda value: wait_to_thread(_assert, value)) async def write_configuration_file( diff --git a/betty/contextlib.py b/betty/contextlib.py index 1183dc8e3..3b05d02f3 100644 --- a/betty/contextlib.py +++ b/betty/contextlib.py @@ -19,7 +19,7 @@ def __init__(self, context_manager: AsyncContextManager[_ContextT]): self._context_manager = context_manager def __enter__(self) -> _ContextT: - return wait_to_thread(self._context_manager.__aenter__()) + return wait_to_thread(self._context_manager.__aenter__) def __exit__( self, @@ -28,5 +28,5 @@ def __exit__( exc_tb: TracebackType | None, ) -> bool | None: return wait_to_thread( - self._context_manager.__aexit__(exc_type, exc_val, exc_tb) + self._context_manager.__aexit__, exc_type, exc_val, exc_tb ) diff --git a/betty/jinja2/__init__.py b/betty/jinja2/__init__.py index 75026aaae..55782cb5a 100644 --- a/betty/jinja2/__init__.py +++ b/betty/jinja2/__init__.py @@ -165,7 +165,7 @@ def __getitem__( ) -> Entity | None: if isinstance(entity_type_or_type_name, str): entity_type = wait_to_thread( - model.ENTITY_TYPE_REPOSITORY.get(entity_type_or_type_name) + model.ENTITY_TYPE_REPOSITORY.get, entity_type_or_type_name ) else: entity_type = entity_type_or_type_name diff --git a/betty/jinja2/test.py b/betty/jinja2/test.py index 7e07af584..6a15e5d2e 100644 --- a/betty/jinja2/test.py +++ b/betty/jinja2/test.py @@ -45,7 +45,7 @@ def test_entity( :param entity_type_id: If given, additionally ensure the value is an entity of this type. """ if isinstance(entity_type_identifier, str): - entity_type = wait_to_thread(ENTITY_TYPE_REPOSITORY.get(entity_type_identifier)) + entity_type = wait_to_thread(ENTITY_TYPE_REPOSITORY.get, entity_type_identifier) elif entity_type_identifier: entity_type = entity_type_identifier else: diff --git a/betty/json/linked_data.py b/betty/json/linked_data.py index d7a407de9..c0c2ccadc 100644 --- a/betty/json/linked_data.py +++ b/betty/json/linked_data.py @@ -80,7 +80,7 @@ def __init__( title=title, description=description, ) - self._schema["allOf"] = [wait_to_thread(JsonLdSchema.new()).embed(self)] + self._schema["allOf"] = [wait_to_thread(JsonLdSchema.new).embed(self)] class LinkedDataDumpableJsonLdObject( diff --git a/betty/model/collections.py b/betty/model/collections.py index 6d3f08a11..2ead6a67e 100644 --- a/betty/model/collections.py +++ b/betty/model/collections.py @@ -328,7 +328,7 @@ def _getitem_by_entity_type_id( self, entity_type_id: MachineName ) -> SingleTypeEntityCollection[Entity]: return self._get_collection( - wait_to_thread(model.ENTITY_TYPE_REPOSITORY.get(entity_type_id)), + wait_to_thread(model.ENTITY_TYPE_REPOSITORY.get, entity_type_id), ) def _getitem_by_index(self, index: int) -> _TargetT & Entity: @@ -362,7 +362,7 @@ def _delitem_by_entity(self, entity: _TargetT & Entity) -> None: def _delitem_by_entity_type_id(self, entity_type_id: MachineName) -> None: self._delitem_by_entity_type( - wait_to_thread(model.ENTITY_TYPE_REPOSITORY.get(entity_type_id)), # type: ignore[arg-type] + wait_to_thread(model.ENTITY_TYPE_REPOSITORY.get, entity_type_id), # type: ignore[arg-type] ) @override diff --git a/betty/plugin/assertion.py b/betty/plugin/assertion.py index 2202809bc..741229557 100644 --- a/betty/plugin/assertion.py +++ b/betty/plugin/assertion.py @@ -25,7 +25,7 @@ def _assert( ) -> type[_PluginT]: plugin_id = assert_str()(value) try: - return wait_to_thread(plugin_repository.get(plugin_id)) + return wait_to_thread(plugin_repository.get, plugin_id) except PluginNotFound: raise AssertionFailed( join( @@ -35,7 +35,7 @@ def _assert( do_you_mean( *( f'"{plugin.plugin_id()}"' - for plugin in wait_to_thread(plugin_repository.select()) + for plugin in wait_to_thread(plugin_repository.select) ) ), ) diff --git a/betty/project/__init__.py b/betty/project/__init__.py index 82982dc73..fb99bcd52 100644 --- a/betty/project/__init__.py +++ b/betty/project/__init__.py @@ -220,7 +220,7 @@ def localized_url_generator(self) -> LocalizedUrlGenerator: if self._url_generator is None: self._assert_bootstrapped() self._url_generator = wait_to_thread( - ProjectLocalizedUrlGenerator.new_for_project(self) + ProjectLocalizedUrlGenerator.new_for_project, self ) return self._url_generator @@ -232,7 +232,7 @@ def static_url_generator(self) -> StaticUrlGenerator: if self._static_url_generator is None: self._assert_bootstrapped() self._static_url_generator = wait_to_thread( - ProjectStaticUrlGenerator.new_for_project(self) + ProjectStaticUrlGenerator.new_for_project, self ) return self._static_url_generator @@ -256,7 +256,7 @@ def renderer(self) -> Renderer: """ if not self._renderer: self._assert_bootstrapped() - self._renderer = wait_to_thread(self._init_renderer()) + self._renderer = wait_to_thread(self._init_renderer) return self._renderer @@ -272,7 +272,7 @@ def extensions(self) -> ProjectExtensions: """ if self._extensions is None: self._assert_bootstrapped() - self._extensions = wait_to_thread(self._init_extensions()) + self._extensions = wait_to_thread(self._init_extensions) return self._extensions @@ -373,7 +373,7 @@ def copyright_notice(self) -> CopyrightNotice: The overall project copyright. """ if self._copyright_notice is None: - self._copyright_notice = wait_to_thread(self._init_copyright()) + self._copyright_notice = wait_to_thread(self._init_copyright) return self._copyright_notice async def _init_copyright(self) -> CopyrightNotice: @@ -511,7 +511,7 @@ def __getitem__( ) -> Extension: if isinstance(extension_identifier, str): extension_type = wait_to_thread( - extension.EXTENSION_REPOSITORY.get(extension_identifier) + extension.EXTENSION_REPOSITORY.get, extension_identifier ) else: extension_type = extension_identifier diff --git a/betty/project/config.py b/betty/project/config.py index d07dce0c0..f95a6469c 100644 --- a/betty/project/config.py +++ b/betty/project/config.py @@ -996,7 +996,7 @@ def configuration_file_path(self) -> Path: def configuration_file_path(self, configuration_file_path: Path) -> None: if configuration_file_path == self._configuration_file_path: return - wait_to_thread(FORMAT_REPOSITORY.format_for(configuration_file_path.suffix)) + wait_to_thread(FORMAT_REPOSITORY.format_for, configuration_file_path.suffix) self._configuration_file_path = configuration_file_path @property diff --git a/betty/project/extension/requirement.py b/betty/project/extension/requirement.py index c87774f6a..f4e440f65 100644 --- a/betty/project/extension/requirement.py +++ b/betty/project/extension/requirement.py @@ -33,7 +33,7 @@ def _get_requirements(self) -> Sequence[Requirement]: return [ ( wait_to_thread( - extension.EXTENSION_REPOSITORY.get(dependency_identifier) + extension.EXTENSION_REPOSITORY.get, dependency_identifier ) if isinstance(dependency_identifier, str) else dependency_identifier @@ -51,7 +51,7 @@ async def summary(self) -> Localizable: lambda localizer: ", ".join( ( wait_to_thread( - extension.EXTENSION_REPOSITORY.get(dependency_identifier) + extension.EXTENSION_REPOSITORY.get, dependency_identifier ) if isinstance(dependency_identifier, str) else dependency_identifier @@ -90,7 +90,7 @@ async def summary(self) -> Localizable: dependent_labels=call( lambda localizer: ", ".join( dependent.plugin_label().localize(localizer) - for dependent in wait_to_thread(self._dependents()) + for dependent in wait_to_thread(self._dependents) ) ), ) diff --git a/betty/project/extension/wikipedia/__init__.py b/betty/project/extension/wikipedia/__init__.py index be73b0f83..255388b5c 100644 --- a/betty/project/extension/wikipedia/__init__.py +++ b/betty/project/extension/wikipedia/__init__.py @@ -85,7 +85,7 @@ def retriever(self) -> _Retriever: @override @property def globals(self) -> Globals: - return wait_to_thread(self._init_globals()) + return wait_to_thread(self._init_globals) async def _init_globals(self) -> Globals: return { diff --git a/betty/requirement.py b/betty/requirement.py index e6e41c830..bde9dbb35 100644 --- a/betty/requirement.py +++ b/betty/requirement.py @@ -54,8 +54,8 @@ async def details(self) -> Localizable | None: @override def localize(self, localizer: Localizer) -> LocalizedStr: - super_localized = wait_to_thread(self.summary()).localize(localizer) - details = wait_to_thread(self.details()) + super_localized = wait_to_thread(self.summary).localize(localizer) + details = wait_to_thread(self.details) localized: str = super_localized if details is not None: localized += f'\n{"-" * len(localized)}' diff --git a/betty/tests/test_asyncio.py b/betty/tests/test_asyncio.py index 94d735360..95aa6064c 100644 --- a/betty/tests/test_asyncio.py +++ b/betty/tests/test_asyncio.py @@ -8,5 +8,5 @@ async def test(self) -> None: async def _async() -> str: return expected - actual = wait_to_thread(_async()) + actual = wait_to_thread(_async) assert actual == expected diff --git a/documentation/conf.py b/documentation/conf.py index 315fc8916..4713effe5 100644 --- a/documentation/conf.py +++ b/documentation/conf.py @@ -17,7 +17,7 @@ assets = AssetRepository(fs.ASSETS_DIRECTORY_PATH) localizers = LocalizerRepository(assets) for locale in localizers.locales: - coverage = wait_to_thread(localizers.coverage(locale)) + coverage = wait_to_thread(localizers.coverage, locale) betty_replacements[f"translation-coverage-{locale}"] = str( int(round(100 / (coverage[1] / coverage[0]))) ) From 3fa79bff807ef2f741884cf6df61da7de6b75985 Mon Sep 17 00:00:00 2001 From: Bart Feenstra Date: Thu, 26 Sep 2024 09:57:26 +0100 Subject: [PATCH 2/4] meh --- betty/asyncio.py | 21 +++++++++-------- test.py | 43 ++++++++++++++++++++++++++++++++++ test2.py | 60 ++++++++++++++++++++++++++++++++++++++++++++++++ test3.py | 25 ++++++++++++++++++++ 4 files changed, 140 insertions(+), 9 deletions(-) create mode 100644 test.py create mode 100644 test2.py create mode 100644 test3.py diff --git a/betty/asyncio.py b/betty/asyncio.py index 8aaa3f82a..1646e0527 100644 --- a/betty/asyncio.py +++ b/betty/asyncio.py @@ -4,18 +4,20 @@ from __future__ import annotations -from asyncio import TaskGroup, run +from asyncio import TaskGroup, get_running_loop, run_coroutine_threadsafe +from concurrent.futures import wait from threading import Thread from typing import ( Awaitable, TypeVar, Generic, - cast, Coroutine, Any, ParamSpec, TYPE_CHECKING, + cast, ) +from typing_extensions import override if TYPE_CHECKING: from collections.abc import Callable @@ -40,9 +42,6 @@ async def gather(*coroutines: Coroutine[Any, None, _T]) -> tuple[_T, ...]: def wait_to_thread( f: Callable[_P, Awaitable[_T]], *args: _P.args, **kwargs: _P.kwargs ) -> _T: - """ - Wait for an awaitable in another thread. - """ synced = _WaiterThread(f, *args, **kwargs) synced.start() synced.join() @@ -54,6 +53,7 @@ def __init__( self, f: Callable[_P, Awaitable[_T]], *args: _P.args, **kwargs: _P.kwargs ): super().__init__() + self._loop = get_running_loop() self._f = f self._args = args self._kwargs = kwargs @@ -66,12 +66,15 @@ def return_value(self) -> _T: raise self._e return cast(_T, self._return_value) + @override def run(self) -> None: - run(self._run()) - - async def _run(self) -> None: try: - self._return_value = await self._f(*self._args, **self._kwargs) + future = run_coroutine_threadsafe( + self._f(*self._args, **self._kwargs), self._loop + ) + wait([future]) + self._return_value = future.result() + foo() except BaseException as e: # noqa: B036 # Store the exception, so it can be reraised when the calling thread # gets self.return_value. diff --git a/test.py b/test.py new file mode 100644 index 000000000..24fa46014 --- /dev/null +++ b/test.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from asyncio import get_running_loop, run_coroutine_threadsafe, run +from concurrent.futures.thread import ThreadPoolExecutor +from typing import Awaitable, TypeVar, ParamSpec, TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +THREAD_POOL = ThreadPoolExecutor() + + +def wait_to_thread( + f: Callable[_P, Awaitable[_T]], *args: _P.args, **kwargs: _P.kwargs +) -> _T: + loop = get_running_loop() + + def _wait_to_thread(): + print("PRE CORO THREADSAFE") + coroutine = f(*args, **kwargs) + return run_coroutine_threadsafe(coroutine, loop).result() + + return THREAD_POOL.submit(_wait_to_thread).result() + + +async def main_async() -> None: + print("THE END") + + +def main_sync() -> None: + wait_to_thread(main_async) + print("POST WAIT TO THREAD") + + +async def main() -> None: + main_sync() + + +if __name__ == "__main__": + run(main()) diff --git a/test2.py b/test2.py new file mode 100644 index 000000000..800a393d0 --- /dev/null +++ b/test2.py @@ -0,0 +1,60 @@ +import threading +import time +import asyncio +from asyncio import run + + +# another task coroutine +async def task_coro2(): + # report a message + print(f">>task2 running") + # block for a moment + await asyncio.sleep(2) + # report a message + print(f">>task2 done") + + +# task coroutine +async def task_coro(): + # loop a few times + for i in range(5): + # report a message + print(f">task at {i}") + # block a moment + await asyncio.sleep(1) + + +# function executed in another thread +def task_thread(loop): + # report a message + print("thread running") + # wait a moment + time.sleep(1) + # create a coroutine + coro = task_coro2() + # execute a coroutine + future = asyncio.run_coroutine_threadsafe(coro, loop) + # wait for the task to finish + future.result() + # report a message + print("thread done") + + +# entry point to the asyncio program +async def main(): + # report a message + print("asyncio running") + # get the event loop + loop = asyncio.get_running_loop() + # start a new thread + thread = threading.Thread(target=task_thread, args=(loop,)) + thread.start() + thread.join() + # execute a task + await task_coro() + # report a message + print("asyncio done") + + +if __name__ == "__main__": + run(main()) diff --git a/test3.py b/test3.py new file mode 100644 index 000000000..a6670ce7c --- /dev/null +++ b/test3.py @@ -0,0 +1,25 @@ +from asyncio import run_coroutine_threadsafe, get_running_loop, run +from threading import Thread + + +async def main_async() -> None: + foo() + pass + + +def _main_async(loop) -> None: + run_coroutine_threadsafe(main_async(), loop).result() + + +def main_sync() -> None: + thread = Thread(target=_main_async, args=[get_running_loop()]) + thread.start() + thread.join() + + +async def main() -> None: + main_sync() + + +if __name__ == "__main__": + run(main()) From 11a26964f50932f64813b9600201582cf116a149 Mon Sep 17 00:00:00 2001 From: Bart Feenstra Date: Mon, 30 Sep 2024 10:31:55 +0100 Subject: [PATCH 3/4] meh --- test3.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/test3.py b/test3.py index a6670ce7c..ab45d6b77 100644 --- a/test3.py +++ b/test3.py @@ -1,25 +1,26 @@ -from asyncio import run_coroutine_threadsafe, get_running_loop, run +from asyncio import run_coroutine_threadsafe, get_running_loop, run, AbstractEventLoop +from collections.abc import Callable, Awaitable from threading import Thread -async def main_async() -> None: - foo() - pass +async def _some_other_async_function() -> None: + print("BUT THIS IS NOT PRINTED") -def _main_async(loop) -> None: - run_coroutine_threadsafe(main_async(), loop).result() +def _helper(loop: AbstractEventLoop, f: Callable[[], Awaitable[None]]) -> None: + print("THIS IS PRINTED") + run_coroutine_threadsafe(f(), loop).result() -def main_sync() -> None: - thread = Thread(target=_main_async, args=[get_running_loop()]) +def _some_sync_function(loop: AbstractEventLoop) -> None: + thread = Thread(target=_helper, args=[loop, _some_other_async_function]) thread.start() thread.join() -async def main() -> None: - main_sync() +async def _async_main() -> None: + _some_sync_function(get_running_loop()) if __name__ == "__main__": - run(main()) + run(_async_main()) From 523fbcc99b86a44d90df9174bc3992309ab0af4a Mon Sep 17 00:00:00 2001 From: Bart Feenstra Date: Mon, 30 Sep 2024 18:23:58 +0100 Subject: [PATCH 4/4] meh --- test3.py | 49 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/test3.py b/test3.py index ab45d6b77..dda640c43 100644 --- a/test3.py +++ b/test3.py @@ -1,26 +1,51 @@ from asyncio import run_coroutine_threadsafe, get_running_loop, run, AbstractEventLoop from collections.abc import Callable, Awaitable -from threading import Thread +from concurrent.futures.thread import ThreadPoolExecutor +from typing import TypeVar +_T = TypeVar("_T") -async def _some_other_async_function() -> None: - print("BUT THIS IS NOT PRINTED") +def _await_in_thread(loop: AbstractEventLoop, f: Callable[[], Awaitable[_T]]) -> _T: + """ + Await something inside a thread, using the main thread's loop. + """ -def _helper(loop: AbstractEventLoop, f: Callable[[], Awaitable[None]]) -> None: print("THIS IS PRINTED") - run_coroutine_threadsafe(f(), loop).result() + return run_coroutine_threadsafe(f(), loop).result() + + +def _await_to_thread(pool: ThreadPoolExecutor, f: Callable[[], Awaitable[_T]]) -> _T: + """ + Await something by moving it to a thread. + """ + return pool.submit(_await_in_thread, get_running_loop(), f).result() + + +async def _async_main(pool: ThreadPoolExecutor) -> None: + """ + Run the main application, which is asynchronous. + """ + # Eventually, the application calls a function that due to its nature (maybe a third-party API) is synchronous. + _some_sync_function(pool) -def _some_sync_function(loop: AbstractEventLoop) -> None: - thread = Thread(target=_helper, args=[loop, _some_other_async_function]) - thread.start() - thread.join() +def _sync_main() -> None: + with ThreadPoolExecutor() as pool: + run(_async_main(pool)) -async def _async_main() -> None: - _some_sync_function(get_running_loop()) + +def _some_sync_function(pool: ThreadPoolExecutor) -> None: + # This synchronous function then has to call a function that is asynchronous. + result = _await_to_thread(pool, _some_async_function) + assert result == 123 + + +async def _some_async_function() -> int: + print("BUT THIS IS NOT PRINTED") + return 123 if __name__ == "__main__": - run(_async_main()) + _sync_main()