diff --git a/src/main/python/AoC2024_20.py b/src/main/python/AoC2024_20.py index 099b184..92d4b75 100644 --- a/src/main/python/AoC2024_20.py +++ b/src/main/python/AoC2024_20.py @@ -7,7 +7,7 @@ from aoc.common import InputData from aoc.common import SolutionBase -from aoc.graph import bfs +from aoc.graph import bfs_full from aoc.grid import CharGrid Input = CharGrid @@ -39,28 +39,34 @@ def parse_input(self, input_data: InputData) -> Input: return CharGrid.from_strings(list(input_data)) def solve(self, grid: CharGrid, cheat_len: int, target: int) -> int: - for cell in grid.get_cells(): - if grid.get_value(cell) == "S": - start = cell - if grid.get_value(cell) == "E": - end = cell - time, path = bfs( + start = next( + cell for cell in grid.get_cells() if grid.get_value(cell) == "S" + ) + distances, _ = bfs_full( start, - lambda cell: cell == end, + lambda cell: grid.get_value(cell) != "#", lambda cell: ( n for n in grid.get_capital_neighbours(cell) if grid.get_value(n) != "#" ), ) - + dist = {(k.row, k.col): v for k, v in distances.items()} ans = 0 - for i1 in range(len(path) - cheat_len): - for i2 in range(i1 + cheat_len, len(path)): - p1, p2 = path[i1], path[i2] - md = abs(p1.row - p2.row) + abs(p1.col - p2.col) - if md <= cheat_len and i2 - i1 - md >= target: - ans += 1 + for r, c in dist.keys(): + for md in range(2, cheat_len + 1): + for dr in range(md + 1): + dc = md - dr + for rr, cc in { + (r + dr, c + dc), + (r + dr, c - dc), + (r - dr, c + dc), + (r - dr, c - dc), + }: + if (rr, cc) not in dist: + continue + if dist[(rr, cc)] - dist[(r, c)] >= target + md: + ans += 1 return ans def part_1(self, grid: Input) -> Output1: diff --git a/src/main/python/aoc/graph.py b/src/main/python/aoc/graph.py index 5847820..d908160 100644 --- a/src/main/python/aoc/graph.py +++ b/src/main/python/aoc/graph.py @@ -71,6 +71,30 @@ def bfs( raise RuntimeError("unsolvable") +def bfs_full( + start: T, + is_end: Callable[[T], bool], + adjacent: Callable[[T], Iterator[T]], +) -> tuple[dict[T, int], dict[T, T]]: + q: deque[tuple[int, T]] = deque() + q.append((0, start)) + seen: set[T] = set() + seen.add(start) + parent: dict[T, T] = {} + dists = defaultdict[T, int](int) + while not len(q) == 0: + distance, node = q.popleft() + if is_end(node): + dists[node] = distance + for n in adjacent(node): + if n in seen: + continue + seen.add(n) + parent[n] = node + q.append((distance + 1, n)) + return dists, parent + + def flood_fill( start: T, adjacent: Callable[[T], Iterator[T]], diff --git a/src/main/python/aoc/grid.py b/src/main/python/aoc/grid.py index 5abe1ee..4d5ef1b 100644 --- a/src/main/python/aoc/grid.py +++ b/src/main/python/aoc/grid.py @@ -38,6 +38,18 @@ def to(self, other: Cell) -> Direction: return Direction.DOWN if self.row < other.row else Direction.UP raise ValueError("not supported") + def get_all_at_manhattan_distance(self, distance: int) -> Iterator[Cell]: + r, c = self.row, self.col + for dr in range(distance + 1): + dc = distance - dr + for rr, cc in { + (r + dr, c + dc), + (r + dr, c - dc), + (r - dr, c + dc), + (r - dr, c - dc), + }: + yield Cell(rr, cc) + @unique class IterDir(Enum): diff --git a/src/test/python/test_grid.py b/src/test/python/test_grid.py index e7faaa7..f1b16ae 100644 --- a/src/test/python/test_grid.py +++ b/src/test/python/test_grid.py @@ -186,3 +186,50 @@ def test_merge(self) -> None: "333333333", ], ) + + +class CellTest(unittest.TestCase): + def test_get_all_at_manhattan_distance_1(self) -> None: + cell = Cell(0, 0) + + ans = {n for n in cell.get_all_at_manhattan_distance(1)} + + self.assertTrue(len(ans) == 4) + + def test_get_all_at_manhattan_distance_2(self) -> None: + cell = Cell(0, 0) + + ans = {n for n in cell.get_all_at_manhattan_distance(2)} + + self.assertTrue(len(ans) == 8) + self.assertTrue(Cell(-2, 0) in ans) + self.assertTrue(Cell(-1, 1) in ans) + self.assertTrue(Cell(0, 2) in ans) + self.assertTrue(Cell(1, 1) in ans) + self.assertTrue(Cell(2, 0) in ans) + self.assertTrue(Cell(1, -1) in ans) + self.assertTrue(Cell(0, -2) in ans) + self.assertTrue(Cell(-1, -1) in ans) + + def test_get_all_at_manhattan_distance_3(self) -> None: + cell = Cell(0, 0) + + ans = {n for n in cell.get_all_at_manhattan_distance(3)} + + self.assertTrue(len(ans) == 12) + self.assertTrue(Cell(-3, 0) in ans) + self.assertTrue(Cell(-2, 1) in ans) + self.assertTrue(Cell(-1, 2) in ans) + self.assertTrue(Cell(0, 3) in ans) + self.assertTrue(Cell(1, 2) in ans) + self.assertTrue(Cell(2, 1) in ans) + self.assertTrue(Cell(3, 0) in ans) + self.assertTrue(Cell(2, -1) in ans) + self.assertTrue(Cell(1, -2) in ans) + self.assertTrue(Cell(0, -3) in ans) + self.assertTrue(Cell(-1, -2) in ans) + self.assertTrue(Cell(-2, -1) in ans) + + +if __name__ == '__main__': + unittest.main()