Skip to content

Commit ccb7f27

Browse files
authored
Merge pull request #352 from ami-iit/sprint/gpu-transfert
[Sprint] Fix device transfer and exceptions handling
2 parents 15e402e + 7aff901 commit ccb7f27

File tree

7 files changed

+17
-14
lines changed

7 files changed

+17
-14
lines changed

docs/guide/configuration.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ The logging and exceptions configurations is controlled by the following environ
6161

6262
*Default:* ``DEBUG`` for development, ``WARNING`` for production.
6363

64-
- ``JAXSIM_DISABLE_EXCEPTIONS``: Disables the runtime checks and exceptions.
64+
- ``JAXSIM_ENABLE_EXCEPTIONS``: Enables the runtime checks and exceptions. Note that enabling exceptions might lead to device-to-host transfer of data, increasing the computational time required.
6565

6666
*Default:* ``False``.
6767

src/jaxsim/api/model.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class JaxSimModel(JaxsimDataclass):
3333

3434
model_name: Static[str]
3535

36-
time_step: jtp.FloatLike = dataclasses.field(
37-
default_factory=lambda: jnp.array(0.001, dtype=float),
36+
time_step: float = dataclasses.field(
37+
default=0.001,
3838
)
3939

4040
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
@@ -91,7 +91,7 @@ def __hash__(self) -> int:
9191
return hash(
9292
(
9393
hash(self.model_name),
94-
hash(float(self.time_step)),
94+
hash(self.time_step),
9595
hash(self.kin_dyn_parameters),
9696
hash(self.contact_model),
9797
)
@@ -222,7 +222,7 @@ def build(
222222
time_step = (
223223
time_step
224224
if time_step is not None
225-
else JaxSimModel.__dataclass_fields__["time_step"].default_factory()
225+
else JaxSimModel.__dataclass_fields__["time_step"].default
226226
)
227227

228228
# Create the default contact model.
@@ -317,7 +317,7 @@ def floating_base(self) -> bool:
317317
True if the model is floating-base, False otherwise.
318318
"""
319319

320-
return bool(self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6)
320+
return self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6
321321

322322
def base_link(self) -> str:
323323
"""
@@ -348,7 +348,7 @@ def dofs(self) -> int:
348348
the number of joints. In the future, this could be different.
349349
"""
350350

351-
return int(sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:]))
351+
return sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:])
352352

353353
def joint_names(self) -> tuple[str, ...]:
354354
"""
@@ -431,7 +431,7 @@ def reduce(
431431
for joint_name in set(model.joint_names()) - set(considered_joints):
432432
j = intermediate_description.joints_dict[joint_name]
433433
with j.mutable_context():
434-
j.initial_position = float(locked_joint_positions.get(joint_name, 0.0))
434+
j.initial_position = locked_joint_positions.get(joint_name, 0.0)
435435

436436
# Reduce the model description.
437437
# If `considered_joints` contains joints not existing in the model,

src/jaxsim/exceptions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def raise_if(
2323

2424
# Disable host callback if running on unsupported hardware or if the user
2525
# explicitly disabled it.
26-
if jax.devices()[0].platform in {"tpu", "METAL"} or os.environ.get(
27-
"JAXSIM_DISABLE_EXCEPTIONS", 0
26+
if jax.devices()[0].platform in {"tpu", "METAL"} or not os.environ.get(
27+
"JAXSIM_ENABLE_EXCEPTIONS", 0
2828
):
2929
return
3030

tests/conftest.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import os
2+
3+
os.environ["JAXSIM_ENABLE_EXCEPTIONS"] = "1"
4+
25
import pathlib
36
import subprocess
47

tests/test_api_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_model_creation_and_reduction(
8181
locked_joint_positions=dict(
8282
zip(
8383
model_full.joint_names(),
84-
data_full.joint_positions,
84+
data_full.joint_positions.tolist(),
8585
strict=True,
8686
)
8787
),

tests/test_automatic_differentiation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def step(
344344
model=model,
345345
data=data_x0,
346346
joint_force_references=τ,
347-
link_forces=W_f_L,
347+
link_forces_inertial=W_f_L,
348348
)
349349

350350
xf_W_p_B = data_xf.base_position

tests/test_simulations.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_box_with_external_forces(
7474
data = js.model.step(
7575
model=model,
7676
data=data,
77-
link_forces=references.link_forces(model=model, data=data),
77+
link_forces_inertial=references._link_forces,
7878
)
7979

8080
# Check that the box didn't move.
@@ -148,7 +148,7 @@ def test_box_with_zero_gravity(
148148
data = js.model.step(
149149
model=model,
150150
data=data,
151-
link_forces=references.link_forces(model=model, data=data),
151+
link_forces_inertial=references.link_forces(model=model, data=data),
152152
)
153153

154154
# Check that the box moved as expected.

0 commit comments

Comments
 (0)