Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Condense server code #6

Merged
merged 6 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ jobs:

steps:
- name: Check build matrix status
if: failure() || cancelled()
if: ${{ !success() }}
run: exit 1

publish:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vibrio"
version = "0.1.3"
version = "0.1.4"
readme = "README.md"
requires-python = ">=3.9"
license = { file = "LICENSE" }
Expand Down
269 changes: 218 additions & 51 deletions vibrio/lazer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@
import asyncio
import atexit
import io
import platform
import signal
import socket
import stat
import subprocess
import tempfile
import time
import urllib.parse
from abc import ABC
from typing import BinaryIO, Generic, Optional, TypeVar
from pathlib import Path
from typing import Any, BinaryIO, Optional

import aiohttp
import psutil
import requests
from typing_extensions import Self

from vibrio.server import Server, ServerAsync, ServerBase
PACKAGE_DIR = Path(__file__).absolute().parent


class ServerStateError(Exception):
Expand All @@ -31,38 +42,149 @@ def find_open_port() -> int:
return sock.getsockname()[1]


T = TypeVar("T", bound=ServerBase)
def get_vibrio_path(platform: str) -> Path:
"""Determines path to server executable on a given platform."""
if platform == "Windows":
suffix = ".exe"
else:
suffix = ""

return PACKAGE_DIR / "lib" / f"Vibrio{suffix}"

class LazerBase(ABC, Generic[T]):

class LazerBase(ABC):
"""Shared functionality for lazer wrappers."""

STARTUP_DELAY = 0.05 # Amount of time (seconds) between requests during startup

def __init__(self, port: Optional[int] = None, use_logging: bool = True) -> None:
if port is None:
self.port = find_open_port()
else:
self.port = port
self.use_logging = use_logging
self.running = False

self.server_path = get_vibrio_path(platform.system())
if not self.server_path.exists():
raise FileNotFoundError(f'No executable found at "{self.server_path}".')
self.server_path.chmod(self.server_path.stat().st_mode | stat.S_IEXEC)

self.log: Optional[tempfile._TemporaryFileWrapper[bytes]] = None

def address(self) -> str:
"""Constructs the base URL for the web server."""
return f"http://localhost:{self.port}"

def _start(self) -> None:
if self.running:
raise ServerStateError("Server is already running")

self.running = True

if self.use_logging:
self.log = tempfile.NamedTemporaryFile(delete=False)

def _stop(self) -> None:
if self.log is not None:
print(f"Server output logged at {self.log.file.name}")
self.log.close()
self.log = None


class BaseUrlSession(requests.Session):
def __init__(self, base_url: str) -> None:
super().__init__()
self.base_url = base_url

self.server: Optional[T] = None
def request(
self, method: str | bytes, url: str | bytes, *args: Any, **kwargs: Any
) -> requests.Response:
full_url = urllib.parse.urljoin(self.base_url, str(url))
return super().request(method, full_url, *args, **kwargs)


class Lazer(LazerBase[Server]):
class Lazer(LazerBase):
"""Synchronous implementation for interfacing with osu!lazer functionality."""

def __init__(self, port: int | None = None, use_logging: bool = True) -> None:
super().__init__(port, use_logging)

self.session = None
self.process = None

@property
def session(self) -> BaseUrlSession:
if self._session is None:
raise ServerStateError("Session has not been initialized")
return self._session

@session.setter
def session(self, value: Optional[BaseUrlSession]) -> None:
self._session = value

@property
def process(self) -> subprocess.Popen[bytes]:
if self._process is None:
raise ServerStateError("Process has not been initialized")
return self._process

@process.setter
def process(self, value: Optional[subprocess.Popen[bytes]]) -> None:
self._process = value

def start(self) -> None:
"""Creates and starts up server."""
if self.server is None:
self.server = Server.create(self.port, self.use_logging)
atexit.register(self.stop)
"""Launches server executable."""
self._start()

if self.log is not None:
out = self.log
else:
raise ServerStateError("Server is already running")
out = subprocess.DEVNULL

self.process = subprocess.Popen(
[self.server_path, "--urls", self.address()],
stdout=out,
stderr=out,
)

self.session = BaseUrlSession(self.address())

# block until webserver has launched
while True:
try:
with self.session.get("/api/status") as response:
if response.status_code == 200:
break
except (ConnectionError, IOError):
pass
finally:
time.sleep(self.STARTUP_DELAY)

atexit.register(self.stop)

def stop(self) -> None:
"""Cleans up and removes server."""
if self.server is not None:
self.server.stop()
self.server = None
"""Cleans up server executable."""
self._stop()

try:
parent = psutil.Process(self.process.pid)
for child in parent.children(recursive=True):
child.terminate()
parent.terminate()
status = self.process.wait()
self.process = None

if status != 0 and status != signal.SIGTERM:
raise SystemError(
f"Could not cleanly shutdown server subprocess; received return code {status}"
)

self.session.close()
self.session = None

except ServerStateError:
pass

def __enter__(self) -> Self:
self.start()
Expand All @@ -74,10 +196,7 @@ def __exit__(self, *_) -> bool:

def has_beatmap(self, beatmap_id: int) -> bool:
"""Checks if given beatmap is cached/available locally."""
if self.server is None:
raise ServerStateError("Server is not currently active")

with self.server.session.get(f"/api/beatmaps/{beatmap_id}/status") as response:
with self.session.get(f"/api/beatmaps/{beatmap_id}/status") as response:
if response.status_code == 200:
return True
elif response.status_code == 404:
Expand All @@ -88,10 +207,7 @@ def has_beatmap(self, beatmap_id: int) -> bool:

def get_beatmap(self, beatmap_id: int) -> BinaryIO:
"""Returns a file stream for the given beatmap."""
if self.server is None:
raise ServerStateError("Server is not currently active")

with self.server.session.get(f"/api/beatmaps/{beatmap_id}") as response:
with self.session.get(f"/api/beatmaps/{beatmap_id}") as response:
if response.status_code == 200:
stream = io.BytesIO()
stream.write(response.content)
Expand All @@ -106,32 +222,94 @@ def get_beatmap(self, beatmap_id: int) -> BinaryIO:

def clear_cache(self) -> None:
"""Clears beatmap cache (if applicable)."""
if self.server is None:
raise ServerStateError("Server is not currently active")

with self.server.session.delete("/api/beatmaps/cache") as response:
with self.session.delete("/api/beatmaps/cache") as response:
if response.status_code != 200:
raise ServerError(
f"Unexpected status code {response.status_code}; check server logs for error details"
)


class LazerAsync(LazerBase[ServerAsync]):
class LazerAsync(LazerBase):
"""Asynchronous implementation for interfacing with osu!lazer functionality."""

def __init__(self, port: int | None = None, use_logging: bool = True) -> None:
super().__init__(port, use_logging)

self.session = None
self.process = None

@property
def session(self) -> aiohttp.ClientSession:
if self._session is None:
raise ServerStateError("Session has not been initialized")
return self._session

@session.setter
def session(self, value: Optional[aiohttp.ClientSession]) -> None:
self._session = value

@property
def process(self) -> asyncio.subprocess.Process:
if self._process is None:
raise ServerStateError("Process has not been initialized")
return self._process

@process.setter
def process(self, value: Optional[asyncio.subprocess.Process]) -> None:
self._process = value

async def start(self) -> None:
"""Creates and starts up server."""
if self.server is None:
self.server = await ServerAsync.create(self.port, self.use_logging)
atexit.register(lambda: asyncio.run(self.stop()))
"""Launches server executable."""
self._start()

if self.log is not None:
out = self.log
else:
raise ServerStateError("Server is already running")
out = subprocess.DEVNULL

self.process = await asyncio.create_subprocess_shell(
f"{self.server_path} --urls {self.address()}",
stdout=out,
stderr=out,
)

self.session = aiohttp.ClientSession(self.address())

# block until webserver has launched
while True:
try:
async with self.session.get("/api/status") as response:
if response.status == 200:
break
except (ConnectionError, aiohttp.ClientConnectionError):
pass
finally:
await asyncio.sleep(self.STARTUP_DELAY)

atexit.register(lambda: asyncio.run(self.stop()))

async def stop(self) -> None:
"""Cleans up and removes server."""
if self.server is not None:
await self.server.stop()
self.server = None
"""Cleans up server executable."""
self._stop()

try:
parent = psutil.Process(self.process.pid)
for child in parent.children(recursive=True):
child.terminate()
parent.terminate()
status = await self.process.wait()
self.process = None

await self.session.close()
self.session = None

if status != 0 and status != signal.SIGTERM:
raise SystemError(
f"Could not cleanly shutdown server subprocess; received return code {status}"
)

except ServerStateError:
pass

async def __aenter__(self) -> Self:
await self.start()
Expand All @@ -143,12 +321,7 @@ async def __aexit__(self, *_) -> bool:

async def has_beatmap(self, beatmap_id: int) -> bool:
"""Checks if given beatmap is cached/available locally."""
if self.server is None:
raise ServerStateError("Server is not currently active")

async with self.server.session.get(
f"/api/beatmaps/{beatmap_id}/status"
) as response:
async with self.session.get(f"/api/beatmaps/{beatmap_id}/status") as response:
if response.status == 200:
return True
elif response.status == 404:
Expand All @@ -159,10 +332,7 @@ async def has_beatmap(self, beatmap_id: int) -> bool:

async def get_beatmap(self, beatmap_id: int) -> BinaryIO:
"""Returns a file stream for the given beatmap."""
if self.server is None:
raise ServerStateError("Server is not currently active")

async with self.server.session.get(f"/api/beatmaps/{beatmap_id}") as response:
async with self.session.get(f"/api/beatmaps/{beatmap_id}") as response:
if response.status == 200:
stream = io.BytesIO()
stream.write(await response.read())
Expand All @@ -177,10 +347,7 @@ async def get_beatmap(self, beatmap_id: int) -> BinaryIO:

async def clear_cache(self) -> None:
"""Clears beatmap cache (if applicable)."""
if self.server is None:
raise ServerStateError("Server is not currently active")

async with self.server.session.delete("/api/beatmaps/cache") as response:
async with self.session.delete("/api/beatmaps/cache") as response:
if response.status != 200:
raise ServerError(
f"Unexpected status code {response.status}; check server logs for error details"
Expand Down
Loading
Loading