diff --git a/src/aoc_cj/aoc2019/day18.py b/src/aoc_cj/aoc2019/day18.py index 31d7f2af..9d65561c 100644 --- a/src/aoc_cj/aoc2019/day18.py +++ b/src/aoc_cj/aoc2019/day18.py @@ -1,77 +1,92 @@ +import dataclasses import itertools from collections import deque +from collections.abc import Generator +from typing import TypeVar, Union + +import more_itertools as mi from aoc_cj import util +Grid = dict[complex, str] -def parse(txt: str): - grid = {complex(x, y): c for y, line in enumerate(txt.splitlines()) for x, c in enumerate(line)} - start_pos = next(p for p, c in grid.items() if c == "@") - num_keys = sum(1 for c in grid.values() if c.islower()) - return grid, start_pos, num_keys +@dataclasses.dataclass +class StateA: + pos: complex + steps: int = 0 + collected: frozenset[str] = frozenset() + def next_states(self: "StateA", grid: Grid) -> Generator["StateA", None, None]: + q = deque(((self.pos, self.steps),)) + seen: set[complex] = set() + while len(q) > 0: + pos, steps = q.popleft() + if pos in seen: + continue + val = grid[pos] + if val in ".@" or val.lower() in self.collected: + q.extend((pos + delta, steps + 1) for delta in (-1j, 1, 1j, -1)) + elif val != "#" and val.islower(): + yield StateA(pos, steps, frozenset((*self.collected, val))) + seen.add(pos) -def next_states_a(grid, state): - steps, pos, collected = state - q = deque() - q.append((pos, steps)) - seen = set() - while len(q) > 0: - pos, steps = q.popleft() - if pos in seen: - continue - val: str = grid.get(pos) - if val in ".@" or val.lower() in collected: - q.extend((pos + delta, steps + 1) for delta in (-1j, 1, 1j, -1)) - elif val != "#" and val.islower(): - yield (steps, pos, frozenset((*collected, val))) - seen.add(pos) - - -def next_states_b(grid, state): - steps, positions, collected = state - for i, pos in enumerate(positions): - for s in next_states_a(grid, (steps, pos, collected)): - new_steps, new_pos, new_collected = s - new_positions = list(positions) - new_positions[i] = new_pos - yield (new_steps, tuple(new_positions), new_collected) - - -def search(initial, num_keys, grid, next_states=next_states_a): - q = util.PriorityQueue() - q.push(0, initial) - seen = set() + +@dataclasses.dataclass +class StateB: + pos: tuple[complex, ...] + steps: int = 0 + collected: frozenset[str] = frozenset() + + def next_states(self, grid: Grid) -> Generator["StateB", None, None]: + for i, pos in enumerate(self.pos): + for new_state in StateA(pos, self.steps, self.collected).next_states(grid): + new_positions = list(self.pos) + new_positions[i] = new_state.pos + yield StateB(tuple(new_positions), new_state.steps, new_state.collected) + + +# TODO: replace w/ protocol +_TState = TypeVar("_TState", StateA, StateB) + + +def search(start: _TState, grid: Grid) -> int: + all_keys = frozenset(c for c in grid.values() if c.islower()) + + q: util.PriorityQueue[_TState] = util.PriorityQueue() + q.push(0, start) + seen: set[tuple[Union[complex, tuple[complex, ...]], frozenset[str]]] = set() while q: state = q.pop() - steps, pos, collected = state - if len(collected) == num_keys: - return steps - if (pos, collected) in seen: + if (state.pos, state.collected) in seen: continue - for s in next_states(grid, state): - q.push(s[0], s) - seen.add((pos, collected)) - return -1 + # if all keys have been collected + if state.collected == all_keys: + return state.steps + for s in state.next_states(grid): + q.push(s.steps, s) + seen.add((state.pos, state.collected)) + assert False, "unreachable" -def parta(txt: str): - grid, start_pos, num_keys = parse(txt) - initial = (0, start_pos, frozenset()) - return search(initial, num_keys, grid) +def parta(txt: str) -> int: + grid = {complex(x, y): c for y, line in enumerate(txt.splitlines()) for x, c in enumerate(line)} + start_pos = mi.one(p for p, c in grid.items() if c == "@") + + return search(StateA(start_pos), grid) -def partb(txt: str): - grid, start_pos, num_keys = parse(txt) +def partb(txt: str) -> int: + grid = {complex(x, y): c for y, line in enumerate(txt.splitlines()) for x, c in enumerate(line)} + start_pos = mi.one(p for p, c in grid.items() if c == "@") # update the cave entrance replace = "@#@\n###\n@#@".splitlines() for x, y in itertools.product(range(3), repeat=2): grid[start_pos + complex(x - 1, y - 1)] = replace[y][x] - initial = (0, tuple(start_pos + delta for delta in (-1 - 1j, 1 - 1j, 1 + 1j, -1 + 1j)), frozenset()) - return search(initial, num_keys, grid, next_states_b) + start = tuple(start_pos + delta for delta in (-1 - 1j, 1 - 1j, 1 + 1j, -1 + 1j)) + return search(StateB(start), grid) if __name__ == "__main__": diff --git a/tests/aoc2019/y2019d18_test.py b/tests/aoc2019/y2019d18_test.py index 20e9ff62..3be1850b 100644 --- a/tests/aoc2019/y2019d18_test.py +++ b/tests/aoc2019/y2019d18_test.py @@ -98,7 +98,7 @@ (EXAMPLE_INPUT_A4, 81), ], ) -def test_a(input, expected): +def test_a(input: str, expected: int) -> None: assert d.parta(input) == expected @@ -111,5 +111,5 @@ def test_a(input, expected): (EXAMPLE_INPUT_B4, 72), ], ) -def test_b(input, expected): +def test_b(input: str, expected: int) -> None: assert d.partb(input) == expected