Skip to content

Commit

Permalink
feat: add optional imports and jit
Browse files Browse the repository at this point in the history
  • Loading branch information
lvjonok committed Sep 3, 2024
1 parent e9de0d7 commit c332e32
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/00_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
print(casadi_function)

print("Translated JAX function:")
for cg_str in translate(casadi_function):
print(cg_str)
# secure add_import and add_jit to True to get the complete code
print(translate(casadi_function, add_import=True, add_jit=True))
8 changes: 4 additions & 4 deletions jaxadi/_translate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import textwrap

from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function

from ._ops import OP_JAX_DICT


def translate(func: Function) -> list[str]:
def translate(func: Function, add_jit=False, add_import=False) -> list[str]:
# Get information about Casadi function
n_instr = func.n_instructions()
n_out = func.n_out() # number of outputs in the function
Expand All @@ -23,7 +21,9 @@ def translate(func: Function) -> list[str]:

# generate string with complete code
codegen = ""
# codegen += "@jax.jit\n"
if add_import:
codegen += "import jax\nimport jax.numpy as jnp\n\n"
codegen += "@jax.jit\n" if add_jit else ""
codegen += f"def evaluate_{func.name()}(*args):\n"
codegen += " inputs = args\n" # combine all inputs into a single list
codegen += f" outputs = [jnp.zeros(out) for out in {out_shapes}]\n" # output variables
Expand Down

0 comments on commit c332e32

Please sign in to comment.