Skip to content

Commit

Permalink
feat: prep the variance computation on cirq
Browse files Browse the repository at this point in the history
  • Loading branch information
Henri-ColibrITD committed Jun 27, 2024
1 parent 6f8ade3 commit ce6df23
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions mpqp/execution/providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mpqp.execution.devices import GOOGLEDevice
from mpqp.execution.job import Job, JobType
from mpqp.execution.result import Result, Sample, StateVector
from mpqp.tools.generics import flatten


@typechecked
Expand Down Expand Up @@ -114,7 +115,8 @@ def run_local(job: Job) -> Result:
ValueError: If the job device is not GOOGLEDevice.
"""
from cirq.circuits.circuit import Circuit as CirqCircuit
from cirq.ops.linear_combinations import PauliSum as Cirq_PauliSum
from cirq.ops.linear_combinations import PauliSum as CirqPauliSum
from cirq.ops.pauli_string import PauliString as CirqPauliString
from cirq.sim.sparse_simulator import Simulator
from cirq.work.observable_measurement import (
RepetitionsStoppingCriteria,
Expand Down Expand Up @@ -150,7 +152,9 @@ def run_local(job: Job) -> Result:
cirq_obs = job.measure.observable.to_other_language(
language=Language.CIRQ, circuit=cirq_circuit
)
assert type(cirq_obs) == Cirq_PauliSum
assert isinstance(cirq_obs, CirqPauliSum) or isinstance(
cirq_obs, CirqPauliString
)

if job.measure.shots == 0:
result_sim = simulator.simulate_expectation_values(
Expand All @@ -159,11 +163,11 @@ def run_local(job: Job) -> Result:
else:
result_sim = measure_observables(
cirq_circuit,
observables=cirq_obs, # type: ignore[reportArgumentType]
observables=flatten(cirq_obs),
sampler=simulator,
stopping_criteria=RepetitionsStoppingCriteria(job.measure.shots),
)

print(result_sim)
return extract_result_OBSERVABLE(result_sim, job)
else:
raise ValueError(f"Job type {job.job_type} not handled")
Expand Down Expand Up @@ -313,7 +317,7 @@ def extract_result_OBSERVABLE(
raise NotImplementedError("job.measure is None")
for result in results:
if isinstance(result, float):
mean += abs(result)
mean += result
if isinstance(result, ObservableMeasuredResult):
mean += result.mean
# TODO variance not supported variance += result1.variance
Expand Down

0 comments on commit ce6df23

Please sign in to comment.