Skip to content

Commit

Permalink
Merge pull request #12 from notjagan/difficulty-calculation
Browse files Browse the repository at this point in the history
Add difficulty calculation
  • Loading branch information
notjagan authored Dec 9, 2023
2 parents c0fa783 + 8c9f6cc commit d1de8ca
Show file tree
Hide file tree
Showing 6 changed files with 2,455 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.osu binary
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vibrio"
version = "0.1.8"
version = "0.2.0"
readme = "README.md"
requires-python = ">=3.9"
license = { file = "LICENSE" }
Expand Down
99 changes: 97 additions & 2 deletions vibrio/lazer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import atexit
import io
import logging
import platform
import signal
import socket
Expand All @@ -13,13 +14,15 @@
import urllib.parse
from abc import ABC
from pathlib import Path
from typing import Any, BinaryIO, Optional
from typing import Any, BinaryIO, Optional, TextIO

import aiohttp
import psutil
import requests
from typing_extensions import Self

from vibrio.types import OsuDifficultyAttributes, OsuMod

PACKAGE_DIR = Path(__file__).absolute().parent


Expand Down Expand Up @@ -83,11 +86,12 @@ def _start(self) -> None:
self.running = True

if self.use_logging:
logging.info(f"Launching server on port {self.port}")
self.log = tempfile.NamedTemporaryFile(delete=False)

def _stop(self) -> None:
if self.log is not None:
print(f"Server output logged at {self.log.file.name}")
logging.info(f"Server output logged at {self.log.file.name}")
self.log.close()
self.log = None

Expand Down Expand Up @@ -228,6 +232,48 @@ def clear_cache(self) -> None:
f"Unexpected status code {response.status_code}; check server logs for error details"
)

def calculate_difficulty(
self,
mods: list[OsuMod],
beatmap_id: Optional[int] = None,
beatmap: Optional[TextIO] = None,
) -> OsuDifficultyAttributes:
params = {"mods": [mod.value for mod in mods]}

if beatmap_id is not None:
if beatmap is not None:
raise ValueError(
"Exactly one of `beatmap_id` and `beatmap_data` should be set"
)

with self.session.get(
f"/api/difficulty/{beatmap_id}", params=params
) as response:
if response.status_code == 200:
return OsuDifficultyAttributes.from_json(response.json())
elif response.status_code == 404:
raise BeatmapNotFound(f"No beatmap found for id {beatmap_id}")
else:
raise ServerError(
f"Unexpected status code {response.status_code}; check server logs for error details"
)

elif beatmap is not None:
with self.session.post(
"/api/difficulty", params=params, files={"beatmap": beatmap}
) as response:
if response.status_code == 200:
return OsuDifficultyAttributes.from_json(response.json())
else:
raise ServerError(
f"Unexpected status code {response.status_code}; check server logs for error details"
)

else:
raise ValueError(
"Exactly one of `beatmap_id` and `beatmap_data` should be set"
)


class LazerAsync(LazerBase):
"""Asynchronous implementation for interfacing with osu!lazer functionality."""
Expand Down Expand Up @@ -352,3 +398,52 @@ async def clear_cache(self) -> None:
raise ServerError(
f"Unexpected status code {response.status}; check server logs for error details"
)

async def calculate_difficulty(
self,
mods: list[OsuMod],
beatmap_id: Optional[int] = None,
beatmap: Optional[TextIO] = None,
) -> OsuDifficultyAttributes:
params = {"mods": [mod.value for mod in mods]}

if beatmap_id is not None:
if beatmap is not None:
raise ValueError(
"Exactly one of `beatmap_id` and `beatmap_data` should be set"
)

async with self.session.get(
f"/api/difficulty/{beatmap_id}", params=params
) as response:
if response.status == 200:
return OsuDifficultyAttributes.from_json(await response.json())
elif response.status == 404:
raise BeatmapNotFound(f"No beatmap found for id {beatmap_id}")
else:
raise ServerError(
f"Unexpected status code {response.status}; check server logs for error details"
)

elif beatmap is not None:
data = aiohttp.FormData()
data.add_field(
"beatmap",
beatmap.read(),
content_type="multipart/form-data",
filename="beatmap",
)
async with self.session.post(
"/api/difficulty", params=params, data=data
) as response:
if response.status == 200:
return OsuDifficultyAttributes.from_json(await response.json())
else:
raise ServerError(
f"Unexpected status code {response.status}; check server logs for error details"
)

else:
raise ValueError(
"Exactly one of `beatmap_id` and `beatmap_data` should be set"
)
Loading

0 comments on commit d1de8ca

Please sign in to comment.