Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
S-Linde committed Jan 7, 2025
1 parent 77d2727 commit 0b2b615
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 2 additions & 0 deletions qgym/envs/routing/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@
class Routing(Environment[dict[str, NDArray[np.int_]], int]):
"""RL environment for the routing problem of OpenQL."""

_state: RoutingState # type: ignore[assignment]

def __init__( # noqa: PLR0913
self,
connection_graph: nx.Graph | ArrayLike | Gridspecs,
Expand Down
9 changes: 5 additions & 4 deletions qgym/envs/routing/routing_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from collections import deque
from itertools import starmap
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Union

import networkx as nx
import numpy as np
Expand All @@ -38,7 +38,7 @@
# pylint: disable=too-many-instance-attributes


class RoutingState(State[dict[str, NDArray[np.int_]], int]):
class RoutingState(State[dict[str, Union[NDArray[np.int_], NDArray[np.int8]]], int]):
"""The :class:`RoutingState` class."""

def __init__( # pylint: disable=too-many-arguments
Expand Down Expand Up @@ -263,7 +263,7 @@ def create_observation_space(self) -> qgym.spaces.Dict:

def obtain_observation(
self,
) -> dict[str, NDArray[np.int_]]:
) -> dict[str, NDArray[np.int_] | NDArray[np.int8]]:
"""Observe the current state.
Returns:
Expand All @@ -279,6 +279,7 @@ def obtain_observation(
constant_values=self.n_qubits,
)

observation: dict[str, NDArray[np.int_] | NDArray[np.int8]]
observation = {
"interaction_gates_ahead": interaction_gates_ahead.flatten(),
"mapping": self.mapping,
Expand All @@ -291,7 +292,7 @@ def obtain_observation(
is_legal_surpass = np.fromiter(
iter=starmap(self.is_legal_surpass, interaction_gates_ahead),
count=len(interaction_gates_ahead),
dtype=np.int_,
dtype=np.int8,
)
observation["is_legal_surpass"] = is_legal_surpass

Expand Down

0 comments on commit 0b2b615

Please sign in to comment.