Skip to content

Commit

Permalink
feat: mostly works
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Sep 3, 2024
1 parent 1b723cd commit ca6957d
Showing 1 changed file with 71 additions and 67 deletions.
138 changes: 71 additions & 67 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def test_simo_trig():
jax_f = convert(casadi_f)
x_val = np.random.randn(1, 1)
compare_results(casadi_f, jax_f, x_val)
#
#


def test_simo_poly():
x = ca.SX.sym('x', 1, 1)
casadi_f = ca.Function('simo_poly', [x], [x**2, x**3, ca.sqrt(x)])
Expand Down Expand Up @@ -70,20 +70,20 @@ def test_miso_multiply():
x_val = np.random.randn(3, 3)
y_val = np.random.randn(3, 3)
compare_results(casadi_f, jax_f, x_val, y_val)
#
#
# def test_miso_combined():
# x = ca.SX.sym('x', 3, 3)
# y = ca.SX.sym('y', 3, 3)
# z = ca.SX.sym('z', 3, 3)
# casadi_f = ca.Function('miso_combined', [x, y, z], [ca.mtimes(x, y) + z])
# jax_f = convert(casadi_f)
# x_val = np.random.randn(3, 3)
# y_val = np.random.randn(3, 3)
# z_val = np.random.randn(3, 3)
# compare_results(casadi_f, jax_f, x_val, y_val, z_val)
#
#


def test_miso_combined():
x = ca.SX.sym('x', 3, 3)
y = ca.SX.sym('y', 3, 3)
z = ca.SX.sym('z', 3, 3)
casadi_f = ca.Function('miso_combined', [x, y, z], [ca.mtimes(x, y) + z])
jax_f = convert(casadi_f)
x_val = np.random.randn(3, 3)
y_val = np.random.randn(3, 3)
z_val = np.random.randn(3, 3)
compare_results(casadi_f, jax_f, x_val, y_val, z_val)


def test_mimo_arith():
x = ca.SX.sym('x', 3, 3)
y = ca.SX.sym('y', 3, 3)
Expand All @@ -94,6 +94,8 @@ def test_mimo_arith():
compare_results(casadi_f, jax_f, x_val, y_val)
#
#


def test_mimo_trig():
x = ca.SX.sym('x', 3, 3)
y = ca.SX.sym('y', 3, 3)
Expand All @@ -102,21 +104,21 @@ def test_mimo_trig():
x_val = np.random.randn(3, 3)
y_val = np.random.randn(3, 3)
compare_results(casadi_f, jax_f, x_val, y_val)
#
#
# def test_mimo_complex():
# x = ca.SX.sym('x', 3, 3)
# y = ca.SX.sym('y', 3, 3)
# z = ca.SX.sym('z', 3, 3)
# casadi_f = ca.Function('mimo_complex', [x, y, z], [
# ca.mtimes(x, y), ca.inv(z), x + z])
# jax_f = convert(casadi_f)
# x_val = np.random.randn(3, 3)
# y_val = np.random.randn(3, 3)
# z_val = np.random.randn(3, 3)
# compare_results(casadi_f, jax_f, x_val, y_val, z_val)
#
#


def test_mimo_complex():
x = ca.SX.sym('x', 3, 3)
y = ca.SX.sym('y', 3, 3)
z = ca.SX.sym('z', 3, 3)
casadi_f = ca.Function('mimo_complex', [x, y, z], [
ca.mtimes(x, y), ca.inv(z), x + z])
jax_f = convert(casadi_f)
x_val = np.random.randn(3, 3)
y_val = np.random.randn(3, 3)
z_val = np.random.randn(3, 3)
compare_results(casadi_f, jax_f, x_val, y_val, z_val)


def test_sin():
x = ca.SX.sym('x', 1, 1)
casadi_f = ca.Function('sin', [x], [ca.sin(x)])
Expand All @@ -135,16 +137,17 @@ def test_cos():
compare_results(casadi_f, jax_f, x_val)


# def test_mtimes():
# x = ca.SX.sym('x', 2, 2)
# y = ca.SX.sym('y', 2, 2)
# casadi_f = ca.Function('mtimes', [x, y], [ca.mtimes(x, y)])
# jax_f = convert(casadi_f)
# # x_val = np.random.randn(2, 2)
# # y_val = np.random.randn(2, 2)
# x_val = np.array([[1, 1], [2, 2]])
# y_val = np.array([[2, 2], [2, 2]])
# compare_results(casadi_f, jax_f, x_val, y_val)
def test_mtimes():
x = ca.SX.sym('x', 2, 2)
y = ca.SX.sym('y', 2, 2)
casadi_f = ca.Function('mtimes', [x, y], [x @ y])
jax_f = convert(casadi_f)
# x_val = np.random.randn(2, 2)
# y_val = np.random.randn(2, 2)
x_val = np.array([[1, 1], [2, 2]])
y_val = np.array([[2, 2], [2, 2]])
print(translate(casadi_f))
compare_results(casadi_f, jax_f, x_val, y_val)


def test_inv():
Expand All @@ -160,19 +163,20 @@ def test_norm_2():
casadi_f = ca.Function('norm_2', [x], [ca.norm_2(x)])
jax_f = convert(casadi_f)
x_val = np.random.randn(3, 1)
print(translate(casadi_f))
compare_results(casadi_f, jax_f, x_val)


def test_sum1():
x = ca.SX.sym('x', 2, 2)
casadi_f = ca.Function('sum1', [x], [ca.sum1(x)])
x_val = np.random.randn(2, 2)
print(ca.sum1(x_val).shape)
jax_f = convert(casadi_f)
print(translate(casadi_f))
compare_results(casadi_f, jax_f, x_val)


# def test_sum1():
# x = ca.SX.sym('x', 2, 2)
# casadi_f = ca.Function('sum1', [x], [ca.sum1(x)])
# x_val = np.random.randn(2, 2)
# print(ca.sum1(x_val).shape)
# jax_f = convert(casadi_f)
# print(translate(casadi_f))
# compare_results(casadi_f, jax_f, x_val)
#
#
def test_dot():
x = ca.SX.sym('x', 3, 1)
y = ca.SX.sym('y', 3, 1)
Expand All @@ -181,16 +185,16 @@ def test_dot():
x_val = np.random.randn(3, 1)
y_val = np.random.randn(3, 1)
compare_results(casadi_f, jax_f, x_val, y_val)
#
#


def test_transpose():
x = ca.SX.sym('x', 3, 3)
casadi_f = ca.Function('transpose', [x], [ca.transpose(x)])
jax_f = convert(casadi_f)
x_val = np.random.randn(3, 3)
compare_results(casadi_f, jax_f, x_val)
#
#


def test_add():
x = ca.SX.sym('x', 3, 3)
y = ca.SX.sym('y', 3, 3)
Expand All @@ -199,8 +203,8 @@ def test_add():
x_val = np.random.randn(3, 3)
y_val = np.random.randn(3, 3)
compare_results(casadi_f, jax_f, x_val, y_val)
#
#


def test_multiply():
x = ca.SX.sym('x', 3, 3)
y = ca.SX.sym('y', 3, 3)
Expand All @@ -209,13 +213,13 @@ def test_multiply():
x_val = np.random.randn(3, 3)
y_val = np.random.randn(3, 3)
compare_results(casadi_f, jax_f, x_val, y_val)
#
#
# def test_combined():
# x = ca.SX.sym('x', 3, 3)
# y = ca.SX.sym('y', 3, 3)
# casadi_f = ca.Function('combined', [x, y], [ca.mtimes(x, y) + ca.inv(x)])
# jax_f = convert(casadi_f)
# x_val = np.random.randn(3, 3)
# y_val = np.random.randn(3, 3)
# compare_results(casadi_f, jax_f, x_val, y_val)


def test_combined():
x = ca.SX.sym('x', 3, 3)
y = ca.SX.sym('y', 3, 3)
casadi_f = ca.Function('combined', [x, y], [ca.mtimes(x, y) + ca.inv(x)])
jax_f = convert(casadi_f)
x_val = np.random.randn(3, 3)
y_val = np.random.randn(3, 3)
compare_results(casadi_f, jax_f, x_val, y_val)

0 comments on commit ca6957d

Please sign in to comment.