Skip to content

Commit

Permalink
Simplify the asyncio API (#1393)
Browse files Browse the repository at this point in the history
  • Loading branch information
bartfeenstra authored Apr 11, 2024
1 parent 0c67525 commit cc2fc0d
Show file tree
Hide file tree
Showing 17 changed files with 162 additions and 94 deletions.
9 changes: 6 additions & 3 deletions betty/_package/pyinstaller/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import sys
from asyncio import run

from betty.app import App
from betty.asyncio import sync
from betty.gui import BettyApplication
from betty.gui.app import WelcomeWindow


@sync
async def main() -> None:
def main() -> None:
"""
Launch Betty for PyInstaller builds.
"""
run(_main())


async def _main() -> None:
async with App.new_from_environment() as app:
async with BettyApplication([sys.argv[0]]).with_app(app) as qapp:
window = WelcomeWindow(app)
Expand Down
14 changes: 7 additions & 7 deletions betty/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from betty import fs
from betty.app.extension import ListExtensions, Extension, Extensions, build_extension_type_graph, \
CyclicDependencyError, ExtensionDispatcher, ConfigurableExtension, discover_extension_types
from betty.asyncio import sync, wait
from betty.asyncio import wait_to_thread
from betty.cache import Cache, FileCache
from betty.cache.file import BinaryFileCache, PickledFileCache
from betty.config import Configurable, FileBasedConfiguration
Expand Down Expand Up @@ -167,7 +167,7 @@ def __init__(
self._localizer: Localizer | None = None
self._localizers: LocalizerRepository | None = None
with suppress(FileNotFoundError):
wait(self.configuration.read())
wait_to_thread(self.configuration.read())
self._project = project or Project()
self.project.configuration.extensions.on_change(self._update_extensions)

Expand Down Expand Up @@ -353,7 +353,7 @@ def localizer(self) -> Localizer:
Get the application's localizer.
"""
if self._localizer is None:
self._localizer = wait(self.localizers.get_negotiated(self.configuration.locale or DEFAULT_LOCALE))
self._localizer = wait_to_thread(self.localizers.get_negotiated(self.configuration.locale or DEFAULT_LOCALE))
return self._localizer

@localizer.deleter
Expand Down Expand Up @@ -408,21 +408,21 @@ def http_client(self) -> aiohttp.ClientSession:
'User-Agent': f'Betty (https://github.com/bartfeenstra/betty) on behalf of {self._project.configuration.base_url}{self._project.configuration.root_path}',
},
)
weakref.finalize(self, sync(self._http_client.close))
weakref.finalize(self, lambda: None if self._http_client is None else wait_to_thread(self._http_client.close()))
return self._http_client

@http_client.deleter
def http_client(self) -> None:
if self._http_client is not None:
wait(self._http_client.close())
wait_to_thread(self._http_client.close())
self._http_client = None

@property
def entity_types(self) -> set[type[Entity]]:
if self._entity_types is None:
from betty.model.ancestry import Citation, Enclosure, Event, File, Note, Person, PersonName, Presence, Place, Source

self._entity_types = reduce(operator.or_, wait(self.dispatcher.dispatch(EntityTypeProvider)()), set()) | {
self._entity_types = reduce(operator.or_, wait_to_thread(self.dispatcher.dispatch(EntityTypeProvider)()), set()) | {
Citation,
Enclosure,
Event,
Expand All @@ -443,7 +443,7 @@ def entity_types(self) -> None:
@property
def event_types(self) -> set[type[EventType]]:
if self._event_types is None:
self._event_types = set(wait(self.dispatcher.dispatch(EventTypeProvider)())) | {
self._event_types = set(wait_to_thread(self.dispatcher.dispatch(EventTypeProvider)())) | {
Birth,
Baptism,
Adoption,
Expand Down
36 changes: 24 additions & 12 deletions betty/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
"""
from __future__ import annotations

import asyncio
from asyncio import TaskGroup
from asyncio import TaskGroup, get_running_loop, run
from functools import wraps
from threading import Thread
from typing import Callable, Awaitable, TypeVar, Generic, cast, ParamSpec, Coroutine, Any

from betty.warnings import deprecated

P = ParamSpec('P')
T = TypeVar('T')

Expand All @@ -30,25 +31,34 @@ async def gather(*coroutines: Coroutine[Any, None, T]) -> tuple[T, ...]:
)


@deprecated('This function is deprecated as of Betty 0.3.3, and will be removed in Betty 0.4.x. Instead, use `betty.asyncio.wait_to_thread()` or `asyncio.run()`.')
def wait(f: Awaitable[T]) -> T:
"""
Wait for an awaitable.
Wait for an awaitable, either in a new event loop or another thread.
"""
try:
loop = asyncio.get_running_loop()
loop = get_running_loop()
except RuntimeError:
loop = None
if loop:
synced = _SyncedAwaitable(f)
synced.start()
synced.join()
return synced.return_value
return wait_to_thread(f)
else:
return asyncio.run(
return run(
f, # type: ignore[arg-type]
)


def wait_to_thread(f: Awaitable[T]) -> T:
"""
Wait for an awaitable in another thread.
"""
synced = _WaiterThread(f)
synced.start()
synced.join()
return synced.return_value


@deprecated('This function is deprecated as of Betty 0.3.3, and will be removed in Betty 0.4.x. Instead, use `betty.asyncio.wait_to_thread()` or `asyncio.run()`.')
def sync(f: Callable[P, Awaitable[T]]) -> Callable[P, T]:
"""
Decorate an asynchronous callable to become synchronous.
Expand All @@ -59,7 +69,7 @@ def _synced(*args: P.args, **kwargs: P.kwargs) -> T:
return _synced


class _SyncedAwaitable(Thread, Generic[T]):
class _WaiterThread(Thread, Generic[T]):
def __init__(self, awaitable: Awaitable[T]):
super().__init__()
self._awaitable = awaitable
Expand All @@ -72,8 +82,10 @@ def return_value(self) -> T:
raise self._e
return cast(T, self._return_value)

@sync
async def run(self) -> None:
def run(self) -> None:
run(self._run())

async def _run(self) -> None:
try:
self._return_value = await self._awaitable
except BaseException as e:
Expand Down
27 changes: 16 additions & 11 deletions betty/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from betty import about, generate, load, documentation
from betty.app import App
from betty.asyncio import sync, wait
from betty.asyncio import wait_to_thread
from betty.contextlib import SynchronizedContextManager
from betty.error import UserFacingError
from betty.extension import demo
Expand Down Expand Up @@ -49,8 +49,7 @@ def catch_exceptions() -> Iterator[None]:
except KeyboardInterrupt:
print('Quitting...')
sys.exit(0)
pass
except Exception as e:
except BaseException as e:
logger = logging.getLogger(__name__)
if isinstance(e, UserFacingError):
logger.error(str(e))
Expand All @@ -74,8 +73,8 @@ def _command(*args: P.args, **kwargs: P.kwargs) -> None:
async def _app_command():
async with app:
await f(app, *args, **kwargs)
return wait(_app_command())
return wait(f(*args, **kwargs))
return wait_to_thread(_app_command())
return wait_to_thread(f(*args, **kwargs))
return _command


Expand All @@ -94,11 +93,17 @@ def app_command(f: Callable[Concatenate[App, P], Awaitable[None]]) -> Callable[P


@catch_exceptions()
@sync
async def _init_ctx_app(
def _init_ctx_app(
ctx: Context,
__: Option | Parameter | None = None,
configuration_file_path: str | None = None,
) -> None:
wait_to_thread(__init_ctx_app(ctx, configuration_file_path))


async def __init_ctx_app(
ctx: Context,
configuration_file_path: str | None = None,
) -> None:
ctx.ensure_object(dict)

Expand All @@ -118,7 +123,7 @@ async def _init_ctx_app(
'demo': _demo,
'gui': _gui,
}
if wait(about.is_development()):
if wait_to_thread(about.is_development()):
ctx.obj['commands']['init-translation'] = _init_translation
ctx.obj['commands']['update-translations'] = _update_translations
ctx.obj['app'] = app
Expand Down Expand Up @@ -231,8 +236,8 @@ def get_command(self, ctx: Context, cmd_name: str) -> Command | None:
callback=_build_init_ctx_verbosity(logging.NOTSET, logging.NOTSET),
)
@click.version_option(
wait(about.version_label()),
message=wait(about.report()),
wait_to_thread(about.version_label()),
message=wait_to_thread(about.report()),
prog_name='Betty',
)
def main(app: App, verbose: bool, more_verbose: bool, most_verbose: bool) -> None:
Expand Down Expand Up @@ -311,7 +316,7 @@ async def _docs():
await asyncio.sleep(999)


if wait(about.is_development()):
if wait_to_thread(about.is_development()):
@click.command(short_help='Initialize a new translation', help='Initialize a new translation.\n\nThis is available only when developing Betty.')
@click.argument('locale')
@global_command
Expand Down
9 changes: 4 additions & 5 deletions betty/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import aiofiles
from aiofiles.os import makedirs

from betty.asyncio import wait, sync
from betty.asyncio import wait_to_thread
from betty.classtools import repr_instance
from betty.functools import slice_to_range
from betty.locale import Str
Expand Down Expand Up @@ -105,9 +105,8 @@ def autowrite(self, autowrite: bool) -> None:
self.remove_on_change(self._on_change_write)
self._autowrite = autowrite

@sync
async def _on_change_write(self) -> None:
await self.write()
def _on_change_write(self) -> None:
wait_to_thread(self.write())

async def write(self, configuration_file_path: Path | None = None) -> None:
if configuration_file_path is not None:
Expand Down Expand Up @@ -158,7 +157,7 @@ def configuration_file_path(self) -> Path:
if self._configuration_file_path is None:
if self._configuration_directory is None:
self._configuration_directory = TemporaryDirectory()
wait(self._write(Path(self._configuration_directory.name) / f'{type(self).__name__}.json'))
wait_to_thread(self._write(Path(self._configuration_directory.name) / f'{type(self).__name__}.json'))
return cast(Path, self._configuration_file_path)

@configuration_file_path.setter
Expand Down
6 changes: 3 additions & 3 deletions betty/contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from types import TracebackType
from typing import AsyncContextManager, TypeVar, Generic

from betty.asyncio import wait
from betty.asyncio import wait_to_thread

ContextT = TypeVar('ContextT')

Expand All @@ -14,7 +14,7 @@ def __init__(self, context_manager: AsyncContextManager[ContextT]):
self._context_manager = context_manager

def __enter__(self) -> ContextT:
return wait(self._context_manager.__aenter__())
return wait_to_thread(self._context_manager.__aenter__())

def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> bool | None:
return wait(self._context_manager.__aexit__(exc_type, exc_val, exc_tb))
return wait_to_thread(self._context_manager.__aexit__(exc_type, exc_val, exc_tb))
4 changes: 2 additions & 2 deletions betty/extension/npm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from betty.app.extension import Extension, discover_extension_types
from betty.app.extension.requirement import Requirement, AnyRequirement, AllRequirements
from betty.asyncio import wait
from betty.asyncio import wait_to_thread
from betty.cache.file import BinaryFileCache
from betty.fs import iterfiles
from betty.locale import Str, DEFAULT_LOCALIZER
Expand Down Expand Up @@ -62,7 +62,7 @@ def _unmet_summary(cls) -> Str:
@classmethod
def check(cls) -> _NpmRequirement:
try:
wait(npm(['--version']))
wait_to_thread(npm(['--version']))
logging.getLogger(__name__).debug(cls._met_summary().localize(DEFAULT_LOCALIZER))
return cls(True)
except (CalledProcessError, FileNotFoundError):
Expand Down
14 changes: 8 additions & 6 deletions betty/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from betty import about
from betty.about import report
from betty.app import App
from betty.asyncio import sync, wait
from betty.asyncio import wait_to_thread
from betty.gui import get_configuration_file_filter
from betty.gui.error import ExceptionCatcher
from betty.gui.locale import TranslationsLocaleCollector
Expand Down Expand Up @@ -178,7 +178,7 @@ def open_project(self) -> None:
)
if not configuration_file_path_str:
return
wait(self._app.project.configuration.read(Path(configuration_file_path_str)))
wait_to_thread(self._app.project.configuration.read(Path(configuration_file_path_str)))
project_window = ProjectWindow(self._app)
project_window.show()
self.close()
Expand All @@ -196,7 +196,7 @@ def new_project(self) -> None:
if not configuration_file_path_str:
return
configuration = ProjectConfiguration()
wait(configuration.write(Path(configuration_file_path_str)))
wait_to_thread(configuration.write(Path(configuration_file_path_str)))
project_window = ProjectWindow(self._app)
project_window.show()
self.close()
Expand All @@ -206,8 +206,10 @@ def _demo(self) -> None:
serve_window = ServeDemoWindow(self._app, parent=self)
serve_window.show()

@sync
async def clear_caches(self) -> None:
def clear_caches(self) -> None:
wait_to_thread(self._clear_caches())

async def _clear_caches(self) -> None:
async with ExceptionCatcher(self):
await self._app.cache.clear()

Expand Down Expand Up @@ -315,7 +317,7 @@ def _set_translatables(self) -> None:
super()._set_translatables()
self._label.setText(''.join(map(lambda x: '<p>%s</p>' % x, [
self._app.localizer._('Version: {version}').format(
version=wait(about.version_label()),
version=wait_to_thread(about.version_label()),
),
self._app.localizer._('Copyright 2019-{year} Bart Feenstra & contributors. Betty is made available to you under the <a href="https://www.gnu.org/licenses/gpl-3.0.en.html">GNU General Public License, Version 3</a> (GPLv3).').format(
year=datetime.now().year,
Expand Down
4 changes: 2 additions & 2 deletions betty/gui/locale.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from PyQt6.QtWidgets import QComboBox, QLabel, QWidget

from betty.app import App
from betty.asyncio import wait
from betty.asyncio import wait_to_thread
from betty.gui.text import Caption
from betty.locale import negotiate_locale, get_display_name

Expand Down Expand Up @@ -87,7 +87,7 @@ def _set_translatables(self) -> None:
locale_name=get_display_name(locale, localizer.locale),
))
else:
negotiated_locale_translations_coverage = wait(localizers.coverage(translations_locale))
negotiated_locale_translations_coverage = wait_to_thread(localizers.coverage(translations_locale))
if negotiated_locale_translations_coverage[1] == 0:
negotiated_locale_translations_coverage_percentage = 0
else:
Expand Down
Loading

0 comments on commit cc2fc0d

Please sign in to comment.