Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

worker/game_process: make worker connection init more robust #5

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions dweam/game_process.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,63 @@
import os
from pathlib import Path
import sys


def setup_logging() -> None:
"""Set up logging to both console and file"""

pid = os.getpid()

cache_dir = os.environ.get("CACHE_DIR")
if cache_dir is None:
cache_dir = Path.home() / ".dweam" / "cache"
else:
cache_dir = Path(cache_dir)

log_dir = cache_dir / "worker_logs"
log_dir.mkdir(exist_ok=True)

log_file = log_dir / f"game_process_{pid}.log"

# Create file handle with line buffering
log_handle = open(log_file, 'w', buffering=1)

# Save original stdout/stderr
original_stdout = sys.stdout
original_stderr = sys.stderr

class DualOutput:
def __init__(self, file1, file2):
self.file1 = file1
self.file2 = file2

def write(self, data):
self.file1.write(data)
self.file2.write(data)
self.file1.flush()
self.file2.flush()

def flush(self):
self.file1.flush()
self.file2.flush()

# Replace stdout/stderr with dual-output versions
sys.stdout = DualOutput(original_stdout, log_handle)
sys.stderr = DualOutput(original_stderr, log_handle)

print(f"=== Game Process {pid} Starting ===")
print(f"Logging to {log_file}")
print(f"Command line args: {sys.argv}")


setup_logging()


import logging
logging.getLogger("aioice.ice").disabled = True

import asyncio
import json
import sys
from typing import Any
from datetime import datetime, timedelta
from dweam.constants import JS_TO_PYGAME_BUTTON_MAP, JS_TO_PYGAME_KEY_MAP
Expand All @@ -15,7 +69,6 @@
from aiortc import VideoStreamTrack, RTCPeerConnection, RTCSessionDescription, RTCConfiguration, RTCIceServer, RTCDataChannel
from aiortc.contrib.signaling import object_from_string, object_to_string
import torch
import os
import socket
from dweam.utils.process import patch_subprocess_popen

Expand Down Expand Up @@ -160,6 +213,8 @@ async def main():

log.info("Parsed args", game_type=game_type, game_id=game_id, port=port)

log.info("Attempting to connect to parent", host='127.0.0.1', port=port)

try:
# Connect to parent process
reader, writer = await asyncio.open_connection(
Expand All @@ -168,7 +223,7 @@ async def main():
)
log.info("Connected to parent")
except Exception as e:
log.error("Failed to connect to parent", error=str(e))
log.exception("Failed to connect to parent")
raise

# Load the game implementation
Expand Down
185 changes: 89 additions & 96 deletions dweam/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,58 +114,52 @@ async def _collect_process_output(self, process: Process) -> tuple[str | None, s

return stdout_str, stderr_str

async def start(self):
"""Start the worker process and establish communication"""
async def _establish_connection(self, timeout: float) -> bool:
"""Establish connection with the worker process"""
if sys.platform == "win32":
venv_python = self.venv_path / "Scripts" / "python.exe"
else:
venv_python = self.venv_path / "bin" / "python"

# Use importlib.resources to reliably locate the module file
if getattr(sys, 'frozen', False):
# In PyInstaller bundle
worker_script = Path(sys._MEIPASS) / "dweam" / "dweam" / "game_process.py"
else:
# In development
worker_script = files('dweam').joinpath('game_process.py')

max_retries = 3
retry_delay = 1.0

for attempt in range(max_retries):
try:
if attempt > 0:
self.log.info(f"Retry attempt {attempt + 1}/{max_retries}")
await asyncio.sleep(retry_delay)

# Create a TCP server socket
client_connected = asyncio.Event()

async def handle_client(reader, writer):
self.reader = reader
self.writer = writer
client_connected.set()

server = await asyncio.start_server(
handle_client,
host='127.0.0.1',
port=0 # Let OS choose port
)
addr = server.sockets[0].getsockname()
port = addr[1]
# Create a TCP server socket
client_connected = asyncio.Event()

# TODO bind a uuid to the logger
async def handle_client(reader, writer):
self.reader = reader
self.writer = writer
client_connected.set()

# Create server and get the port
server = await asyncio.start_server(
handle_client,
host='127.0.0.1',
port=0
)
port = server.sockets[0].getsockname()[1]

self.log.info("Started TCP server", port=port)

# Log what we're about to execute
self.log.info("Starting worker process",
python=str(venv_python),
script=str(worker_script),
game_type=self.game_type,
game_id=self.game_id)

# Start the worker process with the port number
self.log.info("Got available port", port=port)

# Start the worker process with the port number
self.log.info(
"Starting worker process",
python=str(venv_python),
script=str(worker_script),
game_type=self.game_type,
game_id=self.game_id
)

# Start serving (but don't block)
async with server:
self.log.debug("Starting server")
server_task = asyncio.create_task(server.serve_forever())

try:
self.process = await asyncio.create_subprocess_exec(
str(venv_python),
str(worker_script),
Expand Down Expand Up @@ -194,78 +188,77 @@ async def handle_client(reader, writer):
stdout=stdout,
stderr=stderr
)
continue # Retry
return False
except asyncio.TimeoutError:
# Process is still running, this is good
pass
# Wait for either client connection or process termination
done, pending = await asyncio.wait([
asyncio.create_task(client_connected.wait()),
asyncio.create_task(self.process.wait())
], timeout=timeout, return_when=asyncio.FIRST_COMPLETED)

# If process completed first, it means it crashed
if self.process.returncode is not None:
stdout_str, stderr_str = await self._collect_process_output(self.process)
self.log.error(
"Process terminated before connection",
returncode=self.process.returncode,
stderr=stderr_str,
stdout=stdout_str
)
return False # Changed from raise to return False to allow retries

# Try to connect with timeout
# If we get here and nothing completed, it was a timeout
if not done:
raise asyncio.TimeoutError("Worker process failed to connect")

finally:
# Cancel server task
server_task.cancel()
try:
# Start serving (but don't block)
async with server:
self.log.debug("Starting server")
server_task = asyncio.create_task(server.serve_forever())

try:
# Wait for either client connection or process termination
done, pending = await asyncio.wait([
asyncio.create_task(client_connected.wait()),
asyncio.create_task(self.process.wait())
], timeout=5, return_when=asyncio.FIRST_COMPLETED)

# If process completed first, it means it crashed
if self.process.returncode is not None:
stdout_str, stderr_str = await self._collect_process_output(self.process)
self.log.error(
"Process terminated before connection",
returncode=self.process.returncode,
stderr=stderr_str,
stdout=stdout_str
)
raise RuntimeError(f"Process terminated with return code {self.process.returncode}")

# If we get here and nothing completed, it was a timeout
if not done:
raise asyncio.TimeoutError("Worker process failed to connect")
except asyncio.TimeoutError:
# If timeout occurs, collect stderr/stdout before raising
stdout_str, stderr_str = await self._collect_process_output(self.process)

self.log.error(
"Worker failed to connect",
returncode=self.process.returncode,
stdout=stdout_str,
stderr=stderr_str,
)
raise TimeoutError("Worker process failed to connect")
finally:
# Cancel server task
server_task.cancel()
try:
await server_task
except asyncio.CancelledError:
pass

if not self.reader or not self.writer:
raise RuntimeError("No connection received")

# Only start monitoring after successful connection
asyncio.create_task(self._monitor_process_output(self.process.stdout, "stdout"))
asyncio.create_task(self._monitor_process_output(self.process.stderr, "stderr"))

self.log.info("Client connected")
await server_task
except asyncio.CancelledError:
pass

if not self.reader or not self.writer:
raise RuntimeError("No connection received")

return True

async def start(self):
"""Start the worker process and establish communication"""
max_retries = 3
base_timeout = 5.0

for attempt in range(max_retries):
try:
timeout = base_timeout * (2 ** attempt) # 5s, 10s, 20s

# Try to establish connection
try:
if not await self._establish_connection(timeout):
continue

# Start monitoring process output
asyncio.create_task(self._monitor_process_output(self.process.stdout, "stdout"))
asyncio.create_task(self._monitor_process_output(self.process.stderr, "stderr"))

self.log.info("Client connected")
return # Success!

except Exception:
self.log.exception("Error during connection")
continue # Retry
if attempt == max_retries - 1:
raise
continue

except Exception:
self.log.exception(f"Attempt {attempt + 1} failed")
if self.process:
self.process.kill()
if attempt == max_retries - 1:
raise # Re-raise on last attempt
raise

raise RuntimeError(f"Failed to start worker after {max_retries} attempts")

Expand Down