Skip to content

Commit d5f236d

Browse files
mattephisimeon-nedlvjonok
authored
Migrate to new translation function (#18)
* feat: optimisiing comp graph * feat: graph tranlsation * Feat/pendulum rollout example (#9) * added pendulum rollout as example * fixed README * fix: rebase * Update README.md Added hyperlink for landing * fix: update branch for colab benchmark * added demo notebook (#13) * added demo notebook * minor readme change * fix: rebase * fix: rebase and fix tests * fix: pre-commit * feat: graph expansion * fix: benchmarks for more powers * fix: disable graph compression * fix: densify structural zeros * fix: adaprive dimensionality #16 * fix: translate as graph_translate * feat: test expand, docs * feat: test examples * fix: exclude running examples from workflow * fix: remove ast debug files * fix: organize imports * fix: ops fix * fix: update plots for graph translation * fix: add plots to website --------- Co-authored-by: Simeon Nedelchev <[email protected]> Co-authored-by: Lev Kozlov <[email protected]>
1 parent 59316b8 commit d5f236d

24 files changed

+1186
-174
lines changed

.github/workflows/build.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
python -m pip install --upgrade pip
3939
pip install -e .[dev]
4040
- name: Run tests
41-
run: python -m pytest
41+
run: python -m pytest --ignore=tests/test_examples.py
4242

4343
deploy:
4444
runs-on: ubuntu-latest

benchmarks/cuda_benchmark_results.npz

-1.45 KB
Binary file not shown.
-1.45 KB
Binary file not shown.

benchmarks/run_benchmark.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class Paths:
1818
FUNCTIONS_DIR = os.path.join(cur_dir, "data")
1919
OUTPUT_DIR = cur_dir
20-
RUN_CUSADI = False
20+
RUN_CUSADI = True
2121

2222

2323
class ColabPaths:
@@ -43,11 +43,11 @@ class ColabPaths:
4343

4444
torch.manual_seed(0)
4545

46-
N_ENVS_SWEEP = [2**i for i in range(21)]
47-
N_EVALS = 20
46+
N_ENVS_SWEEP = [2**i for i in range(15)]
47+
N_EVALS = 10
4848

4949
# Load functions for CUDA benchmarking
50-
fn_files = ["fn_1e1.casadi", "fn_1e2.casadi"]
50+
fn_files = ["fn_1e1.casadi", "fn_1e2.casadi", "fn_1e3.casadi", "fn_1e4.casadi"]
5151
benchmark_fns = [Function.load(os.path.join(PathsProvider.FUNCTIONS_DIR, fn)) for fn in fn_files]
5252

5353

benchmarks/visualize.ipynb

+176-12
Large diffs are not rendered by default.

docs/index.html

+6
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ <h2 class="title is-3">Benchmarks</h2>
166166
<div class="item item-chair-tp">
167167
<img src="./static/images/compare_1e2_bar.png"/>
168168
</div>
169+
<div class="item item-chair-tp">
170+
<img src="./static/images/compare_1e3_bar.png"/>
171+
</div>
172+
<div class="item item-chair-tp">
173+
<img src="./static/images/compare_1e4_bar.png"/>
174+
</div>
169175
<div class="item item-steve">
170176
<img src="./static/images/speedup_ratio.png"/>
171177
</div>
9.22 KB
Loading
3.19 KB
Loading
32.1 KB
Loading
29.6 KB
Loading

docs/static/images/speedup_ratio.png

19.4 KB
Loading

examples/00_translate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import casadi as cs
1010

11-
from jaxadi import translate
11+
from jaxadi import graph_translate as translate
1212

1313
# define input variables for the function
1414
x = cs.SX.sym("x", 3)

examples/01_lower.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import casadi as cs
22

3-
from jaxadi import lower, translate, declare
3+
from jaxadi import lower, declare
4+
from jaxadi import graph_translate as translate
45

56
# define input variables for the function
6-
x = cs.SX.sym("x", 10, 10)
7-
y = cs.SX.sym("y", 10, 10)
8-
casadi_fn = cs.Function("myfunc", [x, y], [x @ y])
7+
x = cs.SX.sym("x", 2, 1)
8+
y = cs.SX.sym("y", 2, 1)
9+
casadi_fn = cs.Function("myfunc", [x, y], [x.T @ y])
910

1011
print("Signature of the CasADi function:")
1112
print(casadi_fn)
1213

1314
# define jax function from casadi one
1415
jax_fn = declare(translate(casadi_fn))
1516

17+
print(translate(casadi_fn))
18+
1619
print("Lowered JAX function:")
1720
print(lower(jax_fn, casadi_fn).as_text())

examples/04_pinocchio.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import pinocchio as pin
1717
import pinocchio.casadi as cpin
1818
from robot_descriptions.panda_description import URDF_PATH
19-
from jaxadi import convert, translate
19+
from jaxadi import convert
20+
from jaxadi import graph_translate as translate
2021

2122
# Load the Panda robot model
2223
model = pin.buildModelFromUrdf(URDF_PATH)

jaxadi/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._compile import lower
22
from ._convert import convert
3-
from ._translate import translate
43
from ._declare import declare
4+
from ._expand import translate as expand_translate
5+
from ._graph import translate as graph_translate

jaxadi/_convert.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
from casadi import Function
2-
from typing import Any
31
from collections.abc import Callable
2+
from typing import Any
3+
4+
from casadi import Function
45

5-
from ._declare import declare
6-
from ._translate import translate
76
from ._compile import compile as compile_fn
7+
from ._declare import declare
8+
from ._graph import translate as graph_translate
9+
from ._preprocess import densify
810

911

10-
def convert(casadi_fn: Function, compile=False, num_threads=1) -> Callable[..., Any]:
12+
def convert(casadi_fn: Function, translate=None, compile=False) -> Callable[..., Any]:
1113
"""
1214
Convert given casadi function into python
1315
callable based on JAX backend, optionally
@@ -17,7 +19,12 @@ def convert(casadi_fn: Function, compile=False, num_threads=1) -> Callable[...,
1719
:param compile (bool): Whether to AOT compile function
1820
:return (Callable[..., Any]): Resulting python function
1921
"""
20-
jax_str = translate(casadi_fn, num_threads=num_threads)
22+
if translate is None:
23+
translate = graph_translate
24+
25+
casadi_fn = densify(casadi_fn)
26+
27+
jax_str = translate(casadi_fn)
2128
jax_fn = declare(jax_str)
2229

2330
if compile:

0 commit comments

Comments
 (0)