From 0b2b615c80285ffb913a32dfaaceab04489e2255 Mon Sep 17 00:00:00 2001 From: Stan van der Linde Date: Tue, 7 Jan 2025 14:59:14 +0100 Subject: [PATCH] Fix tests --- qgym/envs/routing/routing.py | 2 ++ qgym/envs/routing/routing_state.py | 9 +++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/qgym/envs/routing/routing.py b/qgym/envs/routing/routing.py index bc8ea27..067e4a1 100644 --- a/qgym/envs/routing/routing.py +++ b/qgym/envs/routing/routing.py @@ -116,6 +116,8 @@ class Routing(Environment[dict[str, NDArray[np.int_]], int]): """RL environment for the routing problem of OpenQL.""" + _state: RoutingState # type: ignore[assignment] + def __init__( # noqa: PLR0913 self, connection_graph: nx.Graph | ArrayLike | Gridspecs, diff --git a/qgym/envs/routing/routing_state.py b/qgym/envs/routing/routing_state.py index 6133e97..2d268da 100644 --- a/qgym/envs/routing/routing_state.py +++ b/qgym/envs/routing/routing_state.py @@ -21,7 +21,7 @@ from collections import deque from itertools import starmap -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union import networkx as nx import numpy as np @@ -38,7 +38,7 @@ # pylint: disable=too-many-instance-attributes -class RoutingState(State[dict[str, NDArray[np.int_]], int]): +class RoutingState(State[dict[str, Union[NDArray[np.int_], NDArray[np.int8]]], int]): """The :class:`RoutingState` class.""" def __init__( # pylint: disable=too-many-arguments @@ -263,7 +263,7 @@ def create_observation_space(self) -> qgym.spaces.Dict: def obtain_observation( self, - ) -> dict[str, NDArray[np.int_]]: + ) -> dict[str, NDArray[np.int_] | NDArray[np.int8]]: """Observe the current state. Returns: @@ -279,6 +279,7 @@ def obtain_observation( constant_values=self.n_qubits, ) + observation: dict[str, NDArray[np.int_] | NDArray[np.int8]] observation = { "interaction_gates_ahead": interaction_gates_ahead.flatten(), "mapping": self.mapping, @@ -291,7 +292,7 @@ def obtain_observation( is_legal_surpass = np.fromiter( iter=starmap(self.is_legal_surpass, interaction_gates_ahead), count=len(interaction_gates_ahead), - dtype=np.int_, + dtype=np.int8, ) observation["is_legal_surpass"] = is_legal_surpass