Skip to content

Commit

Permalink
Add wrapper templates
Browse files Browse the repository at this point in the history
  • Loading branch information
S-Linde committed Nov 29, 2024
1 parent 823c623 commit 820dcec
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 74 deletions.
2 changes: 1 addition & 1 deletion qgym/envs/routing/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
class Routing(Environment[Dict[str, NDArray[np.int_]], int]):
"""RL environment for the routing problem of OpenQL."""

def __init__( # pylint: disable=too-many-arguments
def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
connection_graph: nx.Graph | ArrayLike | Gridspecs,
interaction_generator: InteractionGenerator | None = None,
Expand Down
3 changes: 2 additions & 1 deletion qgym/templates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
from qgym.templates.rewarder import Rewarder
from qgym.templates.state import State
from qgym.templates.visualiser import Visualiser
from qgym.templates.wrappers import AgentWrapper

__all__ = ["Environment", "Rewarder", "State", "Visualiser"]
__all__ = ["AgentWrapper", "Environment", "Rewarder", "State", "Visualiser"]
9 changes: 9 additions & 0 deletions qgym/templates/pass_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,12 @@ class Mapper(Protocol):
@abstractmethod
def compute_mapping(self, circuit: QuantumCircuit | DAGCircuit) -> NDArray[np.int_]:
"""Compute a mapping for a provided quantum `circuit`."""


@runtime_checkable
class Router(Protocol):
"""Qubit router protocol."""

@abstractmethod
def compute_routing(self, circuit: QuantumCircuit | DAGCircuit) -> DAGCircuit:
"""Compute a qubit routing for a provided quantum `circuit`."""
108 changes: 108 additions & 0 deletions qgym/templates/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Generic abstract base class for RL agent wrappers.

Check failure on line 1 in qgym/templates/wrappers.py

View workflow job for this annotation

GitHub Actions / isort

Imports are incorrectly sorted and/or formatted.
All agnet wrappers should inherit from ``AgentWrapper``.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from qgym.utils.qiskit_utils import parse_circuit

if TYPE_CHECKING:
from collections.abc import Mapping

from qiskit import QuantumCircuit
from qiskit.dagcircuit import DAGCircuit
from stable_baselines3.common.base_class import BaseAlgorithm
from qgym.templates.environment import Environment

WrapperOutputT = TypeVar("WrapperOutputT")


class AgentWrapper(ABC, Generic[WrapperOutputT]): # pylint: disable=too-few-public-methods
"""Wrap any trained stable baselines 3 agent that inherits from
:class:`~stable_baselines3.common.base_class.BaseAlgorithm`.
"""

def __init__(
self,
agent: BaseAlgorithm,
env: Environment[Any, Any],
max_steps: int = 1000,
*,
use_action_masking: bool = False,
) -> None:
"""Init of the :class:`AgentWrapper`.
Args:
agent: agent trained on the provided environment.
env: environment the agent was trained on.
max_steps: maximum number steps the `agent` can take to complete an episode.
use_action_masking: If ``True`` it is assumed that action masking was used
during training. The `env` should then have a `action_masks` method
and the `predict` method of `agent` should accept the keyword argument
`"action_masks"`. If ``False`` (default) no action masking is used.
"""
self.agent = agent
self.env = env
self.max_steps = max_steps
self.use_action_masking = use_action_masking
if self.use_action_masking and not hasattr(self.env, "action_masks"):
msg = "use_action_mask is True, but env has no action_masks attribute"
raise TypeError(msg)

@abstractmethod
def _prepare_episode(self, circuit: DAGCircuit) -> Mapping[str, Any]:
"""Prepare the episode options with the information from the provided circuit.
Args:
circuit: Quantum circuit to extract the episode information from.
Returns:
Mapping containing the options that should be provided to
``self.env.reset``.
"""

def _run_epsiode(self, options: Mapping[str, Any]) -> None:
"""Run an episode with the provided options.
Args:
options: Mapping to provide to the options argument of ``self.env.reset``.
"""
obs, _ = self.env.reset(options=options)

predict_kwargs = {"observation": obs}
for _ in range(self.max_steps):
if self.use_action_masking:
action_masks = self.env.action_masks() # type: ignore[attr-defined]
predict_kwargs["action_masks"] = action_masks

action, _ = self.agent.predict(**predict_kwargs)
predict_kwargs["observation"], _, done, _, _ = self.env.step(action)
if done:
break

@abstractmethod
def _postprocess_episode(self, circuit: DAGCircuit) -> WrapperOutputT:
"""Postprocess the epsiode.
Extract the useful information from ``self.env`` and do something with it.
"""

def run(self, circuit: QuantumCircuit | DAGCircuit) -> WrapperOutputT:
"""Prepare, run and postprocess an episode.
Output is based on the provided agent, env and circuit combination.
Args:
circuit: Quantum circuit to run the episode for.
Returns:
Some useful information extracted from the episode.
"""
circuit = parse_circuit(circuit)
options = self._prepare_episode(circuit)
self._run_epsiode(options)
return self._postprocess_episode(circuit)
2 changes: 1 addition & 1 deletion qgym/utils/qiskit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from collections import deque
from collections.abc import Hashable, Iterable
from collections.abc import Iterable
from copy import deepcopy

import networkx as nx
Expand Down
8 changes: 7 additions & 1 deletion qgym/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""This subpackage contains wrapper classes."""

from qgym.wrappers.initial_mapping import AgentMapperWrapper, QiskitMapperWrapper
from qgym.wrappers.routing import AgentRoutingWrapper, QiskitRoutingWrapper

__all__ = ["AgentMapperWrapper", "QiskitMapperWrapper"]
__all__ = [
"AgentMapperWrapper",
"QiskitMapperWrapper",
"AgentRoutingWrapper",
"QiskitRoutingWrapper",
]
59 changes: 26 additions & 33 deletions qgym/wrappers/initial_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,34 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import numpy as np
from numpy.typing import NDArray
from qiskit import QuantumCircuit
from qiskit.dagcircuit import DAGCircuit
from qiskit.transpiler import AnalysisPass, Layout

from qgym.templates import AgentWrapper
from qgym.utils.qiskit_utils import get_interaction_graph, parse_circuit
from qgym.envs.initial_mapping import InitialMappingState

if TYPE_CHECKING:
import networkx as nx
from stable_baselines3.common.base_class import BaseAlgorithm

from qgym.envs import InitialMapping
from qgym.envs.initial_mapping import InitialMapping


class AgentMapperWrapper: # pylint: disable=too-few-public-methods
class AgentMapperWrapper(AgentWrapper[NDArray[np.int_]]): # pylint: disable=too-few-public-methods
"""Wrap any trained stable baselines 3 agent that inherits from
:class:`~stable_baselines3.common.base_class.BaseAlgorithm`.
The wrapper makes sure the agent upholds the Mapper protocol , which is required for
the qgym benchmarking tools.
"""

def __init__(
def __init__( # pylint: disable=useless-parent-delegation
self,
agent: BaseAlgorithm,
env: InitialMapping,
Expand All @@ -47,46 +50,36 @@ def __init__(
and the `predict` method of `agent` should accept the keyword argument
`"action_masks"`. If ``False`` (default) no action masking is used.
"""
self.agent = agent
self.env = env
self.max_steps = max_steps
self.use_action_masking = use_action_masking
if self.use_action_masking and not hasattr(self.env, "action_masks"):
msg = "use_action_mask is True, but env has no action_masks attribute"
raise TypeError(msg)
super().__init__(agent, env, max_steps, use_action_masking=use_action_masking)

def _prepare_episode(self, circuit: DAGCircuit) -> dict[str, nx.Graph]:
"""Extract the interaction graph from `circuit`."""
interaction_graph = get_interaction_graph(circuit)
return {"interaction_graph": interaction_graph}

def _postprocess_episode(self, circuit: DAGCircuit) -> NDArray[np.int_]: # pylint: disable=unused-argument
state = cast(InitialMappingState, self.env._state) # pylint: disable=protected-access
if not state.is_done():
msg = (
"mapping not found, "
"the episode was truncated or 'max_steps' was reached"
)
raise ValueError(msg)
return state.mapping

def compute_mapping(self, circuit: QuantumCircuit | DAGCircuit) -> NDArray[np.int_]:
"""Compute a mapping of the `circuit` using the provided `agent` and `env`.
Alias for ``run``.
Args:
circuit: Quantum circuit to map.
Returns:
Array of which the index represents a physical qubit, and the value a
virtual qubit.
"""
interaction_graph = get_interaction_graph(circuit)
obs, _ = self.env.reset(options={"interaction_graph": interaction_graph})

predict_kwargs = {"observation": obs}
for _ in range(self.max_steps):
if self.use_action_masking:
action_masks = self.env.action_masks() # type: ignore[attr-defined]
predict_kwargs["action_masks"] = action_masks

action, _ = self.agent.predict(**predict_kwargs)
predict_kwargs["observation"], _, done, _, _ = self.env.step(action)
if done:
break

if not done:
msg = (
"mapping not found, "
"the episode was truncated or 'max_steps' was reached"
)
raise ValueError(msg)

return obs["mapping"]
return self.run(circuit)


class QiskitMapperWrapper:
Expand Down
65 changes: 28 additions & 37 deletions qgym/wrappers/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

from qiskit import QuantumCircuit
from qiskit.dagcircuit import DAGCircuit
Expand All @@ -13,22 +13,25 @@
insert_swaps_in_circuit,
parse_circuit,
)
from qgym.envs.routing import RoutingState
from qgym.templates.wrappers import AgentWrapper

if TYPE_CHECKING:
import numpy as np
from numpy.typing import NDArray
from stable_baselines3.common.base_class import BaseAlgorithm
from qgym.envs.routing import Routing

from qgym.envs import Routing


class AgentRoutingWrapper: # pylint: disable=too-few-public-methods
class AgentRoutingWrapper(AgentWrapper[DAGCircuit]): # pylint: disable=too-few-public-methods
"""Wrap any trained stable baselines 3 agent that inherits from
:class:`~stable_baselines3.common.base_class.BaseAlgorithm`.
The wrapper makes sure the agent upholds the QubitRouting protocol , which is
required for the qgym benchmarking tools.
"""

def __init__(
def __init__( # pylint: disable=useless-parent-delegation
self,
agent: BaseAlgorithm,
env: Routing,
Expand All @@ -49,13 +52,25 @@ def __init__(
and the `predict` method of `agent` should accept the keyword argument
`"action_masks"`. If ``False`` (default) no action masking is used.
"""
self.agent = agent
self.env = env
self.max_steps = max_steps
self.use_action_masking = use_action_masking
if self.use_action_masking and not hasattr(self.env, "action_masks"):
msg = "use_action_mask is True, but env has no action_masks attribute"
raise TypeError(msg)
super().__init__(agent, env, max_steps, use_action_masking=use_action_masking)

def _prepare_episode(
self, circuit: QuantumCircuit | DAGCircuit
) -> dict[str, NDArray[np.int_]]:
"""Extract the interaction circuit from `circuit`."""
interaction_circuit = get_interaction_circuit(circuit)
return {"interaction_circuit": interaction_circuit}

def _postprocess_episode(self, circuit: DAGCircuit) -> DAGCircuit:
"""Route `circuit` based on the findings of the current episode."""
state = cast(RoutingState, self.env._state) # pylint: disable=protected-access
if not state.is_done():
msg = (
"routing not found, "
"the episode was truncated or 'max_steps' was reached"
)
raise ValueError(msg)
return insert_swaps_in_circuit(circuit, state.swap_gates_inserted)

def compute_routing(self, circuit: QuantumCircuit | DAGCircuit) -> DAGCircuit:
"""Route the `circuit` using the provided `agent` and `env`.
Expand All @@ -69,31 +84,7 @@ def compute_routing(self, circuit: QuantumCircuit | DAGCircuit) -> DAGCircuit:
Routed circuit, i.e. a quantum circuit that only contains two qubit gates
between qubits that are part of the connection graph.
"""
interaction_circuit = get_interaction_circuit(circuit)
obs, _ = self.env.reset(options={"interaction_circuit": interaction_circuit})

predict_kwargs = {"observation": obs}
for _ in range(self.max_steps):
if self.use_action_masking:
action_masks = self.env.action_masks() # type: ignore[attr-defined]
predict_kwargs["action_masks"] = action_masks

action, _ = self.agent.predict(**predict_kwargs)
predict_kwargs["observation"], _, done, _, _ = self.env.step(action)
if done:
break

if not done:
msg = (
"routing not found, "
"the episode was truncated or 'max_steps' was reached"
)
raise ValueError(msg)

return insert_swaps_in_circuit(
circuit,
self.env._state.swap_gates_inserted, # type: ignore[attr-defined]
)
return self.run(circuit)


class QiskitRoutingWrapper:
Expand Down

0 comments on commit 820dcec

Please sign in to comment.