Skip to content

Commit

Permalink
fix: solve problem with different shapes in input
Browse files Browse the repository at this point in the history
  • Loading branch information
lvjonok committed Sep 17, 2024
1 parent 56ee252 commit a6742c7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion jaxadi/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@
OP_ATANH: "jnp.arctanh(work[{0}])",
OP_ATAN2: "jnp.arctan2(work[{0}], work[{1}])",
OP_CONST: "{0:.16f}",
OP_INPUT: "inputs[{0}, {1}, {2}]",
OP_INPUT: "inputs[{0}][{1}, {2}]",
OP_OUTPUT: "work[{0}][0]",
}
2 changes: 1 addition & 1 deletion jaxadi/_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def translate(func: Function, add_jit=False, add_import=False) -> str:
codegen += "@jax.jit\n" if add_jit else ""
codegen += f"def evaluate_{func.name()}(*args):\n"
# combine all inputs into a single list
codegen += " inputs = jnp.expand_dims(jnp.array(args), axis=-1)\n"
codegen += " inputs = [jnp.expand_dims(jnp.array(arg), axis=-1) for arg in args]\n"
# output variables
codegen += f" outputs = [jnp.zeros(out) for out in {out_shapes}]\n"

Expand Down
18 changes: 18 additions & 0 deletions tests/test_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import casadi as cs
import jax.numpy as jnp
import numpy as np

from jaxadi import convert


def test_different_shapes():
x = cs.SX.sym("x", 2, 3)
y = cs.SX.sym("y", 3, 2)
casadi_fn = cs.Function("myfunc", [x, y], [x @ y])

jax_fn = convert(casadi_fn, compile=True)

in1 = jnp.array(np.random.randn(2, 3))
in2 = jnp.array(np.random.randn(3, 2))

jax_fn(in1, in2)

0 comments on commit a6742c7

Please sign in to comment.