Skip to content

Commit

Permalink
Add 1x1 and 3x3 special cases to Solve()
Browse files Browse the repository at this point in the history
  • Loading branch information
calcmogul committed Jun 28, 2024
1 parent ecc2624 commit 2ec0b06
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 43 deletions.
59 changes: 42 additions & 17 deletions jormungandr/test/autodiff/variable_matrix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,45 @@ def test_block_free_function():


def test_solve_free_function():
A1 = VariableMatrix([[1.0, 2.0], [3.0, 4.0]])
B1 = VariableMatrix([[5.0], [6.0]])
X1 = autodiff.solve(A1, B1)

expected1 = np.array([[-4.0], [4.5]])
assert X1.shape == (2, 1)
assert (A1.value() @ X1.value() == B1.value()).all()
assert (X1.value() == expected1).all()

A2 = VariableMatrix([[1.0, 2.0, 3.0], [-4.0, -5.0, 6.0], [7.0, 8.0, 9.0]])
B2 = VariableMatrix([[10.0], [11.0], [12.0]])
X2 = autodiff.solve(A2, B2)

expected2 = np.array([[-7.5], [6.0], [11.0 / 6.0]])
assert X2.shape == (3, 1)
assert np.linalg.norm(A2.value() @ X2.value() - B2.value()) < 1e-12
assert np.linalg.norm(X2.value() - expected2) < 1e-12
A_11 = VariableMatrix([[2.0]])
B_11 = VariableMatrix([[5.0]])
X_11 = autodiff.solve(A_11, B_11)

expected_11 = np.array([[2.5]])
assert X_11.shape == (1, 1)
assert (A_11.value() @ X_11.value() == B_11.value()).all()
assert (X_11.value() == expected_11).all()

A_22 = VariableMatrix([[1.0, 2.0], [3.0, 4.0]])
B_21 = VariableMatrix([[5.0], [6.0]])
X_21 = autodiff.solve(A_22, B_21)

expected_21 = np.array([[-4.0], [4.5]])
assert X_21.shape == (2, 1)
assert (A_22.value() @ X_21.value() == B_21.value()).all()
assert (X_21.value() == expected_21).all()

A_33 = VariableMatrix([[1.0, 2.0, 3.0], [-4.0, -5.0, 6.0], [7.0, 8.0, 9.0]])
B_31 = VariableMatrix([[10.0], [11.0], [12.0]])
X_31 = autodiff.solve(A_33, B_31)

expected_31 = np.array([[-7.5], [6.0], [11.0 / 6.0]])
assert X_31.shape == (3, 1)
assert np.linalg.norm(A_33.value() @ X_31.value() - B_31.value()) < 1e-12
assert np.linalg.norm(X_31.value() - expected_31) < 1e-12

A_44 = VariableMatrix(
[
[1.0, 2.0, 3.0, -4.0],
[-5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
]
)
B_41 = VariableMatrix([[17.0], [18.0], [19.0], [20.0]])
X_41 = autodiff.solve(A_44, B_41)

expected_41 = np.array([[4.44089e-16], [-16.25], [16.5], [0.0]])
assert X_41.shape == (4, 1)
assert np.linalg.norm(A_44.value() @ X_41.value() - B_41.value()) < 1e-12
assert np.linalg.norm(X_41.value() - expected_41) < 1e-12
51 changes: 44 additions & 7 deletions src/autodiff/VariableMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,51 @@ VariableMatrix Solve(const VariableMatrix& A, const VariableMatrix& B) {
// m x n * n x p = m x p
Assert(A.Rows() == B.Rows());

if (A.Rows() == 2 && A.Cols() == 2) {
if (A.Rows() == 1 && A.Cols() == 1) {
// Compute optimal inverse instead of using Eigen's general solver
sleipnir::VariableMatrix Ainv{2, 2};
Ainv(0, 0) = A(1, 1);
Ainv(0, 1) = -A(0, 1);
Ainv(1, 0) = -A(1, 0);
Ainv(1, 1) = A(0, 0);
auto detA = A(0, 0) * A(1, 1) - A(0, 1) * A(1, 0);
return 1.0 / A(0, 0) * B;
} else if (A.Rows() == 2 && A.Cols() == 2) {
// Compute optimal inverse instead of using Eigen's general solver
//
// [a b]⁻¹ ___1___ [ d −b]
// [c d] = ad − bc [−c a]

const auto& a = A(0, 0);
const auto& b = A(0, 1);
const auto& c = A(1, 0);
const auto& d = A(1, 1);

sleipnir::VariableMatrix Ainv{{d, -b}, {-c, a}};
auto detA = a * d - b * c;
Ainv /= detA;

return Ainv * B;
} else if (A.Rows() == 3 && A.Cols() == 3) {
// Compute optimal inverse instead of using Eigen's general solver
//
// [a b c]⁻¹
// [d e f]
// [g h i]
// 1 [ei − fh ch − bi bf − ce]
// = --------------------------------- [fg − di ai − cg cd − af]
// aei − afh − bdi + bfg + cdh − ceg [dh − eg bg − ah ae − bd]

const auto& a = A(0, 0);
const auto& b = A(0, 1);
const auto& c = A(0, 2);
const auto& d = A(1, 0);
const auto& e = A(1, 1);
const auto& f = A(1, 2);
const auto& g = A(2, 0);
const auto& h = A(2, 1);
const auto& i = A(2, 2);

sleipnir::VariableMatrix Ainv{
{e * i - f * h, c * h - b * i, b * f - c * e},
{f * g - d * i, a * i - c * g, c * d - a * f},
{d * h - e * g, b * g - a * h, a * e - b * d}};
auto detA =
a * e * i - a * f * h - b * d * i + b * f * g + c * d * h - c * e * g;
Ainv /= detA;

return Ainv * B;
Expand Down
62 changes: 43 additions & 19 deletions test/src/autodiff/VariableMatrixTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,48 @@ TEST_CASE("VariableMatrix - Block() free function", "[VariableMatrix]") {
}

TEST_CASE("VariableMatrix - Solve() free function", "[VariableMatrix]") {
sleipnir::VariableMatrix A1{{1.0, 2.0}, {3.0, 4.0}};
sleipnir::VariableMatrix B1{{5.0}, {6.0}};
sleipnir::VariableMatrix X1 = sleipnir::Solve(A1, B1);

Eigen::Matrix<double, 2, 1> expected1{{-4.0}, {4.5}};
CHECK(X1.Rows() == 2);
CHECK(X1.Cols() == 1);
CHECK(A1.Value() * X1.Value() == B1.Value());
CHECK(X1.Value() == expected1);

sleipnir::VariableMatrix A2{
sleipnir::VariableMatrix A_11{{2.0}};
sleipnir::VariableMatrix B_11{{5.0}};
sleipnir::VariableMatrix X_11 = sleipnir::Solve(A_11, B_11);

Eigen::Matrix<double, 1, 1> expected_11{{2.5}};
CHECK(X_11.Rows() == 1);
CHECK(X_11.Cols() == 1);
CHECK(A_11.Value() * X_11.Value() == B_11.Value());
CHECK(X_11.Value() == expected_11);

sleipnir::VariableMatrix A_22{{1.0, 2.0}, {3.0, 4.0}};
sleipnir::VariableMatrix B_21{{5.0}, {6.0}};
sleipnir::VariableMatrix X_21 = sleipnir::Solve(A_22, B_21);

Eigen::Matrix<double, 2, 1> expected_21{{-4.0}, {4.5}};
CHECK(X_21.Rows() == 2);
CHECK(X_21.Cols() == 1);
CHECK(A_22.Value() * X_21.Value() == B_21.Value());
CHECK(X_21.Value() == expected_21);

sleipnir::VariableMatrix A_33{
{1.0, 2.0, 3.0}, {-4.0, -5.0, 6.0}, {7.0, 8.0, 9.0}};
sleipnir::VariableMatrix B2{{10.0}, {11.0}, {12.0}};
sleipnir::VariableMatrix X2 = sleipnir::Solve(A2, B2);

Eigen::Matrix<double, 3, 1> expected2{{-7.5}, {6.0}, {11.0 / 6.0}};
CHECK(X2.Rows() == 3);
CHECK(X2.Cols() == 1);
CHECK((A2.Value() * X2.Value() - B2.Value()).norm() < 1e-12);
CHECK((X2.Value() - expected2).norm() < 1e-12);
sleipnir::VariableMatrix B_31{{10.0}, {11.0}, {12.0}};
sleipnir::VariableMatrix X_31 = sleipnir::Solve(A_33, B_31);

Eigen::Matrix<double, 3, 1> expected_31{{-7.5}, {6.0}, {11.0 / 6.0}};
CHECK(X_31.Rows() == 3);
CHECK(X_31.Cols() == 1);
CHECK((A_33.Value() * X_31.Value() - B_31.Value()).norm() < 1e-12);
CHECK((X_31.Value() - expected_31).norm() < 1e-12);

sleipnir::VariableMatrix A_44{{1.0, 2.0, 3.0, -4.0},
{-5.0, 6.0, 7.0, 8.0},
{9.0, 10.0, 11.0, 12.0},
{13.0, 14.0, 15.0, 16.0}};
sleipnir::VariableMatrix B_41{{17.0}, {18.0}, {19.0}, {20.0}};
sleipnir::VariableMatrix X_41 = sleipnir::Solve(A_44, B_41);

Eigen::Matrix<double, 4, 1> expected_41{
{4.44089e-16}, {-16.25}, {16.5}, {0.0}};
CHECK(X_41.Rows() == 4);
CHECK(X_41.Cols() == 1);
CHECK((A_44.Value() * X_41.Value() - B_41.Value()).norm() < 1e-12);
CHECK((X_41.Value() - expected_41).norm() < 1e-12);
}

0 comments on commit 2ec0b06

Please sign in to comment.