diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 15a088a..a7fe8c0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -42,7 +42,7 @@ jobs: dotnet-version: '6.0.x' - name: Build wheels - uses: pypa/cibuildwheel@v2.16.2 + uses: pypa/cibuildwheel@v2.16.5 env: CIBW_BUILD: ${{ matrix.pyver }}-* CIBW_ARCHS: ${{ matrix.platform.arch }} diff --git a/.gitmodules b/.gitmodules index 19cd038..3638b37 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "vibrio/vendor/vibrio"] path = vibrio/vendor/vibrio - url = git@github.com:notjagan/vibrio.git + url = https://github.com/notjagan/vibrio.git diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..4a37bd7 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,21 @@ +version: 2 + +submodules: + include: all + recursive: true + +build: + os: ubuntu-22.04 + tools: + python: "3.12" + apt_packages: + - dotnet-sdk-6.0 + +sphinx: + configuration: docs/conf.py + +python: + install: + - method: pip + path: . + - requirements: docs/requirements.txt diff --git a/README.md b/README.md index 9272a76..864fd79 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,21 @@ -# vibrio-python -python bindings for https://github.com/notjagan/vibrio +# vibrio-python ([documentation](https://vibrio-python.readthedocs.io/en/latest/)) + +[![PyPI](https://img.shields.io/pypi/v/vibrio.svg)](https://pypi.org/project/vibrio/) +[![Build](https://github.com/notjagan/vibrio-python/actions/workflows/build.yml/badge.svg)](https://github.com/notjagan/vibrio-python/actions/workflows/build.yml) +[![CodeFactor](https://www.codefactor.io/repository/github/notjagan/vibrio-python/badge)](https://www.codefactor.io/repository/github/notjagan/vibrio-python) + +Python library for interfacing with osu!lazer functionality. Acts as bindings for https://github.com/notjagan/vibrio under the hood. + +# Installation + +`pip install vibrio` + +Supports Python 3.9+. + +Tested (through `cibuildwheel` deployment) and published on `pip` on the following platforms: +- Ubuntu (via manylinux and musl) (x86) +- macOS (x86, arm64) + - arm is currently untested due to unavailability through GitHub actions hosting +- Windows (x86 and AMD64) + +If you do not have one of the supported platforms (or otherwise want to build from source), simply clone the repository and use the `build` subcommand (and/or any superset of `build` like `sdist`) in `setup.py` to produce an installable package, then use `pip install` on the result. This will require having the `dotnet` C# SDK on the system path to compile the server solution. diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..05d9664 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1,41 @@ +# sphinx build folder +_build + +# Compiled source # +################### +*.com +*.class +*.dll +*.exe +*.o +*.so + +# Packages # +############ +# it's better to unpack these files and commit the raw source +# git has its own built in compression methods +*.7z +*.dmg +*.gz +*.iso +*.jar +*.rar +*.tar +*.zip + +# Logs and databases # +###################### +*.log +*.sql +*.sqlite + +# OS generated files # +###################### +.DS_Store? +ehthumbs.db +Icon? +Thumbs.db + +# Editor backup files # +####################### +*~ diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d4bb2cb --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..59c2543 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,21 @@ +import os +import sys + +sys.path.insert(0, os.path.abspath("..")) + +project = "vibrio" +copyright = "2024, notjagan" +author = "notjagan" + +extensions = ["sphinx.ext.autodoc", "numpydoc"] + +numpydoc_class_members_toctree = False +numpydoc_show_inherited_class_members = False +autodoc_default_options = {"members": True, "undoc-members": True} + +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + +html_theme = "pydata_sphinx_theme" +html_static_path = ["_static"] +html_sidebars = {"**": ["localtoc.html"]} diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..384be45 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,39 @@ +Welcome to vibrio's documentation! +================================== +All normal use-cases should be available through a top-level import +(`from vibrio import ...`); see :mod:`vibrio`. Internal documentation is +provided at :mod:`vibrio.lazer` and :mod:`vibrio.types`, although use of these modules +should not be necessary. + +Quickstart +---------- +>>> from vibrio import HitStatistics, Lazer, OsuMod +>>> with Lazer() as lazer: +... attributes = lazer.calculate_performance( +... beatmap_id=1001682, +... mods=[OsuMod.HIDDEN, OsuMod.DOUBLE_TIME], +... hitstats=HitStatistics( +... count_300=2019, count_100=104, count_50=0, count_miss=3, combo=3141 +... ), +... ) +... attributes.total +1304.35 + +See :class:`vibrio.Lazer` (or :class:`vibrio.LazerAsync` for the asynchronous +implementation) for more details. + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + vibrio + vibrio.lazer + vibrio.types + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..32bb245 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements.in b/docs/requirements.in new file mode 100644 index 0000000..639c79a --- /dev/null +++ b/docs/requirements.in @@ -0,0 +1,3 @@ +sphinx==7.2.6 +numpydoc==1.6.0 +pydata-sphinx-theme==0.15.2 diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..aaba4ac --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,78 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile '.\requirements.in' +# +accessible-pygments==0.0.4 + # via pydata-sphinx-theme +alabaster==0.7.16 + # via sphinx +babel==2.14.0 + # via + # pydata-sphinx-theme + # sphinx +beautifulsoup4==4.12.3 + # via pydata-sphinx-theme +certifi==2024.2.2 + # via requests +charset-normalizer==3.3.2 + # via requests +colorama==0.4.6 + # via sphinx +docutils==0.20.1 + # via + # pydata-sphinx-theme + # sphinx +idna==3.6 + # via requests +imagesize==1.4.1 + # via sphinx +jinja2==3.1.3 + # via + # numpydoc + # sphinx +markupsafe==2.1.5 + # via jinja2 +numpydoc==1.6.0 + # via -r .\requirements.in +packaging==23.2 + # via + # pydata-sphinx-theme + # sphinx +pydata-sphinx-theme==0.15.2 + # via -r .\requirements.in +pygments==2.17.2 + # via + # accessible-pygments + # pydata-sphinx-theme + # sphinx +requests==2.31.0 + # via sphinx +snowballstemmer==2.2.0 + # via sphinx +soupsieve==2.5 + # via beautifulsoup4 +sphinx==7.2.6 + # via + # -r .\requirements.in + # numpydoc + # pydata-sphinx-theme +sphinxcontrib-applehelp==1.0.8 + # via sphinx +sphinxcontrib-devhelp==1.0.6 + # via sphinx +sphinxcontrib-htmlhelp==2.0.5 + # via sphinx +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-qthelp==1.0.7 + # via sphinx +sphinxcontrib-serializinghtml==1.1.10 + # via sphinx +tabulate==0.9.0 + # via numpydoc +typing-extensions==4.9.0 + # via pydata-sphinx-theme +urllib3==2.2.1 + # via requests diff --git a/docs/vibrio.lazer.rst b/docs/vibrio.lazer.rst new file mode 100644 index 0000000..aaf1072 --- /dev/null +++ b/docs/vibrio.lazer.rst @@ -0,0 +1,4 @@ +vibrio.lazer +============ +.. automodule:: vibrio.lazer + :members: diff --git a/docs/vibrio.rst b/docs/vibrio.rst new file mode 100644 index 0000000..b5d8c72 --- /dev/null +++ b/docs/vibrio.rst @@ -0,0 +1,4 @@ +vibrio +====== +.. automodule:: vibrio + :members: diff --git a/docs/vibrio.types.rst b/docs/vibrio.types.rst new file mode 100644 index 0000000..4ab986c --- /dev/null +++ b/docs/vibrio.types.rst @@ -0,0 +1,4 @@ +vibrio.types +============ +.. automodule:: vibrio.types + :members: diff --git a/pyproject.toml b/pyproject.toml index 4653937..2dc1a3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vibrio" -version = "0.2.0" +version = "0.3.0" readme = "README.md" requires-python = ">=3.9" license = { file = "LICENSE" } diff --git a/setup.py b/setup.py index 6b7086a..c2eea97 100644 --- a/setup.py +++ b/setup.py @@ -67,18 +67,18 @@ def finalize_options(self) -> None: def run(self) -> None: def onerror( - func: Callable[..., Any], path: str, ex_info: tuple[Exception, ...] + func: Callable[[str], Any], path: str, ex_info: tuple[BaseException, ...] ) -> None: - ex, *_ = ex_info + ex_type, *_ = ex_info # resolve any permission issues - if ex is PermissionError and not os.access(path, os.W_OK): - os.chmod(path, stat.S_IWUSR) + if ex_type is PermissionError and not os.access(path, os.W_OK): + os.chmod(path, os.stat(path).st_mode | stat.S_IWUSR) func(path) # ignore missing file - elif ex is FileNotFoundError: + elif ex_type is FileNotFoundError: pass else: - raise + raise ex_type shutil.rmtree(EXTENSION_DIR, onerror=onerror) EXTENSION_DIR.mkdir(parents=True, exist_ok=True) @@ -101,13 +101,15 @@ def onerror( publish_dir = server_dir / "publish" for path in publish_dir.glob("*.zip"): with ZipFile(path, "r") as zip_file: - zip_file.extractall(EXTENSION_DIR) + for filename in zip_file.filelist: + executable = Path(zip_file.extract(filename, EXTENSION_DIR)) + executable.chmod(executable.stat().st_mode | stat.S_IEXEC) class CustomBuild(build): """Build process including compiling server executable.""" - sub_commands = [("build_vendor", None)] + build.sub_commands + sub_commands = [("build_vendor", None)] + build.sub_commands # type: ignore setup( diff --git a/vibrio/__init__.py b/vibrio/__init__.py index 37a83b6..1a9516a 100644 --- a/vibrio/__init__.py +++ b/vibrio/__init__.py @@ -1,2 +1,20 @@ -from vibrio.lazer import Lazer as Lazer -from vibrio.lazer import LazerAsync as LazerAsync +""" +Top-level imports and types; see :class:`Lazer` and :class:`LazerAsync`. +""" + +from vibrio.lazer import Lazer, LazerAsync +from vibrio.types import ( + HitStatistics, + OsuDifficultyAttributes, + OsuMod, + OsuPerformanceAttributes, +) + +__all__ = [ + "Lazer", + "LazerAsync", + "HitStatistics", + "OsuMod", + "OsuPerformanceAttributes", + "OsuDifficultyAttributes", +] diff --git a/vibrio/lazer.py b/vibrio/lazer.py index 15ac01b..be74f23 100644 --- a/vibrio/lazer.py +++ b/vibrio/lazer.py @@ -1,33 +1,46 @@ +""" +Module for interacting with osu!lazer functionality (see :class:`Lazer`, +:class:`LazerAsync`). +""" + from __future__ import annotations import asyncio import atexit import io import logging +import os import platform import signal import socket -import stat import subprocess -import tempfile +import threading import time import urllib.parse from abc import ABC from pathlib import Path -from typing import Any, BinaryIO, Optional, TextIO +from typing import IO, Any, BinaryIO, Callable import aiohttp import psutil import requests from typing_extensions import Self -from vibrio.types import OsuDifficultyAttributes, OsuMod +from vibrio.types import ( + HitStatistics, + OsuDifficultyAttributes, + OsuMod, + OsuPerformanceAttributes, +) PACKAGE_DIR = Path(__file__).absolute().parent -class ServerStateError(Exception): - """Exception due to attempting to induce an invalid server state transition.""" +class StateError(Exception): + """ + Exception due to attempting to induce an invalid state transition e.g. attempting to + launch the server when an instance is already tied to the current object. + """ class ServerError(Exception): @@ -42,7 +55,8 @@ def find_open_port() -> int: """Returns a port not currently in use on the system.""" with socket.socket() as sock: sock.bind(("", 0)) - return sock.getsockname()[1] + _, port = sock.getsockname() + return port def get_vibrio_path(platform: str) -> Path: @@ -55,48 +69,96 @@ def get_vibrio_path(platform: str) -> Path: return PACKAGE_DIR / "lib" / f"Vibrio{suffix}" +class LogPipe(IO[str]): + """IO wrapper around a thread for piping output to log function.""" + + def __init__(self, log_func: Callable[[str], None]) -> None: + self.log_func = log_func + self.fd_read, self.fd_write = os.pipe() + + class LogThread(threading.Thread): + def run(_self) -> None: + with os.fdopen(self.fd_read) as pipe_reader: + for line in iter(pipe_reader.readline, ""): + self.log_func(line.strip("\n")) + + self.thread = LogThread() + self.thread.daemon = True + self.thread.start() + + def fileno(self) -> int: + return self.fd_write + + def close(self) -> None: + os.close(self.fd_write) + + class LazerBase(ABC): - """Shared functionality for lazer wrappers.""" + """Abstract base class for `Lazer` and `LazerAsync`.""" - STARTUP_DELAY = 0.05 # Amount of time (seconds) between requests during startup + STARTUP_DELAY = 0.05 + """Amount of time (seconds) between requests during startup.""" + + def __init__( + self, + *, + port: int | None = None, + self_hosted: bool = False, + log_level: logging._Level = logging.NOTSET, + ) -> None: + self.self_hosted = self_hosted + if self.self_hosted and port is None: + raise ValueError("`port` must be provided if self-hosting") - def __init__(self, port: Optional[int] = None, use_logging: bool = True) -> None: if port is None: self.port = find_open_port() else: self.port = port - self.use_logging = use_logging - self.running = False - self.server_path = get_vibrio_path(platform.system()) - if not self.server_path.exists(): - raise FileNotFoundError(f'No executable found at "{self.server_path}".') - self.server_path.chmod(self.server_path.stat().st_mode | stat.S_IEXEC) + self.connected = False + self._server_path = get_vibrio_path(platform.system()) + if not self._server_path.exists(): + raise FileNotFoundError(f'No executable found at "{self._server_path}"') + + self._logger = logging.getLogger(str(id(self))) + self._logger.setLevel(log_level) + + self._info_pipe: LogPipe | None + self._error_pipe: LogPipe | None - self.log: Optional[tempfile._TemporaryFileWrapper[bytes]] = None + def args(self) -> list[str]: + """Produces the command line arguments for the server executable.""" + return [str(self._server_path), "--urls", self.address()] def address(self) -> str: """Constructs the base URL for the web server.""" return f"http://localhost:{self.port}" def _start(self) -> None: - if self.running: - raise ServerStateError("Server is already running") + if self.connected: + raise StateError("Already connected to server") - self.running = True + self._info_pipe = LogPipe(self._logger.info) + self._error_pipe = LogPipe(self._logger.error) - if self.use_logging: - logging.info(f"Launching server on port {self.port}") - self.log = tempfile.NamedTemporaryFile(delete=False) + if not self.self_hosted: + self._logger.info(f"Hosting server on port {self.port}.") def _stop(self) -> None: - if self.log is not None: - logging.info(f"Server output logged at {self.log.file.name}") - self.log.close() - self.log = None + self._logger.info("Shutting down...") + if self._info_pipe is not None: + self._info_pipe.close() + if self._error_pipe is not None: + self._error_pipe.close() + + @staticmethod + def _not_found_error(beatmap_id: int) -> BeatmapNotFound: + return BeatmapNotFound(f"No beatmap found for id {beatmap_id}") class BaseUrlSession(requests.Session): + """Request session with a base URL as used internally in `Lazer`.""" + def __init__(self, base_url: str) -> None: super().__init__() self.base_url = base_url @@ -104,74 +166,142 @@ def __init__(self, base_url: str) -> None: def request( self, method: str | bytes, url: str | bytes, *args: Any, **kwargs: Any ) -> requests.Response: + """ + Makes a request using the currently stored base URL. + + See Also + -------- + requests.Session.request + """ full_url = urllib.parse.urljoin(self.base_url, str(url)) return super().request(method, full_url, *args, **kwargs) class Lazer(LazerBase): - """Synchronous implementation for interfacing with osu!lazer functionality.""" - - def __init__(self, port: int | None = None, use_logging: bool = True) -> None: - super().__init__(port, use_logging) + """ + Context manager for interfacing with osu!lazer functionality (synchronously). + + Attributes + ---------- + connected : bool + Whether the class instance is currently connected to a server. + + Examples + -------- + >>> from vibrio import HitStatistics, Lazer, OsuMod + >>> with Lazer() as lazer: + ... attributes = lazer.calculate_performance( + ... beatmap_id=1001682, + ... mods=[OsuMod.HIDDEN, OsuMod.DOUBLE_TIME], + ... hitstats=HitStatistics( + ... count_300=2019, count_100=104, count_50=0, count_miss=3, combo=3141 + ... ), + ... ) + ... attributes.total + 1304.35 + + Notes + ----- + This class can be used traditionally instead of as a context manager, in which case + the use of `start()` and `stop()` are left up to the user. Only do this if you know + what you are doing; failing to call `stop()` appropriately may leave an instance of + the server dangling. `start()` will attempt to create a callback that culls any + server instances on program shutdown, but proceed with caution and call `stop()` as + necessary to avoid any possible memory leaks. + + See Also + -------- + LazerAsync : asynchronous implementation of the same functionality + """ + + def __init__( + self, + *, + port: int | None = None, + self_hosted: bool = False, + log_level: logging._Level = logging.NOTSET, + ) -> None: + """ + Constructs a `Lazer` instance. + + This *does not* launch/connect to a server instance. If you are using this class + as a context manager, feel free to ignore this as the `__enter__()` and + `__exit__()` methods will handle this for you. Otherwise, see `start()` and + `stop()` and the notes in the class docstring. + + Parameters + ---------- + port : int, optional + Port to run/connect to the server on. Automatically generates an unused + port if left unset. + self_hosted : bool, default False + Whether the user is hosting their own server instance. Requires + specification of a port if set to `True`. + log_level : logging level, default `logging.NOTSET` + Mininum severity level for logging, as found in the `logging` standard + library. + + Returns + ------- + Lazer + """ + super().__init__(port=port, self_hosted=self_hosted, log_level=log_level) self.session = None self.process = None @property def session(self) -> BaseUrlSession: + """Request session; errors if unset.""" if self._session is None: - raise ServerStateError("Session has not been initialized") + raise StateError("Session has not been initialized") return self._session @session.setter - def session(self, value: Optional[BaseUrlSession]) -> None: + def session(self, value: BaseUrlSession | None) -> None: self._session = value @property def process(self) -> subprocess.Popen[bytes]: + """Executable process; errors if unset.""" if self._process is None: - raise ServerStateError("Process has not been initialized") + raise StateError("Process has not been initialized") return self._process @process.setter - def process(self, value: Optional[subprocess.Popen[bytes]]) -> None: + def process(self, value: subprocess.Popen[bytes] | None) -> None: self._process = value def start(self) -> None: - """Launches server executable.""" + """Launches and connects to `vibrio` server executable.""" self._start() - if self.log is not None: - out = self.log - else: - out = subprocess.DEVNULL - - self.process = subprocess.Popen( - [self.server_path, "--urls", self.address()], - stdout=out, - stderr=out, - ) + if not self.self_hosted: + self.process = subprocess.Popen( + self.args(), + stdout=self._info_pipe, + stderr=self._error_pipe, + ) self.session = BaseUrlSession(self.address()) - - # block until webserver has launched - while True: + while True: # block until webserver has launched try: with self.session.get("/api/status") as response: if response.status_code == 200: break except (ConnectionError, IOError): - pass - finally: time.sleep(self.STARTUP_DELAY) + self.connected = True atexit.register(self.stop) def stop(self) -> None: - """Cleans up server executable.""" + """Cleans up server executable and related periphery.""" + if not self.connected: + return self._stop() - try: + if not self.self_hosted: parent = psutil.Process(self.process.pid) for child in parent.children(recursive=True): child.terminate() @@ -180,15 +310,16 @@ def stop(self) -> None: self.process = None if status != 0 and status != signal.SIGTERM: - raise SystemError( - f"Could not cleanly shutdown server subprocess; received return code {status}" + self._logger.error( + "Could not cleanly shutdown server subprocess; received return code" + f" {status}" ) - self.session.close() - self.session = None + self.session.close() + self.session = None - except ServerStateError: - pass + self.connected = False + self._logger.info("Connection closed.") def __enter__(self) -> Self: self.start() @@ -198,16 +329,24 @@ def __exit__(self, *_) -> bool: self.stop() return False + @staticmethod + def _status_error(response: requests.Response) -> ServerError: + """Emits an error based on the status of the provided request.""" + if response.text: + return ServerError( + f"Unexpected status code {response.status_code}: {response.text}" + ) + else: + return ServerError(f"Unexpected status code {response.status_code}") + def has_beatmap(self, beatmap_id: int) -> bool: - """Checks if given beatmap is cached/available locally.""" + """Returns true if the given beatmap is currently stored locally.""" with self.session.get(f"/api/beatmaps/{beatmap_id}/status") as response: if response.status_code == 200: return True elif response.status_code == 404: return False - raise ServerError( - f"Unexpected status code {response.status_code}; check server logs for error details" - ) + raise self._status_error(response) def get_beatmap(self, beatmap_id: int) -> BinaryIO: """Returns a file stream for the given beatmap.""" @@ -218,127 +357,312 @@ def get_beatmap(self, beatmap_id: int) -> BinaryIO: stream.seek(0) return stream elif response.status_code == 404: - raise BeatmapNotFound(f"No beatmap found for id {beatmap_id}") + raise self._not_found_error(beatmap_id) else: - raise ServerError( - f"Unexpected status code {response.status_code}; check server logs for error details" - ) + raise self._status_error(response) def clear_cache(self) -> None: """Clears beatmap cache (if applicable).""" with self.session.delete("/api/beatmaps/cache") as response: if response.status_code != 200: - raise ServerError( - f"Unexpected status code {response.status_code}; check server logs for error details" - ) + raise self._status_error(response) def calculate_difficulty( self, - mods: list[OsuMod], - beatmap_id: Optional[int] = None, - beatmap: Optional[TextIO] = None, + *, + beatmap_id: int | None = None, + beatmap: BinaryIO | None = None, + mods: list[OsuMod] | None = None, ) -> OsuDifficultyAttributes: - params = {"mods": [mod.value for mod in mods]} + """ + Calculates the difficulty parameters for a beatmap and optional mod combination. + + `beatmap_id` and `beatmap` specify the beatmap to be queried; exactly one of + the two must be set during difficulty calculation. + + Parameters + ---------- + beatmap_id : int, optional + beatmap : binary file stream, optional + mods : list of OsuMod enums, optional + + Returns + ------- + OsuDifficultyAttributes + Dataclass encoding the difficulty attributes of the requested map. + """ + params: dict[str, Any] = {} + if mods is not None: + params["mods"] = [mod.value for mod in mods] if beatmap_id is not None: if beatmap is not None: raise ValueError( - "Exactly one of `beatmap_id` and `beatmap_data` should be set" + "Exactly one of `beatmap_id` and `beatmap` must be populated" ) - - with self.session.get( - f"/api/difficulty/{beatmap_id}", params=params - ) as response: - if response.status_code == 200: - return OsuDifficultyAttributes.from_json(response.json()) - elif response.status_code == 404: - raise BeatmapNotFound(f"No beatmap found for id {beatmap_id}") - else: - raise ServerError( - f"Unexpected status code {response.status_code}; check server logs for error details" - ) - + response = self.session.get(f"/api/difficulty/{beatmap_id}", params=params) elif beatmap is not None: - with self.session.post( + response = self.session.post( "/api/difficulty", params=params, files={"beatmap": beatmap} - ) as response: - if response.status_code == 200: - return OsuDifficultyAttributes.from_json(response.json()) - else: - raise ServerError( - f"Unexpected status code {response.status_code}; check server logs for error details" - ) + ) + else: + raise ValueError( + "Exactly one of `beatmap_id` and `beatmap` must be populated" + ) + with response: + if response.status_code == 200: + return OsuDifficultyAttributes.from_dict(response.json()) + elif response.status_code == 404 and beatmap_id is not None: + raise self._not_found_error(beatmap_id) + else: + raise self._status_error(response) + + def _request_performance_beatmap_id( + self, + beatmap_id: int, + mods: list[OsuMod] | None, + hit_stats: HitStatistics | None = None, + replay: BinaryIO | None = None, + ) -> requests.Response: + """Queries for the performance of a play given a beatmap ID.""" + if hit_stats is not None: + params = hit_stats.to_dict() + if mods is not None: + params["mods"] = [mod.value for mod in mods] + return self.session.get(f"/api/performance/{beatmap_id}", params=params) + elif replay is not None: + return self.session.post( + f"/api/performance/replay/{beatmap_id}", files={"replay": replay} + ) else: raise ValueError( - "Exactly one of `beatmap_id` and `beatmap_data` should be set" + "Exactly one of `hit_stats` and `replay` must be populated when" + " calculating performance with a beatmap ID" ) + def _request_performance_beatmap( + self, + beatmap: BinaryIO, + mods: list[OsuMod] | None, + hit_stats: HitStatistics | None = None, + replay: BinaryIO | None = None, + ) -> requests.Response: + """Queries for the performance of a play given a beatmap ID.""" + if hit_stats is not None: + params = hit_stats.to_dict() + if mods is not None: + params["mods"] = [mod.value for mod in mods] + return self.session.post( + "/api/performance", params=params, files={"beatmap": beatmap} + ) + elif replay is not None: + return self.session.post( + "/api/performance/replay", + files={"beatmap": beatmap, "replay": replay}, + ) + else: + raise ValueError( + "Exactly one of `hit_stats` and `replay` must be populated when" + " calculating performance with a beatmap" + ) -class LazerAsync(LazerBase): - """Asynchronous implementation for interfacing with osu!lazer functionality.""" + def calculate_performance( + self, + *, + beatmap_id: int | None = None, + beatmap: BinaryIO | None = None, + mods: list[OsuMod] | None = None, + difficulty: OsuDifficultyAttributes | None = None, + hit_stats: HitStatistics | None = None, + replay: BinaryIO | None = None, + ) -> OsuPerformanceAttributes: + """ + Calculates the performance values for a given play on a provided beatmap. + + Each query essentially requires a method of specifying the beatmap the play was + made on (through exactly one of `beatmap_id`, `beatmap` or `difficulty`) and + a method of describing the play itself (through exactly one of `hit_stats` and + `replay`). However, processing a replay is not possible with difficulty + attributes alone, so using `difficulty` requires the use of `hit_stats`. + + Parameters + ---------- + beatmap_id : int, optional + beatmap : binary file stream, optional + mods : list of OsuMod enums, optional + For use with either `beatmap` or `beatmap_id`. + difficulty : OsuDifficultyAttributes, optional + Difficulty attribute instance, as returned by `calculate_difficulty()`. + hit_stats : HitStatistics, optional + replay : binary file stream, optional + + Returns + ------- + OsuPerformanceAttributes + Dataclass encoding the performance values of the requested play. + """ + if beatmap_id is not None: + response = self._request_performance_beatmap_id( + beatmap_id, mods, hit_stats, replay + ) + elif beatmap is not None: + response = self._request_performance_beatmap( + beatmap, mods, hit_stats, replay + ) + elif difficulty is not None: + if hit_stats is not None: + response = self.session.get( + "/api/performance", + params=difficulty.to_dict() | hit_stats.to_dict(), + ) + else: + raise ValueError( + "`hit_stats` must be populated when querying with `difficulty`" + ) + else: + raise ValueError( + "Exactly one of `beatmap_id`, `beatmap`, and `difficulty` must be" + " populated" + ) + + with response: + if response.status_code == 200: + return OsuPerformanceAttributes.from_dict(response.json()) + elif response.status_code == 404 and beatmap_id is not None: + raise self._not_found_error(beatmap_id) + else: + raise self._status_error(response) - def __init__(self, port: int | None = None, use_logging: bool = True) -> None: - super().__init__(port, use_logging) + +class LazerAsync(LazerBase): + """ + Context manager for interfacing with osu!lazer functionality asynchronously. + + Attributes + ---------- + connected : bool + Whether the class instance is currently connected to a server. + + Examples + -------- + Note: the following example would not execute in a REPL environment as async + statements must occur within async functions, but the principle still holds. + + >>> from vibrio import HitStatistics, Lazer, OsuMod + >>> async with Lazer() as lazer: + ... attributes = await lazer.calculate_performance( + ... beatmap_id=1001682, + ... mods=[OsuMod.HIDDEN, OsuMod.DOUBLE_TIME], + ... hitstats=HitStatistics( + ... count_300=2019, count_100=104, count_50=0, count_miss=3, combo=3141 + ... ), + ... ) + ... attributes.total + 1304.35 + + Notes + ----- + This class can be used traditionally instead of as a context manager, in which case + the use of `start()` and `stop()` are left up to the user. Only do this if you know + what you are doing; failing to call `stop()` appropriately may leave an instance of + the server dangling. `start()` will attempt to create a callback that culls any + server instances on program shutdown, but proceed with caution and call `stop()` as + necessary to avoid any possible memory leaks. + + See Also + -------- + Lazer : synchronous implementation of the same functionality + """ + + def __init__( + self, + *, + port: int | None = None, + self_hosted: bool = False, + log_level: logging._Level = logging.NOTSET, + ) -> None: + """ + Constructs a `LazerAsync` instance. + + This *does not* launch/connect to a server instance. If you are using this class + as a context manager, feel free to ignore this as the `__aenter__()` and + `__aexit__()` methods will handle this for you. Otherwise, see `start()` and + `stop()` and the notes in the class docstring. + + Parameters + ---------- + port : int, optional + Port to run/connect to the server on. Automatically generates an unused + port if left unset. + self_hosted : bool, default False + Whether the user is hosting their own server instance. Requires + specification of a port if set to `True`. + log_level : logging level, default `logging.NOTSET` + Mininum severity level for logging, as found in the `logging` standard + library. + + Returns + ------- + LazerAsync + """ + super().__init__(port=port, self_hosted=self_hosted, log_level=log_level) self.session = None self.process = None @property def session(self) -> aiohttp.ClientSession: + """Request session; errors if unset.""" if self._session is None: - raise ServerStateError("Session has not been initialized") + raise StateError("Session has not been initialized") return self._session @session.setter - def session(self, value: Optional[aiohttp.ClientSession]) -> None: + def session(self, value: aiohttp.ClientSession | None) -> None: self._session = value @property def process(self) -> asyncio.subprocess.Process: + """Executable process; errors if unset.""" if self._process is None: - raise ServerStateError("Process has not been initialized") + raise StateError("Process has not been initialized") return self._process @process.setter - def process(self, value: Optional[asyncio.subprocess.Process]) -> None: + def process(self, value: asyncio.subprocess.Process | None) -> None: self._process = value async def start(self) -> None: - """Launches server executable.""" + """Launches and connects to `vibrio` server executable.""" self._start() - if self.log is not None: - out = self.log - else: - out = subprocess.DEVNULL - - self.process = await asyncio.create_subprocess_shell( - f"{self.server_path} --urls {self.address()}", - stdout=out, - stderr=out, - ) + if not self.self_hosted: + self.process = await asyncio.create_subprocess_shell( + " ".join(self.args()), + stdout=self._info_pipe, + stderr=self._error_pipe, + ) self.session = aiohttp.ClientSession(self.address()) - - # block until webserver has launched - while True: + while True: # block until webserver has launched try: async with self.session.get("/api/status") as response: if response.status == 200: break except (ConnectionError, aiohttp.ClientConnectionError): - pass - finally: await asyncio.sleep(self.STARTUP_DELAY) + self.connected = True atexit.register(lambda: asyncio.run(self.stop())) async def stop(self) -> None: - """Cleans up server executable.""" + """Cleans up server executable and related periphery.""" + if not self.connected: + return self._stop() - try: + if not self.self_hosted: parent = psutil.Process(self.process.pid) for child in parent.children(recursive=True): child.terminate() @@ -346,16 +670,17 @@ async def stop(self) -> None: status = await self.process.wait() self.process = None - await self.session.close() - self.session = None - if status != 0 and status != signal.SIGTERM: - raise SystemError( - f"Could not cleanly shutdown server subprocess; received return code {status}" + self._logger.error( + "Could not cleanly shutdown server subprocess; received return code" + f" {status}" ) - except ServerStateError: - pass + await self.session.close() + self.session = None + + self.connected = False + self._logger.info("Connection closed.") async def __aenter__(self) -> Self: await self.start() @@ -365,16 +690,25 @@ async def __aexit__(self, *_) -> bool: await self.stop() return False + @staticmethod + async def _status_error(response: aiohttp.ClientResponse) -> ServerError: + """Emits an error based on the status of the provided request.""" + text = await response.text() + if text: + return ServerError( + f"Unexpected status code {response.status}: {response.text}" + ) + else: + return ServerError(f"Unexpected status code {response.status}") + async def has_beatmap(self, beatmap_id: int) -> bool: - """Checks if given beatmap is cached/available locally.""" + """Returns true if the given beatmap is currently stored locally.""" async with self.session.get(f"/api/beatmaps/{beatmap_id}/status") as response: if response.status == 200: return True elif response.status == 404: return False - raise ServerError( - f"Unexpected status code {response.status}; check server logs for error details" - ) + raise await self._status_error(response) async def get_beatmap(self, beatmap_id: int) -> BinaryIO: """Returns a file stream for the given beatmap.""" @@ -385,65 +719,183 @@ async def get_beatmap(self, beatmap_id: int) -> BinaryIO: stream.seek(0) return stream elif response.status == 404: - raise BeatmapNotFound(f"No beatmap found for id {beatmap_id}") + raise self._not_found_error(beatmap_id) else: - raise ServerError( - f"Unexpected status code {response.status}; check server logs for error details" - ) + raise await self._status_error(response) async def clear_cache(self) -> None: """Clears beatmap cache (if applicable).""" async with self.session.delete("/api/beatmaps/cache") as response: if response.status != 200: - raise ServerError( - f"Unexpected status code {response.status}; check server logs for error details" - ) + raise await self._status_error(response) async def calculate_difficulty( self, - mods: list[OsuMod], - beatmap_id: Optional[int] = None, - beatmap: Optional[TextIO] = None, + *, + beatmap_id: int | None = None, + beatmap: BinaryIO | None = None, + mods: list[OsuMod] | None = None, ) -> OsuDifficultyAttributes: - params = {"mods": [mod.value for mod in mods]} + """ + Calculates the difficulty parameters for a beatmap and optional mod combination. + + `beatmap_id` and `beatmap` specify the beatmap to be queried; exactly one of + the two must be set during difficulty calculation. + + Parameters + ---------- + beatmap_id : int, optional + beatmap : binary file stream, optional + mods : list of OsuMod enums, optional + + Returns + ------- + OsuDifficultyAttributes + Dataclass encoding the difficulty attributes of the requested map. + """ + params = {} + if mods is not None: + params["mods"] = [mod.value for mod in mods] if beatmap_id is not None: if beatmap is not None: raise ValueError( - "Exactly one of `beatmap_id` and `beatmap_data` should be set" + "Exactly one of `beatmap_id` and `beatmap` must be populated" ) - - async with self.session.get( + response = await self.session.get( f"/api/difficulty/{beatmap_id}", params=params - ) as response: - if response.status == 200: - return OsuDifficultyAttributes.from_json(await response.json()) - elif response.status == 404: - raise BeatmapNotFound(f"No beatmap found for id {beatmap_id}") - else: - raise ServerError( - f"Unexpected status code {response.status}; check server logs for error details" - ) - + ) elif beatmap is not None: - data = aiohttp.FormData() - data.add_field( - "beatmap", - beatmap.read(), - content_type="multipart/form-data", - filename="beatmap", + response = await self.session.post( + "/api/difficulty", params=params, data={"beatmap": beatmap} + ) + else: + raise ValueError( + "Exactly one of `beatmap_id` and `beatmap` must be populated" ) - async with self.session.post( - "/api/difficulty", params=params, data=data - ) as response: - if response.status == 200: - return OsuDifficultyAttributes.from_json(await response.json()) - else: - raise ServerError( - f"Unexpected status code {response.status}; check server logs for error details" - ) + async with response: + if response.status == 200: + return OsuDifficultyAttributes.from_dict(await response.json()) + elif response.status == 404 and beatmap_id is not None: + raise self._not_found_error(beatmap_id) + else: + raise await self._status_error(response) + + async def _request_performance_beatmap_id( + self, + beatmap_id: int, + mods: list[OsuMod] | None, + hit_stats: HitStatistics | None = None, + replay: BinaryIO | None = None, + ) -> aiohttp.ClientResponse: + """Queries for the performance of a play given a beatmap ID.""" + if hit_stats is not None: + params = hit_stats.to_dict() + if mods is not None: + params["mods"] = [mod.value for mod in mods] + return await self.session.get( + f"/api/performance/{beatmap_id}", params=params + ) + elif replay is not None: + return await self.session.post( + f"/api/performance/replay/{beatmap_id}", data={"replay": replay} + ) + else: + raise ValueError( + "Exactly one of `hit_stats` and `replay` must be populated when" + " calculating performance with a beatmap ID" + ) + + async def _request_performance_beatmap( + self, + beatmap: BinaryIO, + mods: list[OsuMod] | None, + hit_stats: HitStatistics | None = None, + replay: BinaryIO | None = None, + ) -> aiohttp.ClientResponse: + """Queries for the performance of a play given a beatmap ID.""" + if hit_stats is not None: + params = hit_stats.to_dict() + if mods is not None: + params["mods"] = [mod.value for mod in mods] + return await self.session.post( + "/api/performance", params=params, data={"beatmap": beatmap} + ) + elif replay is not None: + return await self.session.post( + "/api/performance/replay", + data={"beatmap": beatmap, "replay": replay}, + ) + else: + raise ValueError( + "Exactly one of `hit_stats` and `replay` must be populated when" + " calculating performance with a beatmap" + ) + + async def calculate_performance( + self, + *, + beatmap_id: int | None = None, + beatmap: BinaryIO | None = None, + mods: list[OsuMod] | None = None, + difficulty: OsuDifficultyAttributes | None = None, + hit_stats: HitStatistics | None = None, + replay: BinaryIO | None = None, + ) -> OsuPerformanceAttributes: + """ + Calculates the performance values for a given play on a provided beatmap. + + Each query essentially requires a method of specifying the beatmap the play was + made on (through exactly one of `beatmap_id`, `beatmap` or `difficulty`) and + a method of describing the play itself (through exactly one of `hit_stats` and + `replay`). However, processing a replay is not possible with difficulty + attributes alone, so using `difficulty` requires the use of `hit_stats`. + + Parameters + ---------- + beatmap_id : int, optional + beatmap : binary file stream, optional + mods : list of OsuMod enums, optional + For use with either `beatmap` or `beatmap_id`. + difficulty : OsuDifficultyAttributes, optional + Difficulty attribute instance, as returned by `calculate_difficulty()`. + hit_stats : HitStatistics, optional + replay : binary file stream, optional + + Returns + ------- + OsuPerformanceAttributes + Dataclass encoding the performance values of the requested play. + """ + if beatmap_id is not None: + response = await self._request_performance_beatmap_id( + beatmap_id, mods, hit_stats, replay + ) + elif beatmap is not None: + response = await self._request_performance_beatmap( + beatmap, mods, hit_stats, replay + ) + elif difficulty is not None: + if hit_stats is not None: + response = await self.session.get( + "/api/performance", + params=difficulty.to_dict() | hit_stats.to_dict(), + ) + else: + raise ValueError( + "`hit_stats` must be populated when querying with `difficulty`" + ) else: raise ValueError( - "Exactly one of `beatmap_id` and `beatmap_data` should be set" + "Exactly one of `beatmap_id`, `beatmap`, and `difficulty` must be" + " populated" ) + + async with response: + if response.status == 200: + return OsuPerformanceAttributes.from_dict(await response.json()) + elif response.status == 404 and beatmap_id is not None: + raise self._not_found_error(beatmap_id) + else: + raise await self._status_error(response) diff --git a/vibrio/tests/resources/4429758207.osr b/vibrio/tests/resources/4429758207.osr new file mode 100644 index 0000000..83f8a65 Binary files /dev/null and b/vibrio/tests/resources/4429758207.osr differ diff --git a/vibrio/tests/test_lazer.py b/vibrio/tests/test_lazer.py index fae782c..bc39522 100644 --- a/vibrio/tests/test_lazer.py +++ b/vibrio/tests/test_lazer.py @@ -1,116 +1,287 @@ +from dataclasses import dataclass from pathlib import Path import pytest from pytest import approx # type: ignore from vibrio import Lazer, LazerAsync -from vibrio.types import OsuMod +from vibrio.types import HitStatistics, OsuMod RESOURCES_DIR = Path(__file__).absolute().parent / "resources" +EPSILON = 1e-3 pytest_plugins = ("pytest_asyncio",) @pytest.mark.parametrize("beatmap_id", [1001682]) -def test_get_beatmap(beatmap_id: int): - beatmap = None - with Lazer() as lazer: - beatmap = lazer.get_beatmap(beatmap_id) +class TestSelfHosted: + def test_get_beatmap(self, beatmap_id: int): + beatmap = None + with Lazer() as lazer1: + with Lazer(port=lazer1.port, self_hosted=True) as lazer2: + beatmap = lazer2.get_beatmap(beatmap_id) + assert beatmap is not None - assert beatmap is not None - for line in beatmap.readlines(): - if line.startswith(b"BeatmapID"): - _, found_id = line.split(b":") - assert beatmap_id == int(found_id) - break + @pytest.mark.asyncio + async def test_get_beatmap_async(self, beatmap_id: int): + beatmap = None + async with LazerAsync() as lazer1: + async with LazerAsync(port=lazer1.port, self_hosted=True) as lazer2: + beatmap = await lazer2.get_beatmap(beatmap_id) + assert beatmap is not None @pytest.mark.parametrize("beatmap_id", [1001682]) -def test_cache_status(beatmap_id: int): - with Lazer() as lazer: - assert not lazer.has_beatmap(beatmap_id) - lazer.get_beatmap(beatmap_id) - assert lazer.has_beatmap(beatmap_id) - lazer.clear_cache() - assert not lazer.has_beatmap(beatmap_id) +class TestBeatmap: + def test_get_beatmap(self, beatmap_id: int): + beatmap = None + with Lazer() as lazer: + beatmap = lazer.get_beatmap(beatmap_id) + assert beatmap is not None + for line in beatmap.readlines(): + if line.startswith(b"BeatmapID"): + _, found_id = line.split(b":") + assert beatmap_id == int(found_id) + break -@pytest.mark.parametrize("beatmap_id", [1001682]) -@pytest.mark.parametrize("mods", [[OsuMod.DOUBLE_TIME]]) -@pytest.mark.parametrize("star_rating", [9.7]) -@pytest.mark.parametrize("max_combo", [3220]) -def test_calculate_difficulty_id( - beatmap_id: int, mods: list[OsuMod], star_rating: float, max_combo: int -): - with Lazer() as lazer: - attributes = lazer.calculate_difficulty(mods, beatmap_id=beatmap_id) - assert attributes.star_rating == approx(star_rating, 0.03) - assert attributes.max_combo == max_combo - - -@pytest.mark.parametrize("beatmap_filename", ["1001682.osu"]) -@pytest.mark.parametrize("mods", [[OsuMod.DOUBLE_TIME]]) -@pytest.mark.parametrize("star_rating", [9.7]) -@pytest.mark.parametrize("max_combo", [3220]) -def test_calculate_difficulty_file( - beatmap_filename: str, mods: list[OsuMod], star_rating: float, max_combo: int -): - with Lazer() as lazer, open(RESOURCES_DIR / beatmap_filename) as beatmap: - attributes = lazer.calculate_difficulty(mods, beatmap=beatmap) - assert attributes.star_rating == approx(star_rating, 0.03) - assert attributes.max_combo == max_combo - - -@pytest.mark.asyncio -@pytest.mark.parametrize("beatmap_id", [1001682]) -async def test_get_beatmap_async(beatmap_id: int): - beatmap = None - async with LazerAsync() as lazer: - beatmap = await lazer.get_beatmap(beatmap_id) + def test_cache_status(self, beatmap_id: int): + with Lazer() as lazer: + assert not lazer.has_beatmap(beatmap_id) + lazer.get_beatmap(beatmap_id) + assert lazer.has_beatmap(beatmap_id) + lazer.clear_cache() + assert not lazer.has_beatmap(beatmap_id) - assert beatmap is not None - for line in beatmap.readlines(): - if line.startswith(b"BeatmapID"): - _, found_id = line.split(b":") - assert beatmap_id == int(found_id) - break + @pytest.mark.asyncio + async def test_get_beatmap_async(self, beatmap_id: int): + beatmap = None + async with LazerAsync() as lazer: + beatmap = await lazer.get_beatmap(beatmap_id) + assert beatmap is not None + for line in beatmap.readlines(): + if line.startswith(b"BeatmapID"): + _, found_id = line.split(b":") + assert beatmap_id == int(found_id) + break -@pytest.mark.asyncio -@pytest.mark.parametrize("beatmap_id", [1001682]) -async def test_cache_status_async(beatmap_id: int): - async with LazerAsync() as lazer: - assert not await lazer.has_beatmap(beatmap_id) - await lazer.get_beatmap(beatmap_id) - assert await lazer.has_beatmap(beatmap_id) - await lazer.clear_cache() - assert not await lazer.has_beatmap(beatmap_id) + @pytest.mark.asyncio + async def test_cache_status_async(self, beatmap_id: int): + async with LazerAsync() as lazer: + assert not await lazer.has_beatmap(beatmap_id) + await lazer.get_beatmap(beatmap_id) + assert await lazer.has_beatmap(beatmap_id) + await lazer.clear_cache() + assert not await lazer.has_beatmap(beatmap_id) -@pytest.mark.asyncio -@pytest.mark.parametrize("beatmap_id", [1001682]) -@pytest.mark.parametrize("mods", [[OsuMod.DOUBLE_TIME]]) -@pytest.mark.parametrize("star_rating", [9.7]) -@pytest.mark.parametrize("max_combo", [3220]) -async def test_calculate_difficulty_id_async( - beatmap_id: int, mods: list[OsuMod], star_rating: float, max_combo: int -): - async with LazerAsync() as lazer: - attributes = await lazer.calculate_difficulty(mods, beatmap_id=beatmap_id) - assert attributes.star_rating == approx(star_rating, 0.03) - assert attributes.max_combo == max_combo - - -@pytest.mark.asyncio -@pytest.mark.parametrize("beatmap_filename", ["1001682.osu"]) -@pytest.mark.parametrize("mods", [[OsuMod.DOUBLE_TIME]]) -@pytest.mark.parametrize("star_rating", [9.7]) -@pytest.mark.parametrize("max_combo", [3220]) -async def test_calculate_difficulty_file_async( - beatmap_filename: str, mods: list[OsuMod], star_rating: float, max_combo: int -): - async with LazerAsync() as lazer: - with open(RESOURCES_DIR / beatmap_filename) as beatmap: - attributes = await lazer.calculate_difficulty(mods, beatmap=beatmap) - assert attributes.star_rating == approx(star_rating, 0.03) - assert attributes.max_combo == max_combo +@dataclass +class DifficultyTestCase: + beatmap_id: int + beatmap_filename: str + mods: list[OsuMod] + star_rating: float + max_combo: int + + +@pytest.mark.parametrize( + "test_case", + [ + DifficultyTestCase( + beatmap_id=1001682, + beatmap_filename="1001682.osu", + mods=[OsuMod.DOUBLE_TIME], + star_rating=9.7, + max_combo=3220, + ) + ], +) +class TestDifficulty: + def test_calculate_difficulty_id(self, test_case: DifficultyTestCase): + with Lazer() as lazer: + attributes = lazer.calculate_difficulty( + mods=test_case.mods, beatmap_id=test_case.beatmap_id + ) + assert attributes.star_rating == approx(test_case.star_rating, EPSILON) + assert attributes.max_combo == test_case.max_combo + + def test_calculate_difficulty_beatmap(self, test_case: DifficultyTestCase): + with ( + Lazer() as lazer, + open(RESOURCES_DIR / test_case.beatmap_filename, "rb") as beatmap, + ): + attributes = lazer.calculate_difficulty( + mods=test_case.mods, beatmap=beatmap + ) + assert attributes.star_rating == approx(test_case.star_rating, EPSILON) + assert attributes.max_combo == test_case.max_combo + + @pytest.mark.asyncio + async def test_calculate_difficulty_id_async(self, test_case: DifficultyTestCase): + async with LazerAsync() as lazer: + attributes = await lazer.calculate_difficulty( + mods=test_case.mods, beatmap_id=test_case.beatmap_id + ) + assert attributes.star_rating == approx(test_case.star_rating, EPSILON) + assert attributes.max_combo == test_case.max_combo + + @pytest.mark.asyncio + async def test_calculate_difficulty_beatmap_async( + self, test_case: DifficultyTestCase + ): + async with LazerAsync() as lazer: + with open(RESOURCES_DIR / test_case.beatmap_filename, "rb") as beatmap: + attributes = await lazer.calculate_difficulty( + mods=test_case.mods, beatmap=beatmap + ) + assert attributes.star_rating == approx(test_case.star_rating, EPSILON) + assert attributes.max_combo == test_case.max_combo + + +@dataclass +class PerformanceTestCase: + beatmap_id: int + beatmap_filename: str + mods: list[OsuMod] + hit_stats: HitStatistics + replay_filename: str + pp: float + + +@pytest.mark.parametrize( + "test_case", + [ + PerformanceTestCase( + beatmap_id=1001682, + beatmap_filename="1001682.osu", + mods=[OsuMod.HIDDEN, OsuMod.DOUBLE_TIME], + hit_stats=HitStatistics( + count_300=2019, count_100=104, count_50=0, count_miss=3, combo=3141 + ), + replay_filename="4429758207.osr", + pp=1304.35, + ) + ], +) +class TestPerformance: + def test_calculate_performance_id_hitstat(self, test_case: PerformanceTestCase): + with Lazer() as lazer: + attributes = lazer.calculate_performance( + beatmap_id=test_case.beatmap_id, + mods=test_case.mods, + hit_stats=test_case.hit_stats, + ) + assert attributes.total == approx(test_case.pp, EPSILON) + + def test_calculate_performance_beatmap_hitstat( + self, test_case: PerformanceTestCase + ): + with ( + Lazer() as lazer, + open(RESOURCES_DIR / test_case.beatmap_filename, "rb") as beatmap, + ): + attributes = lazer.calculate_performance( + beatmap=beatmap, + mods=test_case.mods, + hit_stats=test_case.hit_stats, + ) + assert attributes.total == approx(test_case.pp, EPSILON) + + def test_calculate_performance_difficulty(self, test_case: PerformanceTestCase): + with Lazer() as lazer: + attributes = lazer.calculate_performance( + difficulty=lazer.calculate_difficulty( + mods=test_case.mods, beatmap_id=test_case.beatmap_id + ), + hit_stats=test_case.hit_stats, + ) + assert attributes.total == approx(test_case.pp, EPSILON) + + def test_calculate_performance_id_replay(self, test_case: PerformanceTestCase): + with ( + Lazer() as lazer, + open(RESOURCES_DIR / test_case.replay_filename, "rb") as replay, + ): + attributes = lazer.calculate_performance( + beatmap_id=test_case.beatmap_id, + replay=replay, + ) + assert attributes.total == approx(test_case.pp, EPSILON) + + def test_calculate_performance_beatmap_replay(self, test_case: PerformanceTestCase): + with ( + Lazer() as lazer, + open(RESOURCES_DIR / test_case.beatmap_filename, "rb") as beatmap, + open(RESOURCES_DIR / test_case.replay_filename, "rb") as replay, + ): + attributes = lazer.calculate_performance( + beatmap=beatmap, + replay=replay, + ) + assert attributes.total == approx(test_case.pp, EPSILON) + + @pytest.mark.asyncio + async def test_calculate_performance_id_hitstat_async( + self, test_case: PerformanceTestCase + ): + async with LazerAsync() as lazer: + attributes = await lazer.calculate_performance( + beatmap_id=test_case.beatmap_id, + mods=test_case.mods, + hit_stats=test_case.hit_stats, + ) + assert attributes.total == approx(test_case.pp, EPSILON) + + @pytest.mark.asyncio + async def test_calculate_performance_beatmap_hitstat_async( + self, test_case: PerformanceTestCase + ): + async with LazerAsync() as lazer: + with open(RESOURCES_DIR / test_case.beatmap_filename, "rb") as beatmap: + attributes = await lazer.calculate_performance( + beatmap=beatmap, + mods=test_case.mods, + hit_stats=test_case.hit_stats, + ) + assert attributes.total == approx(test_case.pp, EPSILON) + + @pytest.mark.asyncio + async def test_calculate_performance_difficulty_async( + self, test_case: PerformanceTestCase + ): + async with LazerAsync() as lazer: + attributes = await lazer.calculate_performance( + difficulty=await lazer.calculate_difficulty( + mods=test_case.mods, beatmap_id=test_case.beatmap_id + ), + hit_stats=test_case.hit_stats, + ) + assert attributes.total == approx(test_case.pp, EPSILON) + + @pytest.mark.asyncio + async def test_calculate_performance_id_replay_async( + self, test_case: PerformanceTestCase + ): + async with LazerAsync() as lazer: + with open(RESOURCES_DIR / test_case.replay_filename, "rb") as replay: + attributes = await lazer.calculate_performance( + beatmap_id=test_case.beatmap_id, + replay=replay, + ) + assert attributes.total == approx(test_case.pp, EPSILON) + + @pytest.mark.asyncio + async def test_calculate_performance_beatmap_replay_async( + self, test_case: PerformanceTestCase + ): + async with LazerAsync() as lazer: + with ( + open(RESOURCES_DIR / test_case.beatmap_filename, "rb") as beatmap, + open(RESOURCES_DIR / test_case.replay_filename, "rb") as replay, + ): + attributes = await lazer.calculate_performance( + beatmap=beatmap, + replay=replay, + ) + assert attributes.total == approx(test_case.pp, EPSILON) diff --git a/vibrio/types.py b/vibrio/types.py index 3dd3e22..09a0c83 100644 --- a/vibrio/types.py +++ b/vibrio/types.py @@ -1,4 +1,9 @@ -from dataclasses import dataclass, fields +""" +Custom types used with :class:`~vibrio.Lazer` and :class:`~vibrio.LazerAsync`. +""" + +from abc import ABC +from dataclasses import asdict, dataclass, fields from enum import Enum from typing import Any @@ -6,6 +11,8 @@ class OsuMod(Enum): + """Enum representing the osu!standard mods as two-letter string codes.""" + NO_FAIL = "NF" EASY = "EZ" TOUCH_DEVICE = "TD" @@ -24,7 +31,68 @@ class OsuMod(Enum): @dataclass -class OsuDifficultyAttributes: +class SerializableDataclass(ABC): + """ + Abstract base for dataclasses supporting serialization to and deserialization + from a dictionary. + + This base class is intended for cross-compatible use with return types from C# code, + so key names are expected in camel case during deserialization and are normalized to + remove underscores during serialization. + """ + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Self: + """Instantiates a dataclass from the provided dictionary.""" + values: dict[str, Any] = {} + data_lowercase = {k.lower(): v for k, v in data.items()} + for field in fields(cls): + name = field.name.replace("_", "") + value = data_lowercase[name] + if field.type is list[OsuMod]: + value = [OsuMod(acronym) for acronym in value] + + values[field.name] = value + + return cls(**values) + + @staticmethod + def _factory(items: list[tuple[str, Any]]) -> dict[str, Any]: + data: dict[str, Any] = {} + for k, v in items: + if type(v) is list[OsuMod]: + v = [mod.value for mod in v] + data[k.replace("_", "")] = v + return data + + def to_dict(self) -> dict[str, Any]: + """Serializes dataclass values to a dictionary.""" + return asdict(self, dict_factory=self._factory) + + +@dataclass +class HitStatistics(SerializableDataclass): + """Dataclass representing an osu! play in terms of individual hit statistics.""" + + count_300: int + count_100: int + count_50: int + count_miss: int + combo: int + + +@dataclass +class OsuDifficultyAttributes(SerializableDataclass): + """ + osu!standard difficulty attributes, as produced by osu!lazer's internal difficulty + calculation. + + See Also + -------- + Lazer.calculate_difficulty + LazerAsync.calculate_difficulty + """ + mods: list[OsuMod] star_rating: float max_combo: int @@ -40,17 +108,23 @@ class OsuDifficultyAttributes: slider_count: int spinner_count: int - @classmethod - def from_json(cls, data: dict[str, Any]) -> Self: - values: dict[str, Any] = {} - data_lowercase = {k.lower(): v for k, v in data.items()} - for field in fields(cls): - name = field.name.replace("_", "") - if field.type is list[OsuMod]: - value = [OsuMod(acronym) for acronym in data_lowercase[name]] - else: - value = data_lowercase[name] - values[field.name] = value +@dataclass +class OsuPerformanceAttributes(SerializableDataclass): + """ + osu!standard performance attributes, as produced by osu!lazer's internal performance + calculation. - return cls(**values) + See Also + -------- + Lazer.calculate_difficulty + LazerAsync.calculate_difficulty + """ + + total: float + """The play's total pp amount.""" + aim: float + speed: float + accuracy: float + flashlight: float + effective_miss_count: float