Skip to content

Commit

Permalink
Merge pull request #19 from QuantumApplicationLab/migrate_to_estimato…
Browse files Browse the repository at this point in the history
…r_v2

Add compatibility with EstimatorV2 (`qiskit-aer` and `qiskit-ibm-runtime`)
  • Loading branch information
Cmurilochem authored Jul 17, 2024
2 parents 1f05569 + 1660ba2 commit c3bd0c4
Show file tree
Hide file tree
Showing 8 changed files with 381 additions and 45 deletions.
24 changes: 11 additions & 13 deletions tests/test_vqls.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,13 @@

from qiskit import QuantumCircuit
from qiskit.circuit.library import RealAmplitudes
from qiskit.primitives import Estimator, Sampler

from qiskit_algorithms.optimizers import ADAM
from qiskit.primitives import Estimator, Sampler

from qiskit_aer.primitives import EstimatorV2 as aer_EstimatorV2
from qiskit_aer.primitives import SamplerV2 as aer_SamplerV2

from vqls_prototype import VQLS

# 8-11-2023
Expand All @@ -43,11 +47,13 @@ def setUp(self):

self.estimators = (
Estimator(),
aer_EstimatorV2(),
# AerEstimator(),
)

self.samplers = (
Sampler(),
aer_SamplerV2(),
# AerSampler(),
)

Expand All @@ -66,12 +72,8 @@ def test_numpy_input(self):
rhs = np.array([0.1] * 4)
ansatz = RealAmplitudes(num_qubits=2, reps=3, entanglement="full")

for iprim, (estimator, sampler) in enumerate(
zip(self.estimators, self.samplers)
):
for iopt, opt in enumerate(self.options):
if iprim == 1 and iopt == 2:
continue
for _, (estimator, sampler) in enumerate(zip(self.estimators, self.samplers)):
for _, opt in enumerate(self.options):
vqls = VQLS(
estimator,
ansatz,
Expand Down Expand Up @@ -101,12 +103,8 @@ def test_circuit_input_statevector(self):
qc2.x(1)
qc2.cx(0, 1)

for iprim, (estimator, sampler) in enumerate(
zip(self.estimators, self.samplers)
):
for iopt, opt in enumerate(self.options):
if iprim == 1 and iopt == 2:
continue
for _, (estimator, sampler) in enumerate(zip(self.estimators, self.samplers)):
for _, opt in enumerate(self.options):
vqls = VQLS(
estimator,
ansatz,
Expand Down
48 changes: 37 additions & 11 deletions vqls_prototype/hadamard_test/direct_hadamard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import numpy as np
import numpy.typing as npt

from qiskit.primitives.sampler import SamplerResult
from qiskit.primitives.containers import PrimitiveResult
from vqls_prototype.primitives_run_builder import SamplerRunBuilder


class BatchDirectHadammardTest:
r"""Class that execute batches of Hadammard Test"""
Expand Down Expand Up @@ -31,14 +35,18 @@ def get_values(self, sampler, parameter_sets: List, zne_strategy=None) -> List:
"""

ncircuits = len(self.circuits)
all_parameter_sets = [parameter_sets] * ncircuits

sampler_run_builder = SamplerRunBuilder(
sampler,
self.circuits,
all_parameter_sets,
options={"shots": self.shots},
)

try:
if zne_strategy is None:
job = sampler.run(
self.circuits,
[parameter_sets] * ncircuits,
shots=self.shots,
)
job = sampler_run_builder.build_run()
else:
job = sampler.run(
self.circuits,
Expand Down Expand Up @@ -133,8 +141,24 @@ def post_processing(self, sampler_result) -> npt.NDArray[np.cdouble]:
Returns:
List: value of the overlap hadammard test
"""
if isinstance(sampler_result, SamplerResult):
quasi_dist = sampler_result.quasi_dists

elif isinstance(sampler_result, PrimitiveResult):
quasi_dist = [
{
key: value / result.data.meas.num_shots
for key, value in result.data.meas.get_int_counts().items()
}
for result in sampler_result
]

else:
raise NotImplementedError(
f"Cannot post processing for {type(sampler_result)} type class."
f"Please, refer to {self.__class__.__name__}.post_processing()."
)

quasi_dist = sampler_result.quasi_dists
val = []
for qdist in quasi_dist:
# add missing keys
Expand All @@ -158,14 +182,16 @@ def get_value(
Returns:
List: value of the test
"""
sampler_run_builder = SamplerRunBuilder(
sampler,
self.circuits,
parameter_sets,
options={"shots": self.shots},
)

try:
if zne_strategy is None:
job = sampler.run(
self.circuits,
parameter_sets,
shots=self.shots,
)
job = sampler_run_builder.build_run()
else:
job = sampler.run(
self.circuits,
Expand Down
46 changes: 39 additions & 7 deletions vqls_prototype/hadamard_test/hadamard_overlap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import numpy as np
import numpy.typing as npt

from qiskit.primitives.sampler import SamplerResult
from qiskit.primitives.containers import PrimitiveResult
from vqls_prototype.primitives_run_builder import SamplerRunBuilder


class BatchHadammardOverlapTest:
r"""Class that execute batches of Hadammard Test"""
Expand Down Expand Up @@ -31,14 +35,18 @@ def get_values(self, sampler, parameter_sets: List, zne_strategy=None) -> List:
"""

ncircuits = len(self.circuits)
all_parameter_sets = [parameter_sets] * ncircuits

sampler_run_builder = SamplerRunBuilder(
sampler,
self.circuits,
all_parameter_sets,
options={"shots": self.shots},
)

try:
if zne_strategy is None:
job = sampler.run(
self.circuits,
[parameter_sets] * ncircuits,
shots=self.shots,
)
job = sampler_run_builder.build_run()
else:
job = sampler.run(
self.circuits,
Expand Down Expand Up @@ -240,8 +248,24 @@ def post_processing(self, sampler_result) -> npt.NDArray[np.cdouble]:
Returns:
List: value of the overlap hadammard test
"""
if isinstance(sampler_result, SamplerResult):
quasi_dist = sampler_result.quasi_dists

elif isinstance(sampler_result, PrimitiveResult):
quasi_dist = [
{
key: value / result.data.meas.num_shots
for key, value in result.data.meas.get_int_counts().items()
}
for result in sampler_result
]

else:
raise NotImplementedError(
f"Cannot post processing for {type(sampler_result)} type class."
f"Please, refer to {self.__class__.__name__}.post_processing()."
)

quasi_dist = sampler_result.quasi_dists
output = []

for qdist in quasi_dist:
Expand Down Expand Up @@ -269,7 +293,15 @@ def get_value(self, sampler, parameter_sets: List) -> float:
float: value of the overlap hadammard test
"""
ncircuits = len(self.circuits)
job = sampler.run(self.circuits, [parameter_sets] * ncircuits, shots=self.shots)
all_parameter_sets = [parameter_sets] * ncircuits

sampler_run_builder = SamplerRunBuilder(
sampler,
self.circuits,
all_parameter_sets,
options={"shots": self.shots},
)
job = sampler_run_builder.build_run()
results = self.post_processing(job.result())

results *= np.array([1.0, 1.0j])
Expand Down
52 changes: 38 additions & 14 deletions vqls_prototype/hadamard_test/hadamard_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import numpy as np
import numpy.typing as npt

from qiskit.primitives.estimator import EstimatorResult
from qiskit.primitives.containers import PrimitiveResult
from vqls_prototype.primitives_run_builder import EstimatorRunBuilder


class BatchHadammardTest:
r"""Class that execute batches of Hadammard Test"""
Expand Down Expand Up @@ -36,15 +40,19 @@ def get_values(self, primitive, parameter_sets: List, zne_strategy=None) -> List
"""

ncircuits = len(self.circuits)
all_parameter_sets = [parameter_sets] * ncircuits

estimator_run_builder = EstimatorRunBuilder(
primitive,
self.circuits,
self.observable,
all_parameter_sets,
options={"shots": self.shots},
)

try:
if zne_strategy is None:
job = primitive.run(
self.circuits,
self.observable,
[parameter_sets] * ncircuits,
shots=self.shots,
)
job = estimator_run_builder.build_run()
else:
job = primitive.run(
self.circuits,
Expand Down Expand Up @@ -236,8 +244,19 @@ def post_processing(self, estimator_result) -> npt.NDArray[np.cdouble]:
Returns:
npt.NDArray[np.cdouble]: value of the test
"""
return np.array([1.0 - 2.0 * val for val in estimator_result.values]).astype(
"complex128"
if isinstance(estimator_result, EstimatorResult):
return np.array(
[1.0 - 2.0 * val for val in estimator_result.values]
).astype("complex128")

if isinstance(estimator_result, PrimitiveResult):
return np.array(
[1.0 - 2.0 * val.data.evs for val in estimator_result]
).astype("complex128")

raise NotImplementedError(
f"Cannot post processing for {type(estimator_result)} type class."
f"Please, refer to {self.__class__.__name__}.post_processing()."
)

def get_value(self, estimator, parameter_sets: List, zne_strategy=None) -> List:
Expand All @@ -252,15 +271,20 @@ def get_value(self, estimator, parameter_sets: List, zne_strategy=None) -> List:
"""

ncircuits = len(self.circuits)
all_parameter_sets = [parameter_sets] * ncircuits
all_observables = [self.observable] * ncircuits

estimator_run_builder = EstimatorRunBuilder(
estimator,
self.circuits,
all_observables,
all_parameter_sets,
options={"shots": self.shots},
)

try:
if zne_strategy is None:
job = estimator.run(
self.circuits,
[self.observable] * ncircuits,
[parameter_sets] * ncircuits,
shots=self.shots,
)
job = estimator_run_builder.build_run()
else:
job = estimator.run(
self.circuits,
Expand Down
10 changes: 10 additions & 0 deletions vqls_prototype/primitives_run_builder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Primitives builder package."""

from .estimator_run_builder import EstimatorRunBuilder
from .sampler_run_builder import SamplerRunBuilder


__all__ = [
"EstimatorRunBuilder",
"SamplerRunBuilder",
]
78 changes: 78 additions & 0 deletions vqls_prototype/primitives_run_builder/base_run_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""This module defines a base class for primitive run builders."""

from typing import Union, List, Tuple, Dict, Any
from qiskit import QuantumCircuit
from qiskit.primitives import PrimitiveJob
from qiskit_ibm_runtime import RuntimeJobV2


class BasePrimitiveRunBuilder:
"""
Base class for building and configuring primitive runs based on their provenance and options.
"""

def __init__(
self,
primitive,
circuits: List[QuantumCircuit],
parameter_sets: List[List[float]],
options: Dict[str, Any],
):
"""
Initializes BasePrimitiveRunBuilder for given primitive, circuits, parameters, and options.
Args:
primitive (Union[SamplerValidType, EstimatorValidType]): The primitive to use for runs.
circuits (List[QuantumCircuit]): The quantum circuits to run.
parameter_sets (List[List[float]]): The parameters to vary in the circuits.
options (Dict[str, Any]): Configuration options such as number of shots.
"""
self.primitive = primitive
self.circuits = circuits
self.parameter_sets = parameter_sets
self.shots = options.pop("shots", None)
self.seed = options.pop("seed", None)
self.provenance = self.find_provenance()

def find_provenance(self) -> Tuple[str, str]:
"""Determines the provenance of the primitive based on its class and module."""
return (
self.primitive.__class__.__module__.split(".")[0],
self.primitive.__class__.__name__,
)

def build_run(self) -> Union[PrimitiveJob, RuntimeJobV2]:
"""
Configures and returns primitive runs based on its provenance.
Raises:
NotImplementedError: If the primitive's provenance is not supported.
Returns:
Union[PrimitiveJob, RuntimeJobV2]: A primitive job.
"""
primitive_job = self._select_run_builder()
return primitive_job()

def _select_run_builder(self) -> Union[PrimitiveJob, RuntimeJobV2]:
"""Selects the appropriate builder function based on the primitive's provenance."""
raise NotImplementedError("This method should be implemented by subclasses.")

def _build_native_qiskit_run(self) -> PrimitiveJob:
"""Builds a run function for a standard qiskit primitive."""
raise NotImplementedError("This method should be implemented by subclasses.")

def _build_v2_run(self) -> Union[PrimitiveJob, RuntimeJobV2]:
"""Builds a run function for qiskit-aer and qiskit-ibm-runtime V2 primitives."""
raise NotImplementedError("This method should be implemented by subclasses.")

def _build_v1_run(self):
"""
Attempts to build a run function for primitives V1, which will be soon deprecated.
Raises:
NotImplementedError: Indicates that V1 will be soon deprecated.
"""
raise NotImplementedError(
"Primitives V1 will be soon deprecated. Please, use V2 implementation."
)
Loading

0 comments on commit c3bd0c4

Please sign in to comment.