diff --git a/mpqp/execution/providers/google.py b/mpqp/execution/providers/google.py index 12d8b3af..9b827727 100644 --- a/mpqp/execution/providers/google.py +++ b/mpqp/execution/providers/google.py @@ -16,6 +16,7 @@ from mpqp.execution.devices import GOOGLEDevice from mpqp.execution.job import Job, JobType from mpqp.execution.result import Result, Sample, StateVector +from mpqp.tools.generics import flatten @typechecked @@ -114,7 +115,8 @@ def run_local(job: Job) -> Result: ValueError: If the job device is not GOOGLEDevice. """ from cirq.circuits.circuit import Circuit as CirqCircuit - from cirq.ops.linear_combinations import PauliSum as Cirq_PauliSum + from cirq.ops.linear_combinations import PauliSum as CirqPauliSum + from cirq.ops.pauli_string import PauliString as CirqPauliString from cirq.sim.sparse_simulator import Simulator from cirq.work.observable_measurement import ( RepetitionsStoppingCriteria, @@ -150,7 +152,9 @@ def run_local(job: Job) -> Result: cirq_obs = job.measure.observable.to_other_language( language=Language.CIRQ, circuit=cirq_circuit ) - assert type(cirq_obs) == Cirq_PauliSum + assert isinstance(cirq_obs, CirqPauliSum) or isinstance( + cirq_obs, CirqPauliString + ) if job.measure.shots == 0: result_sim = simulator.simulate_expectation_values( @@ -159,11 +163,11 @@ def run_local(job: Job) -> Result: else: result_sim = measure_observables( cirq_circuit, - observables=cirq_obs, # type: ignore[reportArgumentType] + observables=flatten(cirq_obs), sampler=simulator, stopping_criteria=RepetitionsStoppingCriteria(job.measure.shots), ) - + print(result_sim) return extract_result_OBSERVABLE(result_sim, job) else: raise ValueError(f"Job type {job.job_type} not handled") @@ -313,7 +317,7 @@ def extract_result_OBSERVABLE( raise NotImplementedError("job.measure is None") for result in results: if isinstance(result, float): - mean += abs(result) + mean += result if isinstance(result, ObservableMeasuredResult): mean += result.mean # TODO variance not supported variance += result1.variance