Skip to content

Commit

Permalink
Specialize 4x4 matrix inverse (#659)
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul authored Dec 12, 2024
1 parent 8002cb4 commit cf1e097
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 59 deletions.
8 changes: 8 additions & 0 deletions cmake/modules/CompilerFlags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ macro(compiler_flags target)
target_compile_options(${target} PRIVATE /wd4244 /wd4251 /WX)
endif()

# Disable warning false positives in Eigen
if(
${CMAKE_CXX_COMPILER_ID} STREQUAL "GNU"
AND ${CMAKE_CXX_COMPILER_VERSION} VERSION_GREATER_EQUAL "12"
)
target_compile_options(${target} PRIVATE -Wno-array-bounds)
endif()

target_compile_features(${target} PUBLIC cxx_std_23)
if(MSVC)
target_compile_options(${target} PUBLIC /MP /utf-8 /bigobj)
Expand Down
198 changes: 183 additions & 15 deletions src/autodiff/VariableMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@ VariableMatrix Solve(const VariableMatrix& A, const VariableMatrix& B) {
const auto& c = A(1, 0);
const auto& d = A(1, 1);

sleipnir::VariableMatrix Ainv{{d, -b}, {-c, a}};
sleipnir::VariableMatrix adjA{{d, -b}, {-c, a}};
auto detA = a * d - b * c;
Ainv /= detA;

return Ainv * B;
return adjA / detA * B;
} else if (A.Rows() == 3 && A.Cols() == 3) {
// Compute optimal inverse instead of using Eigen's general solver
//
// [a b c]⁻¹
// [d e f]
// [g h i]
// 1 [ei − fh ch − bi bf − ce]
// = --------------------------------- [fg − di ai − cg cd − af]
// aei − afh − bdi + bfg + cdh − ceg [dh − eg bg − ah ae − bd]
// 1 [ei − fh ch − bi bf − ce]
// = ------------------------------------ [fg − di ai − cg cd − af]
// a(ei − fh) + b(fg − di) + c(dh − eg) [dh − eg bg − ah ae − bd]
//
// https://www.wolframalpha.com/input?i=inverse+%7B%7Ba%2C+b%2C+c%7D%2C+%7Bd%2C+e%2C+f%7D%2C+%7Bg%2C+h%2C+i%7D%7D

const auto& a = A(0, 0);
const auto& b = A(0, 1);
Expand All @@ -49,15 +49,183 @@ VariableMatrix Solve(const VariableMatrix& A, const VariableMatrix& B) {
const auto& h = A(2, 1);
const auto& i = A(2, 2);

sleipnir::VariableMatrix Ainv{
{e * i - f * h, c * h - b * i, b * f - c * e},
{f * g - d * i, a * i - c * g, c * d - a * f},
{d * h - e * g, b * g - a * h, a * e - b * d}};
auto detA =
a * e * i - a * f * h - b * d * i + b * f * g + c * d * h - c * e * g;
Ainv /= detA;
auto ae = a * e;
auto af = a * f;
auto ah = a * h;
auto ai = a * i;
auto bd = b * d;
auto bf = b * f;
auto bg = b * g;
auto bi = b * i;
auto cd = c * d;
auto ce = c * e;
auto cg = c * g;
auto ch = c * h;
auto dh = d * h;
auto di = d * i;
auto eg = e * g;
auto ei = e * i;
auto fg = f * g;
auto fh = f * h;

auto adjA00 = ei - fh;
auto adjA10 = fg - di;
auto adjA20 = dh - eg;

sleipnir::VariableMatrix adjA{{adjA00, ch - bi, bf - ce},
{adjA10, ai - cg, cd - af},
{adjA20, bg - ah, ae - bd}};
auto detA = a * adjA00 + b * adjA10 + c * adjA20;
return adjA / detA * B;
} else if (A.Rows() == 4 && A.Cols() == 4) {
// Compute optimal inverse instead of using Eigen's general solver
//
// [a b c d]⁻¹
// [e f g h]
// [i j k l]
// [m n o p]
//
// https://www.wolframalpha.com/input?i=inverse+%7B%7Ba%2C+b%2C+c%2C+d%7D%2C+%7Be%2C+f%2C+g%2C+h%7D%2C+%7Bi%2C+j%2C+k%2C+l%7D%2C+%7Bm%2C+n%2C+o%2C+p%7D%7D

const auto& a = A(0, 0);
const auto& b = A(0, 1);
const auto& c = A(0, 2);
const auto& d = A(0, 3);
const auto& e = A(1, 0);
const auto& f = A(1, 1);
const auto& g = A(1, 2);
const auto& h = A(1, 3);
const auto& i = A(2, 0);
const auto& j = A(2, 1);
const auto& k = A(2, 2);
const auto& l = A(2, 3);
const auto& m = A(3, 0);
const auto& n = A(3, 1);
const auto& o = A(3, 2);
const auto& p = A(3, 3);

auto afk = a * f * k;
auto afl = a * f * l;
auto afo = a * f * o;
auto afp = a * f * p;
auto agj = a * g * j;
auto agl = a * g * l;
auto agn = a * g * n;
auto agp = a * g * p;
auto ahj = a * h * j;
auto ahk = a * h * k;
auto ahn = a * h * n;
auto aho = a * h * o;
auto ajo = a * j * o;
auto ajp = a * j * p;
auto akn = a * k * n;
auto akp = a * k * p;
auto aln = a * l * n;
auto alo = a * l * o;
auto bek = b * e * k;
auto bel = b * e * l;
auto beo = b * e * o;
auto bep = b * e * p;
auto bgi = b * g * i;
auto bgl = b * g * l;
auto bgm = b * g * m;
auto bgp = b * g * p;
auto bhi = b * h * i;
auto bhk = b * h * k;
auto bhm = b * h * m;
auto bho = b * h * o;
auto bio = b * i * o;
auto bip = b * i * p;
auto bjp = b * j * p;
auto bkm = b * k * m;
auto bkp = b * k * p;
auto blm = b * l * m;
auto blo = b * l * o;
auto cej = c * e * j;
auto cel = c * e * l;
auto cen = c * e * n;
auto cep = c * e * p;
auto cfi = c * f * i;
auto cfl = c * f * l;
auto cfm = c * f * m;
auto cfp = c * f * p;
auto chi = c * h * i;
auto chj = c * h * j;
auto chm = c * h * m;
auto chn = c * h * n;
auto cin = c * i * n;
auto cip = c * i * p;
auto cjm = c * j * m;
auto cjp = c * j * p;
auto clm = c * l * m;
auto cln = c * l * n;
auto dej = d * e * j;
auto dek = d * e * k;
auto den = d * e * n;
auto deo = d * e * o;
auto dfi = d * f * i;
auto dfk = d * f * k;
auto dfm = d * f * m;
auto dfo = d * f * o;
auto dgi = d * g * i;
auto dgj = d * g * j;
auto dgm = d * g * m;
auto dgn = d * g * n;
auto din = d * i * n;
auto dio = d * i * o;
auto djm = d * j * m;
auto djo = d * j * o;
auto dkm = d * k * m;
auto dkn = d * k * n;
auto ejo = e * j * o;
auto ejp = e * j * p;
auto ekn = e * k * n;
auto ekp = e * k * p;
auto eln = e * l * n;
auto elo = e * l * o;
auto fio = f * i * o;
auto fip = f * i * p;
auto fkm = f * k * m;
auto fkp = f * k * p;
auto flm = f * l * m;
auto flo = f * l * o;
auto gin = g * i * n;
auto gip = g * i * p;
auto gjm = g * j * m;
auto gjp = g * j * p;
auto glm = g * l * m;
auto gln = g * l * n;
auto hin = h * i * n;
auto hio = h * i * o;
auto hjm = h * j * m;
auto hjo = h * j * o;
auto hkm = h * k * m;
auto hkn = h * k * n;

auto adjA00 = fkp - flo - gjp + gln + hjo - hkn;
auto adjA01 = -bkp + blo + cjp - cln - djo + dkn;
auto adjA02 = bgp - bho - cfp + chn + dfo - dgn;
auto adjA03 = -bgl + bhk + cfl - chj - dfk + dgj;
auto adjA10 = -ekp + elo + gip - glm - hio + hkm;
auto adjA11 = akp - alo - cip + clm + dio - dkm;
auto adjA12 = -agp + aho + cep - chm - deo + dgm;
auto adjA13 = agl - ahk - cel + chi + dek - dgi;
auto adjA20 = ejp - eln - fip + flm + hin - hjm;
auto adjA21 = -ajp + aln + bip - blm - din + djm;
auto adjA22 = afp - ahn - bep + bhm + den - dfm;
auto adjA23 = -afl + ahj + bel - bhi - dej + dfi;
auto adjA30 = -ejo + ekn + fio - fkm - gin + gjm;
// NOLINTNEXTLINE
auto adjA31 = ajo - akn - bio + bkm + cin - cjm;
auto adjA32 = -afo + agn + beo - bgm - cen + cfm;
auto adjA33 = afk - agj - bek + bgi + cej - cfi;

return Ainv * B;
sleipnir::VariableMatrix adjA{{adjA00, adjA01, adjA02, adjA03},
{adjA10, adjA11, adjA12, adjA13},
{adjA20, adjA21, adjA22, adjA23},
{adjA30, adjA31, adjA32, adjA33}};
auto detA = a * adjA00 + b * adjA10 + c * adjA20 + d * adjA30;
return adjA / detA * B;
} else {
using MatrixXv = Eigen::Matrix<Variable, Eigen::Dynamic, Eigen::Dynamic>;

Expand Down
92 changes: 48 additions & 44 deletions test/src/autodiff/VariableMatrixTest.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// Copyright (c) Sleipnir contributors

#include <format>
#include <functional>
#include <iterator>

#include <Eigen/Core>
#include <Eigen/QR>
#include <catch2/catch_test_macros.hpp>
#include <sleipnir/autodiff/VariableMatrix.hpp>

Expand Down Expand Up @@ -334,49 +336,51 @@ TEST_CASE("VariableMatrix - Block() free function", "[VariableMatrix]") {
CHECK(mat2.Value() == expected2);
}

template <int Rows>
void ExpectSolve(const Eigen::Matrix<double, Rows, Rows>& A,
const Eigen::Matrix<double, Rows, 1>& B) {
INFO(std::format("Solve {}x{}", Rows, Rows));

sleipnir::VariableMatrix slpA{A};
sleipnir::VariableMatrix slpB{B};
auto actualX = sleipnir::Solve(slpA, slpB);

Eigen::Matrix<double, Rows, 1> expectedX = A.householderQr().solve(B);

CHECK(actualX.Rows() == expectedX.rows());
CHECK(actualX.Cols() == expectedX.cols());
CHECK((slpA.Value() * actualX.Value() - slpB.Value()).norm() < 1e-12);
CHECK((actualX.Value() - expectedX).norm() < 1e-12);
}

TEST_CASE("VariableMatrix - Solve() free function", "[VariableMatrix]") {
sleipnir::VariableMatrix A_11{{2.0}};
sleipnir::VariableMatrix B_11{{5.0}};
sleipnir::VariableMatrix X_11 = sleipnir::Solve(A_11, B_11);

Eigen::Matrix<double, 1, 1> expected_11{{2.5}};
CHECK(X_11.Rows() == 1);
CHECK(X_11.Cols() == 1);
CHECK(A_11.Value() * X_11.Value() == B_11.Value());
CHECK(X_11.Value() == expected_11);

sleipnir::VariableMatrix A_22{{1.0, 2.0}, {3.0, 4.0}};
sleipnir::VariableMatrix B_21{{5.0}, {6.0}};
sleipnir::VariableMatrix X_21 = sleipnir::Solve(A_22, B_21);

Eigen::Matrix<double, 2, 1> expected_21{{-4.0}, {4.5}};
CHECK(X_21.Rows() == 2);
CHECK(X_21.Cols() == 1);
CHECK(A_22.Value() * X_21.Value() == B_21.Value());
CHECK(X_21.Value() == expected_21);

sleipnir::VariableMatrix A_33{
{1.0, 2.0, 3.0}, {-4.0, -5.0, 6.0}, {7.0, 8.0, 9.0}};
sleipnir::VariableMatrix B_31{{10.0}, {11.0}, {12.0}};
sleipnir::VariableMatrix X_31 = sleipnir::Solve(A_33, B_31);

Eigen::Matrix<double, 3, 1> expected_31{{-7.5}, {6.0}, {11.0 / 6.0}};
CHECK(X_31.Rows() == 3);
CHECK(X_31.Cols() == 1);
CHECK((A_33.Value() * X_31.Value() - B_31.Value()).norm() < 1e-12);
CHECK((X_31.Value() - expected_31).norm() < 1e-12);

sleipnir::VariableMatrix A_44{{1.0, 2.0, 3.0, -4.0},
{-5.0, 6.0, 7.0, 8.0},
{9.0, 10.0, 11.0, 12.0},
{13.0, 14.0, 15.0, 16.0}};
sleipnir::VariableMatrix B_41{{17.0}, {18.0}, {19.0}, {20.0}};
sleipnir::VariableMatrix X_41 = sleipnir::Solve(A_44, B_41);

Eigen::Matrix<double, 4, 1> expected_41{
{4.44089e-16}, {-16.25}, {16.5}, {0.0}};
CHECK(X_41.Rows() == 4);
CHECK(X_41.Cols() == 1);
CHECK((A_44.Value() * X_41.Value() - B_41.Value()).norm() < 1e-12);
CHECK((X_41.Value() - expected_41).norm() < 1e-12);
// 1x1 special case
ExpectSolve(Eigen::Matrix<double, 1, 1>{{2.0}},
Eigen::Matrix<double, 1, 1>{{5.0}});

// 2x2 special case
ExpectSolve(Eigen::Matrix<double, 2, 2>{{1.0, 2.0}, {3.0, 4.0}},
Eigen::Matrix<double, 2, 1>{{5.0}, {6.0}});

// 3x3 special case
ExpectSolve(
Eigen::Matrix<double, 3, 3>{
{1.0, 2.0, 3.0}, {-4.0, -5.0, 6.0}, {7.0, 8.0, 9.0}},
Eigen::Matrix<double, 3, 1>{{10.0}, {11.0}, {12.0}});

// 4x4 special case
ExpectSolve(Eigen::Matrix<double, 4, 4>{{1.0, 2.0, 3.0, -4.0},
{-5.0, 6.0, 7.0, 8.0},
{9.0, 10.0, 11.0, 12.0},
{13.0, 14.0, 15.0, 16.0}},
Eigen::Matrix<double, 4, 1>{{17.0}, {18.0}, {19.0}, {20.0}});

// 5x5 general case
ExpectSolve(
Eigen::Matrix<double, 5, 5>{{1.0, 2.0, 3.0, -4.0, 5.0},
{-5.0, 6.0, 7.0, 8.0, 9.0},
{9.0, 10.0, 11.0, 12.0, 13.0},
{13.0, 14.0, 15.0, 16.0, 17.0},
{17.0, 18.0, 19.0, 20.0, 21.0}},
Eigen::Matrix<double, 5, 1>{{21.0}, {22.0}, {23.0}, {24.0}, {25.0}});
}

0 comments on commit cf1e097

Please sign in to comment.