Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix bug in JsonScalar::from_scalar #74

Merged
merged 5 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions quizx/src/json/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//!
//! This definition is compatible with the `pyzx` JSON format for scalars.

use num::complex::ComplexFloat;
use std::f64::consts::PI;

use num::{One, Zero};
Expand All @@ -20,6 +21,7 @@ impl JsonScalar {
let phase_options = PhaseOptions {
ignore_approx: true,
ignore_pi: true,
limit_denom: Some(256),
..Default::default()
};
match scalar {
Expand All @@ -35,11 +37,39 @@ impl JsonScalar {
..Default::default()
}
}
Scalar::Exact(pow, _) => {
Scalar::Exact(pow, coeffs) => {
// pow is an integer specifying the power of 2 that is applied
// power2 in the JsonScalar representation and in pyzx refers to the power of sqrt(2)

// Extract the phase. scalar.phase() will return exact representations of multiples of pi/4. In
// other cases, we lose precision.
let phase = JsonPhase::from_phase(scalar.phase(), phase_options);

// In the Clifford+T case where we have Scalar4, we can extract factors of sqrt(2) directly from the
// coefficients. Since the coefficients are reduced, sqrt(2) is represented as
// [1, 0, +-1, 0], [0, 1, +-1, 0], where the +- lead to phase contributions already extracted in `phase`
let (power_sqrt2, floatfactor) =
match coeffs.iter_coeffs().collect::<Vec<_>>().as_slice() {
[a, 0, b, 0] | [0, a, 0, b]
if a.abs() == 1 && b.abs() == 1 && coeffs.len() == 4 =>
{
(*pow * 2 + 1, Default::default()) // Coefficients represent a factor of sqrt(2)
}
cf => (
// In all other cases, we simply assign the complex value to the pyzx floatfactor
*pow * 2,
Scalar::<Vec<_>>::from_int_coeffs(cf).complex_value().abs(),
),
};

JsonScalar {
power2: *pow,
power2: power_sqrt2,
phase,
floatfactor: if floatfactor == 1.0 {
Default::default()
} else {
floatfactor
},
is_zero: scalar.is_zero(),
..Default::default()
}
Expand Down Expand Up @@ -102,6 +132,11 @@ mod test {
#[case(ScalarN::from_phase((-1,2)))]
#[case(ScalarN::real(2.0))]
#[case(ScalarN::complex(1.0, 1.0))]
#[case(ScalarN::from_int_coeffs(&[0, 1, 0, -1]))]
#[case(ScalarN::from_int_coeffs(&[0, 7, 0, 7]))]
#[case(ScalarN::from_int_coeffs(&[-2, 0, -2, 0]))]
#[case(ScalarN::from_int_coeffs(&[2, 0, -2, 0]))]
#[case(ScalarN::from_int_coeffs(&[2, 0, 0, 0, 0, 0]))]
fn scalar_roundtrip(#[case] scalar: ScalarN) -> Result<(), JsonError> {
let json_scalar = JsonScalar::from_scalar(&scalar);
let decoded: ScalarN = json_scalar.to_scalar()?;
Expand Down
57 changes: 54 additions & 3 deletions quizx/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

use approx::AbsDiffEq;
use num::complex::Complex;
use num::integer::sqrt;
pub use num::traits::identities::{One, Zero};
use num::{integer, Integer};
use num::{integer, Integer, Rational64};
use std::cmp::min;
use std::f64::consts::PI;
use std::fmt;
Expand Down Expand Up @@ -167,8 +168,39 @@ impl<T: Coeffs> Scalar<T> {

/// Returns the phase of the scalar, expressed as half turns.
///
/// As [`Phase`] is encoded as a rational number, this method may lose precision.
/// We deal with Pi/4 phases of Scalar4 (Clifford+T) exactly. For other cases, [`Phase`] is encoded as a rational
/// number, which may lose precision.
pub fn phase(&self) -> Phase {
if let Exact(_, coeffs) = self {
if coeffs.len() == 4 {
// cases where the phase is a multiple of 1/4 are handled exactly
match coeffs.iter_coeffs().collect::<Vec<_>>().as_slice() {
[a, b, 0, c] if -b == *c => {
return Phase::new(((-a - b * sqrt(2)).signum() as i64 + 1) / 2)
}
[0, c, 0, 0] => {
return Phase::new(Rational64::new(if *c > 0 { 1 } else { 5 }, 4))
}
[0, 0, c, 0] => {
return Phase::new(Rational64::new(if *c > 0 { 1 } else { 3 }, 2))
}
[0, 0, 0, c] => {
return Phase::new(Rational64::new(if *c > 0 { 3 } else { 7 }, 4))
}
[c, 0, d, 0] if c == d => {
return Phase::new(Rational64::new(if *c > 0 { 1 } else { 5 }, 4))
}
[0, c, 0, d] if c == d => {
return Phase::new(Rational64::new(if *c > 0 { 1 } else { 3 }, 2))
}
[d, 0, c, 0] if -c == *d => {
return Phase::new(Rational64::new(if *c > 0 { 3 } else { 7 }, 4))
}
_ => {}
}
}
}
// for other cases, we use the floating point representation
Phase::from_f64(self.complex_value().arg() / PI)
}

Expand Down Expand Up @@ -632,7 +664,7 @@ impl<T: Coeffs> PartialEq for Scalar<T> {

all_eq
}
_ => false,
_ => self.complex_value() == other.complex_value(),
}
}
}
Expand Down Expand Up @@ -705,6 +737,7 @@ mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use num::Rational64;
use rstest::rstest;

#[test]
fn approx_mul() {
Expand Down Expand Up @@ -766,6 +799,24 @@ mod tests {
);
}

#[rstest]
#[case(ScalarN::from_int_coeffs(&[3, 0, 0, 0]))]
#[case(ScalarN::from_int_coeffs(&[0, -2, 0, 0]))]
#[case(ScalarN::from_int_coeffs(&[0, 0, 1, 0]))]
#[case(ScalarN::from_int_coeffs(&[0, 0, 0, 1]))]
#[case(ScalarN::from_int_coeffs(&[0, 0, 0, -1]))]
#[case(ScalarN::from_int_coeffs(&[2, 0, 2, 0]))]
#[case(ScalarN::from_int_coeffs(&[2, 0, -2, 0]))]
#[case(ScalarN::from_int_coeffs(&[-2, 0, -2, 0]))]
#[case(ScalarN::from_int_coeffs(&[0, 1, 0, 1]))]
#[case(ScalarN::from_int_coeffs(&[0, 1, 0, -1]))]
#[case(ScalarN::from_int_coeffs(&[0, -2, 0, -2]))]
#[case(ScalarN::from_int_coeffs(&[0, 2, 0, -2]))]
#[case(ScalarN::from_int_coeffs(&[-1, 2, 3, -4]))]
fn get_phase(#[case] s: ScalarN) {
assert_abs_diff_eq!(s.phase().to_f64(), s.complex_value().arg() / PI);
}

#[test]
fn additions() {
let s = ScalarN::from_int_coeffs(&[1, 2, 3, 4]);
Expand Down
Loading