Skip to content

Commit

Permalink
Merge pull request #211 from ValeevGroup/gaudel/feature/scalar_fusion
Browse files Browse the repository at this point in the history
Factor out the gcd of the scalars when two `Product` summands are fused.
  • Loading branch information
bimalgaudel authored Jun 20, 2024
2 parents 237e141 + f16fdbc commit 7f31851
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 21 deletions.
51 changes: 39 additions & 12 deletions SeQuant/core/optimize/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,13 @@ ExprPtr Fusion::fuse_left(Product const& lhs, Product const& rhs) {
auto lsmand_prod = Product{lsmand.begin(), lsmand.end()};
auto rsmand_prod = Product{rsmand.begin(), rsmand.end()};

if (lhs.scalar() == rhs.scalar())
fac_prod.scale(lhs.scalar());
else {
lsmand_prod.scale(lhs.scalar());
rsmand_prod.scale(rhs.scalar());
}
assert(lhs.scalar().imag().is_zero() && rhs.scalar().imag().is_zero() &&
"Complex valued gcd not supported");
auto scalars_fused = fuse_scalar(lhs.scalar().real(), rhs.scalar().real());

fac_prod.scale(scalars_fused.at(0));
lsmand_prod.scale(scalars_fused.at(1));
rsmand_prod.scale(scalars_fused.at(2));

// f (a + b)

Expand Down Expand Up @@ -86,12 +87,13 @@ ExprPtr Fusion::fuse_right(Product const& lhs, Product const& rhs) {
auto lsmand_prod = Product{lsmand.begin(), lsmand.end()};
auto rsmand_prod = Product{rsmand.begin(), rsmand.end()};

if (lhs.scalar() == rhs.scalar())
fac_prod.scale(lhs.scalar());
else {
lsmand_prod.scale(lhs.scalar());
rsmand_prod.scale(rhs.scalar());
}
assert(lhs.scalar().imag().is_zero() && rhs.scalar().imag().is_zero() &&
"Complex valued gcd not supported");
auto scalars_fused = fuse_scalar(lhs.scalar().real(), rhs.scalar().real());

fac_prod.scale(scalars_fused.at(0));
lsmand_prod.scale(scalars_fused.at(1));
rsmand_prod.scale(scalars_fused.at(2));

// (a + b) f

Expand All @@ -102,4 +104,29 @@ ExprPtr Fusion::fuse_right(Product const& lhs, Product const& rhs) {
return ex<Product>(ExprPtrList{ex<Sum>(ExprPtrList{a, b}), f});
}

rational Fusion::gcd_rational(rational const& left, rational const& right) {
auto&& r1 = left.real();
auto&& r2 = right.real();
auto&& n1 = numerator(r1);
auto&& d1 = denominator(r1);
auto&& n2 = numerator(r2);
auto&& d2 = denominator(r2);

auto num = gcd(n1 * d2, n2 * d1);
return {num, d1 * d2};
}

std::array<rational, 3> Fusion::fuse_scalar(rational const& left,
rational const& right) {
auto fused = gcd_rational(left, right);
rational left_fused = left / fused;
rational right_fused = right / fused;
if (left < 0 && right < 0) {
fused *= -1;
left_fused *= -1;
right_fused *= -1;
}
return {fused, left_fused, right_fused};
}

} // namespace sequant::opt
14 changes: 14 additions & 0 deletions SeQuant/core/optimize/fusion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ class Fusion {

static ExprPtr fuse_right(Product const& lhs, Product const& rhs);

///
/// Get the greatest common divisor of two rational numbers.
///
static rational gcd_rational(rational const& left, rational const& right);

///
/// Fuse scalars @param left and @param right and the return the result
/// as an array of three elements: first is the greatest common factor,
/// second the fused sub-factor of @param left and the third is that
/// of @param right.
///
static std::array<rational, 3> fuse_scalar(rational const& left,
rational const& right);

private:
ExprPtr left_;

Expand Down
23 changes: 14 additions & 9 deletions tests/unit/test_fusion.cpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
#include <catch2/catch_test_macros.hpp>

#include <SeQuant/core/attr.hpp>
#include <SeQuant/core/expr.hpp>
#include <SeQuant/core/optimize/fusion.hpp>
#include <SeQuant/core/parse_expr.hpp>
#include <SeQuant/domain/mbpt/convention.hpp>

#include <array>
#include <memory>
#include <string_view>
#include <vector>

TEST_CASE("TEST_FUSION", "[Fusion]") {
TEST_CASE("TEST_FUSION", "[optimize]") {
using sequant::opt::Fusion;
using namespace sequant;
std::vector<std::array<std::wstring_view, 3>> fused_terms{
Expand All @@ -23,7 +21,7 @@ TEST_CASE("TEST_FUSION", "[Fusion]") {

{L"1/8 g{a1,a2;a3,a4} t{a3,a4;i1,i2}",
L"1/4 g{a1,a2;a3,a4} t{a3;i1} t{a4;i2}",
L"g{a1,a2;a3,a4}(1/8 t{a3,a4;i1,i2} + 1/4 t{a3;i1} t{a4;i2})"},
L"1/8 g{a1,a2;a3,a4}(t{a3,a4;i1,i2} + 2 t{a3;i1} t{a4;i2})"},

{L"1/4 g{a1,a2;a3,a4} t{a3;i1} t{a4;i2}",
L"1/4 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3;i1} t{a4;i2}",
Expand All @@ -32,13 +30,20 @@ TEST_CASE("TEST_FUSION", "[Fusion]") {

{L"1/8 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3,a4;i1,i2}",
L"1/4 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3;i1} t{a4;i2}",
L"g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} "
L" (1/8 t{a3,a4;i1,i2} + 1/4 t{a3;i1} t{a4;i2})"}};
L"1/8 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} "
L" (t{a3,a4;i1,i2} + 2 t{a3;i1} t{a4;i2})"},

{L"-1/8 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3,a4;i1,i2}",
L"-1/4 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} t{a3;i1} t{a4;i2}",
L"-1/8 g{i3,i4;a3,a4} t{a1;i3} t{a2;i4} "
L" (t{a3,a4;i1,i2} + 2 t{a3;i1} t{a4;i2})"}

};

for (auto&& [l, r, f] : fused_terms) {
auto const le = parse_expr(l, Symmetry::nonsymm);
auto const re = parse_expr(r, Symmetry::nonsymm);
auto const fe = parse_expr(f, Symmetry::nonsymm);
auto const le = parse_expr(l);
auto const re = parse_expr(r);
auto const fe = parse_expr(f);
auto fu = Fusion{le->as<Product>(), re->as<Product>()};
REQUIRE((fu.left() || fu.right()));

Expand Down

0 comments on commit 7f31851

Please sign in to comment.