Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add difficulty calculation #12

Merged
merged 9 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.osu binary
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vibrio"
version = "0.1.8"
version = "0.2.0"
readme = "README.md"
requires-python = ">=3.9"
license = { file = "LICENSE" }
Expand Down
99 changes: 97 additions & 2 deletions vibrio/lazer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import atexit
import io
import logging
import platform
import signal
import socket
Expand All @@ -13,13 +14,15 @@
import urllib.parse
from abc import ABC
from pathlib import Path
from typing import Any, BinaryIO, Optional
from typing import Any, BinaryIO, Optional, TextIO

import aiohttp
import psutil
import requests
from typing_extensions import Self

from vibrio.types import OsuDifficultyAttributes, OsuMod

PACKAGE_DIR = Path(__file__).absolute().parent


Expand Down Expand Up @@ -83,11 +86,12 @@ def _start(self) -> None:
self.running = True

if self.use_logging:
logging.info(f"Launching server on port {self.port}")
self.log = tempfile.NamedTemporaryFile(delete=False)

def _stop(self) -> None:
if self.log is not None:
print(f"Server output logged at {self.log.file.name}")
logging.info(f"Server output logged at {self.log.file.name}")
self.log.close()
self.log = None

Expand Down Expand Up @@ -228,6 +232,48 @@ def clear_cache(self) -> None:
f"Unexpected status code {response.status_code}; check server logs for error details"
)

def calculate_difficulty(
self,
mods: list[OsuMod],
beatmap_id: Optional[int] = None,
beatmap: Optional[TextIO] = None,
) -> OsuDifficultyAttributes:
params = {"mods": [mod.value for mod in mods]}

if beatmap_id is not None:
if beatmap is not None:
raise ValueError(
"Exactly one of `beatmap_id` and `beatmap_data` should be set"
)

with self.session.get(
f"/api/difficulty/{beatmap_id}", params=params
) as response:
if response.status_code == 200:
return OsuDifficultyAttributes.from_json(response.json())
elif response.status_code == 404:
raise BeatmapNotFound(f"No beatmap found for id {beatmap_id}")
else:
raise ServerError(
f"Unexpected status code {response.status_code}; check server logs for error details"
)

elif beatmap is not None:
with self.session.post(
"/api/difficulty", params=params, files={"beatmap": beatmap}
) as response:
if response.status_code == 200:
return OsuDifficultyAttributes.from_json(response.json())
else:
raise ServerError(
f"Unexpected status code {response.status_code}; check server logs for error details"
)

else:
raise ValueError(
"Exactly one of `beatmap_id` and `beatmap_data` should be set"
)


class LazerAsync(LazerBase):
"""Asynchronous implementation for interfacing with osu!lazer functionality."""
Expand Down Expand Up @@ -352,3 +398,52 @@ async def clear_cache(self) -> None:
raise ServerError(
f"Unexpected status code {response.status}; check server logs for error details"
)

async def calculate_difficulty(
self,
mods: list[OsuMod],
beatmap_id: Optional[int] = None,
beatmap: Optional[TextIO] = None,
) -> OsuDifficultyAttributes:
params = {"mods": [mod.value for mod in mods]}

if beatmap_id is not None:
if beatmap is not None:
raise ValueError(
"Exactly one of `beatmap_id` and `beatmap_data` should be set"
)

async with self.session.get(
f"/api/difficulty/{beatmap_id}", params=params
) as response:
if response.status == 200:
return OsuDifficultyAttributes.from_json(await response.json())
elif response.status == 404:
raise BeatmapNotFound(f"No beatmap found for id {beatmap_id}")
else:
raise ServerError(
f"Unexpected status code {response.status}; check server logs for error details"
)

elif beatmap is not None:
data = aiohttp.FormData()
data.add_field(
"beatmap",
beatmap.read(),
content_type="multipart/form-data",
filename="beatmap",
)
async with self.session.post(
"/api/difficulty", params=params, data=data
) as response:
if response.status == 200:
return OsuDifficultyAttributes.from_json(await response.json())
else:
raise ServerError(
f"Unexpected status code {response.status}; check server logs for error details"
)

else:
raise ValueError(
"Exactly one of `beatmap_id` and `beatmap_data` should be set"
)
Loading
Loading