diff --git a/algobattle/types.py b/algobattle/types.py index 43f59c42..1df1cd6c 100644 --- a/algobattle/types.py +++ b/algobattle/types.py @@ -1,5 +1,6 @@ """Utility types used to easily define Problems.""" from dataclasses import dataclass +from functools import cache, cached_property from sys import float_info from typing import ( Annotated, @@ -24,6 +25,7 @@ SupportsLt, SupportsMod, ) +from itertools import pairwise from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler from pydantic.json_schema import JsonSchemaValue @@ -62,6 +64,7 @@ "DirectedGraph", "UndirectedGraph", "Edge", + "Path", "EdgeLen", "EdgeWeights", "VertexWeights", @@ -411,6 +414,24 @@ def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHa # * Graph classes +Vertex = SizeIndex +"""Type for vertices, encoded as numbers `0 <= v < instance.num_vertices`.""" + + +Edge = Annotated[int, IndexInto[InstanceRef.edges]] +"""Type for edges, encoded as indices into `instance.edges`.""" + + +def path_in_graph(path: list[Vertex], edge_set: set[tuple[Vertex, Vertex]]): + """Checks that a path actually exists in the graph.""" + for edge in pairwise(path): + if edge not in edge_set: + raise ValueError(f"The edge {edge} does not exist in the graph.") + + +Path = Annotated[list[Vertex], AttributeReferenceValidator(path_in_graph, InstanceRef.edge_set)] + + class DirectedGraph(InstanceModel): """Base instance class for problems on directed graphs.""" @@ -422,6 +443,21 @@ def size(self) -> int: """A graph's size is the number of vertices in it.""" return self.num_vertices + @cached_property + def edge_set(self) -> set[tuple[Vertex, Vertex]]: + """The set of edges in this graph.""" + return set(self.edges) + + @cache + def neighbors(self, vertex: Vertex, direction: Literal["all", "outgoing", "incoming"] = "all") -> set[Vertex]: + """The neighbors of a vertex.""" + res = set[Vertex]() + if direction in {"all", "outgoing"}: + res |= set(v for (u, v) in self.edges if u == vertex) + if direction in {"all", "incoming"}: + res |= set(v for (v, u) in self.edges if u == vertex) + return res + class UndirectedGraph(DirectedGraph): """Base instance class for problems on undirected graphs.""" @@ -440,13 +476,20 @@ def validate_instance(self): if any(edge[::-1] in edge_set for edge in self.edges): raise ValidationError("Undirected graph contains back and forth edges between two vertices.") + @cached_property + def edge_set(self) -> set[tuple[Vertex, Vertex]]: + """The set of edges in this graph. -Vertex = SizeIndex -"""Type for vertices, encoded as numbers `0 <= v < instance.num_vertices`.""" + Normalized to contain every edge in both directions. + """ + return set(self.edges) | set((v, u) for (u, v) in self.edges) + @cache + def neighbors(self, vertex: Vertex, direction: Literal["all", "outgoing", "incoming"] = "all") -> set[Vertex]: + """The neighbors of a vertex.""" + # more efficient specialization -Edge = IndexInto[InstanceRef.edges] -"""Type for edges, encoded as indices into `instance.edges`.""" + return set(v for (u, v) in self.edge_set if u == vertex) class EdgeLen: @@ -477,12 +520,42 @@ class EdgeWeights(DirectedGraph, BaseModel, Generic[Weight]): edge_weights: Annotated[list[Weight], EdgeLen] + @cached_property + def edges_with_weights(self) -> Iterator[tuple[tuple[Vertex, Vertex], Weight]]: + """Iterate over all edges and their weights.""" + return zip(self.edges, self.edge_weights) + + @cache + def weight(self, edge: Edge | tuple[Vertex, Vertex]) -> Weight: + """Returns the weight of an edge. + + Raises KeyError if the given edge does not exist. + """ + if isinstance(edge, tuple): + try: + edge = self.edges.index(edge) + except ValueError: + if isinstance(self, UndirectedGraph): + try: + edge = self.edges.index((edge[1], edge[0])) + except ValueError: + raise KeyError + else: + raise KeyError + + return self.edge_weights[edge] + class VertexWeights(DirectedGraph, BaseModel, Generic[Weight]): """Mixin for graphs with weighted vertices.""" vertex_weights: Annotated[list[Weight], SizeLen] + @cached_property + def vertices_with_weights(self) -> Iterator[tuple[Vertex, Weight]]: + """Iterate over all edges and their weights.""" + return enumerate(self.vertex_weights) + @dataclass(frozen=True, slots=True) class LaxComp: diff --git a/docs/instructor/problem/advanced.md b/docs/instructor/problem/advanced.md index 26327851..29379a15 100644 --- a/docs/instructor/problem/advanced.md +++ b/docs/instructor/problem/advanced.md @@ -172,7 +172,7 @@ directionless. Both graph's size is the number of vertices in it. !!! tip "Associated Annotation Types" As you can see in the example above, we also provide several types that are useful in type annotations of graph - problems such as `Vertex` or `Edge`. These are documented in more detail in the + problems such as `Vertex`, `Edge`, or `Path`. How these function is explained in more detail in the [advanced annotations](annotations.md) section. If you want the problem instance to also contain additional information associated with each vertex and/or each edge @@ -201,6 +201,11 @@ indexed with the type of the weights you want to use. ... ``` +!!! tip + These classes also contain some utility methods to easily perform common graph operations. For example, + `UndirectedGraph.edge_set` contains all edges in both directions, and the `neighbors` methods lets you quickly + access a vertex's neighbours. + ## Comparing Floats !!! abstract