-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathentrypoint.py
420 lines (354 loc) · 15.1 KB
/
entrypoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
from collections import defaultdict
import importlib
import os
import uuid
from typing_extensions import assert_never
import subprocess
import sys
try:
import tomli as toml_lib
except ImportError:
if sys.version_info >= (3, 11):
import tomllib as toml_lib
else:
raise ImportError("Neither tomli nor tomllib (Python >= 3.11) are available. Please install tomli.")
from pathlib import Path
from structlog.stdlib import BoundLogger
import importlib.util
from typing import BinaryIO
import shutil
from packaging import markers as pkg_markers
from dweam.models import (
PackageMetadata, GameInfo, GameSource,
GitBranchSource, PathSource, PyPISource, SourceConfig
)
from dweam.utils.venv import ensure_correct_dweam_version, run_pip_with_output
# Define default sources for each game
DEFAULT_SOURCE_CONFIG = SourceConfig(
packages={
"diamond_yumenikki": [
PathSource(
path=Path("diamond-yumenikki"),
),
GitBranchSource(
git="https://github.com/dweam-team/diamond-yumenikki",
branch="master",
),
],
"lucid_v1": [
PathSource(
path=Path("lucid-v1"),
),
GitBranchSource(
git="https://github.com/dweam-team/lucid-v1",
branch="master",
markers="platform_system != 'Windows'", # JAX does not support GPU on Windows
),
],
"open_oasis": [
PathSource(
path=Path("open-oasis"),
),
GitBranchSource(
git="https://github.com/dweam-team/open-oasis",
branch="master",
markers="platform_system != 'Windows'", # triton does not support Windows
),
],
"diamond_mariokart": [
PathSource(
path=Path("AI-MarioKart64"),
),
GitBranchSource(
git="https://github.com/dweam-team/AI-MarioKart64",
branch="main",
),
],
"diamond_csgo": [
PathSource(
path=Path("diamond-csgo"),
),
GitBranchSource(
git="https://github.com/dweam-team/diamond",
branch="csgo",
),
],
"snake_diffusion": [
PathSource(
path=Path("snake-diffusion"),
),
GitBranchSource(
git="https://github.com/dweam-team/snake-diffusion",
branch="main",
),
],
"diamond_atari": [
PathSource(
path=Path("diamond"),
),
GitBranchSource(
git="https://github.com/dweam-team/diamond",
branch="main",
),
],
}
)
def get_cache_dir() -> Path:
"""Get the cache directory for storing git repositories"""
cache_dir = os.environ.get("CACHE_DIR")
if cache_dir is not None:
return Path(cache_dir)
return Path.home() / ".dweam" / "cache"
def get_pip_path(venv_path: Path) -> Path:
"""Get the pip executable path for the given venv"""
return venv_path / "Scripts" / "pip.exe" if sys.platform == "win32" else venv_path / "bin" / "pip"
def evaluate_markers(markers: str | None) -> bool:
"""Evaluate environment markers like 'platform_system != "Windows"'"""
if not markers:
return True
marker = pkg_markers.Marker(markers)
return marker.evaluate()
def install_game_source(log: BoundLogger, venv_path: Path, source: GameSource, name: str) -> Path | None:
"""Install a game from its source into the given venv and return the module path"""
if not evaluate_markers(source.markers):
log.info("Skipping installation due to environment markers", markers=source.markers)
return None
pip_path = get_pip_path(venv_path)
if not pip_path.exists():
log.error("Pip executable not found", path=str(pip_path))
return None
try:
# Common pip install args
pip_base_args = [
str(pip_path),
"install",
# "--force-reinstall",
"--extra-index-url",
"https://download.pytorch.org/whl/cu121",
]
# Install package based on source type
if isinstance(source, PathSource):
abs_path = source.path.absolute()
if not abs_path.exists():
log.warning("Source path does not exist", path=str(abs_path))
return None
log.info("Installing from local path", path=str(abs_path))
returncode = run_pip_with_output(log, [*pip_base_args, "-e", str(abs_path)])
if returncode != 0:
log.error("Failed to install from local path")
return None
elif isinstance(source, GitBranchSource):
git_url = f"git+{source.git}@{source.branch}#egg={name}"
log.info("Installing from git", url=git_url)
returncode = run_pip_with_output(log, [*pip_base_args, git_url])
if returncode != 0:
log.error("Failed to install from git")
return None
elif isinstance(source, PyPISource):
package_spec = f"{name}=={source.version}"
log.info("Installing from PyPI", package=package_spec)
returncode = run_pip_with_output(log, [*pip_base_args, package_spec])
if returncode != 0:
log.error("Failed to install from PyPI")
return None
else:
assert_never(source)
return None
# Get the installed package location
result = subprocess.run(
[str(pip_path), "show", name],
capture_output=True,
text=True
)
if result.returncode != 0:
log.error("Failed to get package location", stdout=result.stdout, stderr=result.stderr)
return None
# Parse the Location and Editable project location from pip show output
location = None
editable_location = None
for line in result.stdout.splitlines():
if line.startswith("Editable project location: "):
editable_location = Path(line.split(": ")[1]) / name
elif line.startswith("Location: "):
location = Path(line.split(": ")[1]) / name
# Prefer editable location if available
if editable_location is not None:
return editable_location
elif location is not None:
return location
log.error("Could not find package location in pip show output")
return None
except Exception as e:
log.exception("Unexpected error installing game source")
return None
def load_toml(file: BinaryIO) -> dict:
"""Load TOML from a binary file object.
Uses tomli if available, otherwise falls back to tomllib on Python >= 3.11.
Raises ImportError if neither is available."""
return toml_lib.load(file)
def load_metadata_from_module(log: BoundLogger, module_name: str) -> PackageMetadata | None:
"""Load metadata from an installed module by name"""
try:
spec = importlib.util.find_spec(module_name)
if spec is None or spec.origin is None:
log.error("Module not found", module=module_name)
return None
# Get the module's root directory
module_path = Path(spec.origin).parent
# First try dweam.toml in the module directory
dweam_path = module_path / "dweam.toml"
if not dweam_path.exists():
log.error("dweam.toml not found", path=str(dweam_path))
return None
with open(dweam_path, "rb") as f:
data = load_toml(f)
metadata = PackageMetadata.model_validate(data)
metadata._module_dir = module_path
return metadata
# FIXME similarly to the other [tool.dweam] explanation, this won't work
# # Then try pyproject.toml
# pyproject_path = module_path / "pyproject.toml"
# if pyproject_path.exists():
# with open(pyproject_path, "rb") as f:
# pyproject_data = load_toml(f)
# # Get the [tool.dweam] table
# if "tool" in pyproject_data and "dweam" in pyproject_data["tool"]:
# dweam_data = pyproject_data["tool"]["dweam"]
# if "games" not in dweam_data:
# raise ValueError("No [tool.dweam.games] section found in pyproject.toml")
# return PackageMetadata.model_validate(dweam_data)
except (TypeError, OSError):
log.exception("Error loading metadata from module")
return None
def load_metadata_from_path(log: BoundLogger, path: Path) -> PackageMetadata | None:
"""Load metadata from a module path"""
try:
# First try dweam.toml in the module directory
dweam_path = path / "dweam.toml"
if not dweam_path.exists():
log.error("dweam.toml not found", path=str(dweam_path))
return None
with open(dweam_path, "rb") as f:
data = load_toml(f)
metadata = PackageMetadata.model_validate(data)
metadata._module_dir = path
return metadata
# FIXME path is to the installed module, not the package, so pyproject tool.dweam won't work like this
# # Then try pyproject.toml
# pyproject_path = path / "pyproject.toml"
# if pyproject_path.exists():
# with open(pyproject_path, "rb") as f:
# pyproject_data = load_toml(f)
# # Get the [tool.dweam] table
# if "tool" in pyproject_data and "dweam" in pyproject_data["tool"]:
# dweam_data = pyproject_data["tool"]["dweam"]
# if "games" not in dweam_data:
# raise ValueError("No [tool.dweam.games] section found in pyproject.toml")
# return PackageMetadata.model_validate(dweam_data)
except (TypeError, OSError):
log.exception("Error loading metadata from path")
return None
def load_metadata_from_package(package_path: Path) -> PackageMetadata | None:
"""Load metadata from the parent directory of an installed package"""
try:
# Get the parent directory of the package
# TODO this'll work for git and local when dweam.toml is in root
# but not for pypi
parent_path = package_path.parent
# Try dweam.toml in the parent directory
dweam_path = parent_path / "dweam.toml"
if dweam_path.exists():
with open(dweam_path, "rb") as f:
data = load_toml(f)
return PackageMetadata.model_validate(data)
except (TypeError, OSError) as e:
print(f"Error loading metadata from package: {e}")
pass
return None
def load_game_implementation(entrypoint: str) -> type:
"""Load a game implementation from an entrypoint string (e.g. 'package.module:Class')"""
try:
module_path, class_name = entrypoint.split(':')
module = importlib.import_module(module_path)
return getattr(module, class_name)
except Exception as e:
raise ImportError(f"Failed to load game implementation from {entrypoint}") from e
def load_games(
log: BoundLogger,
venv_path: Path | None = None,
games: defaultdict[str, dict[str, GameInfo]] | None = None,
) -> dict[str, dict[str, GameInfo]]:
"""Load games from their sources into a single venv"""
if games is None:
games = defaultdict(dict)
for name, sources in DEFAULT_SOURCE_CONFIG.packages.items():
success = False
for source in sources:
try:
if venv_path is not None:
# Install and load from venv
module_path = install_game_source(log, venv_path, source, name)
if module_path is None:
continue
metadata = load_metadata_from_path(log, module_path)
else:
# Try to load from installed package
metadata = load_metadata_from_module(log, name)
if metadata is None:
log.error("No metadata found for game", name=name)
continue
# Add game info to the games dict
for game_id, game_info in metadata.games.items():
if game_id in games[metadata.type]:
log.warning(
"Game ID already exists for type. Overriding...",
type=metadata.type,
id=game_id,
)
game_info._metadata = metadata
games[metadata.type][game_id] = game_info
log.info("Successfully loaded game", name=name)
success = True
break
except Exception as e:
log.warning("Failed to load game from source", name=name, source=source, exc_info=True)
continue
if not success:
log.error("Failed to load game from any source", name=name)
if venv_path is not None:
pip_path = get_pip_path(venv_path)
ensure_correct_dweam_version(log, pip_path)
log.info("Finished loading games")
return games
# def load_game_entrypoints(log: BoundLogger, games: defaultdict[str, dict[str, GameInfo]] | None = None) -> dict[str, dict[str, GameInfo]]:
# """Load all games from installed packages"""
# if games is None:
# games = defaultdict(dict)
# game_entrypoints = defaultdict(dict)
# entrypoints = importlib_metadata.entry_points(group="dweam")
# for entry_point in entrypoints.select(name="game"):
# try:
# game_class = entry_point.load()
# except Exception as e:
# log.exception("Error loading game entrypoint", entrypoint=entry_point)
# continue
# if isinstance(game_class.game_info, list):
# game_infos = game_class.game_info
# else:
# game_infos = [game_class.game_info]
# for game_info in game_infos:
# game_info._implementation = game_class
# if game_info.id in games[game_info.type]:
# previous_entrypoint = game_entrypoints[game_info.type][game_info.id]
# current_entrypoint = entry_point.name
# log.error(
# "Game ID already exists for type. Overriding...",
# type=game_info.type,
# id=game_info.id,
# previous_entrypoint=previous_entrypoint,
# new_entrypoint=current_entrypoint,
# )
# game_entrypoints[game_info.type][game_info.id] = entry_point.name
# games[game_info.type][game_info.id] = game_info
# log.info("Loaded game entrypoint", entrypoint=entry_point)
# return games