Skip to content

Commit

Permalink
AoC 2024 Day 20 - faster
Browse files Browse the repository at this point in the history
  • Loading branch information
pareronia committed Dec 20, 2024
1 parent 2eec489 commit 337f4cd
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 15 deletions.
36 changes: 21 additions & 15 deletions src/main/python/AoC2024_20.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from aoc.common import InputData
from aoc.common import SolutionBase
from aoc.graph import bfs
from aoc.graph import bfs_full
from aoc.grid import CharGrid

Input = CharGrid
Expand Down Expand Up @@ -39,28 +39,34 @@ def parse_input(self, input_data: InputData) -> Input:
return CharGrid.from_strings(list(input_data))

def solve(self, grid: CharGrid, cheat_len: int, target: int) -> int:
for cell in grid.get_cells():
if grid.get_value(cell) == "S":
start = cell
if grid.get_value(cell) == "E":
end = cell
time, path = bfs(
start = next(
cell for cell in grid.get_cells() if grid.get_value(cell) == "S"
)
distances, _ = bfs_full(
start,
lambda cell: cell == end,
lambda cell: grid.get_value(cell) != "#",
lambda cell: (
n
for n in grid.get_capital_neighbours(cell)
if grid.get_value(n) != "#"
),
)

dist = {(k.row, k.col): v for k, v in distances.items()}
ans = 0
for i1 in range(len(path) - cheat_len):
for i2 in range(i1 + cheat_len, len(path)):
p1, p2 = path[i1], path[i2]
md = abs(p1.row - p2.row) + abs(p1.col - p2.col)
if md <= cheat_len and i2 - i1 - md >= target:
ans += 1
for r, c in dist.keys():
for md in range(2, cheat_len + 1):
for dr in range(md + 1):
dc = md - dr
for rr, cc in {
(r + dr, c + dc),
(r + dr, c - dc),
(r - dr, c + dc),
(r - dr, c - dc),
}:
if (rr, cc) not in dist:
continue
if dist[(rr, cc)] - dist[(r, c)] >= target + md:
ans += 1
return ans

def part_1(self, grid: Input) -> Output1:
Expand Down
24 changes: 24 additions & 0 deletions src/main/python/aoc/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,30 @@ def bfs(
raise RuntimeError("unsolvable")


def bfs_full(
start: T,
is_end: Callable[[T], bool],
adjacent: Callable[[T], Iterator[T]],
) -> tuple[dict[T, int], dict[T, T]]:
q: deque[tuple[int, T]] = deque()
q.append((0, start))
seen: set[T] = set()
seen.add(start)
parent: dict[T, T] = {}
dists = defaultdict[T, int](int)
while not len(q) == 0:
distance, node = q.popleft()
if is_end(node):
dists[node] = distance
for n in adjacent(node):
if n in seen:
continue
seen.add(n)
parent[n] = node
q.append((distance + 1, n))
return dists, parent


def flood_fill(
start: T,
adjacent: Callable[[T], Iterator[T]],
Expand Down
12 changes: 12 additions & 0 deletions src/main/python/aoc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ def to(self, other: Cell) -> Direction:
return Direction.DOWN if self.row < other.row else Direction.UP
raise ValueError("not supported")

def get_all_at_manhattan_distance(self, distance: int) -> Iterator[Cell]:
r, c = self.row, self.col
for dr in range(distance + 1):
dc = distance - dr
for rr, cc in {
(r + dr, c + dc),
(r + dr, c - dc),
(r - dr, c + dc),
(r - dr, c - dc),
}:
yield Cell(rr, cc)


@unique
class IterDir(Enum):
Expand Down
47 changes: 47 additions & 0 deletions src/test/python/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,50 @@ def test_merge(self) -> None:
"333333333",
],
)


class CellTest(unittest.TestCase):
def test_get_all_at_manhattan_distance_1(self) -> None:
cell = Cell(0, 0)

ans = {n for n in cell.get_all_at_manhattan_distance(1)}

self.assertTrue(len(ans) == 4)

def test_get_all_at_manhattan_distance_2(self) -> None:
cell = Cell(0, 0)

ans = {n for n in cell.get_all_at_manhattan_distance(2)}

self.assertTrue(len(ans) == 8)
self.assertTrue(Cell(-2, 0) in ans)
self.assertTrue(Cell(-1, 1) in ans)
self.assertTrue(Cell(0, 2) in ans)
self.assertTrue(Cell(1, 1) in ans)
self.assertTrue(Cell(2, 0) in ans)
self.assertTrue(Cell(1, -1) in ans)
self.assertTrue(Cell(0, -2) in ans)
self.assertTrue(Cell(-1, -1) in ans)

def test_get_all_at_manhattan_distance_3(self) -> None:
cell = Cell(0, 0)

ans = {n for n in cell.get_all_at_manhattan_distance(3)}

self.assertTrue(len(ans) == 12)
self.assertTrue(Cell(-3, 0) in ans)
self.assertTrue(Cell(-2, 1) in ans)
self.assertTrue(Cell(-1, 2) in ans)
self.assertTrue(Cell(0, 3) in ans)
self.assertTrue(Cell(1, 2) in ans)
self.assertTrue(Cell(2, 1) in ans)
self.assertTrue(Cell(3, 0) in ans)
self.assertTrue(Cell(2, -1) in ans)
self.assertTrue(Cell(1, -2) in ans)
self.assertTrue(Cell(0, -3) in ans)
self.assertTrue(Cell(-1, -2) in ans)
self.assertTrue(Cell(-2, -1) in ans)


if __name__ == '__main__':
unittest.main()

0 comments on commit 337f4cd

Please sign in to comment.