Skip to content

Commit 8f9e275

Browse files
committed
Remove wrappers module and custom hash methods
1 parent ccb7f27 commit 8f9e275

File tree

15 files changed

+330
-665
lines changed

15 files changed

+330
-665
lines changed

src/jaxsim/api/kin_dyn_parameters.py

+13-33
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import jaxsim.typing as jtp
1313
from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion
1414
from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription
15-
from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
15+
from jaxsim.utils import JaxsimDataclass
1616

1717

1818
@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
@@ -34,9 +34,9 @@ class KinDynParameters(JaxsimDataclass):
3434

3535
# Static
3636
link_names: Static[tuple[str]]
37-
_parent_array: Static[HashedNumpyArray]
38-
_support_body_array_bool: Static[HashedNumpyArray]
39-
_motion_subspaces: Static[HashedNumpyArray]
37+
_parent_array: Static[tuple[int]]
38+
_support_body_array_bool: Static[tuple[int]]
39+
_motion_subspaces: Static[tuple[float]]
4040

4141
# Links
4242
link_parameters: LinkParameters
@@ -56,21 +56,21 @@ def motion_subspaces(self) -> jtp.Matrix:
5656
r"""
5757
Return the motion subspaces :math:`\mathbf{S}(s)` of the joints.
5858
"""
59-
return self._motion_subspaces.get()
59+
return jnp.array(self._motion_subspaces, dtype=float)
6060

6161
@property
6262
def parent_array(self) -> jtp.Vector:
6363
r"""
6464
Return the parent array :math:`\lambda(i)` of the model.
6565
"""
66-
return self._parent_array.get()
66+
return jnp.array(self._parent_array, dtype=int)
6767

6868
@property
6969
def support_body_array_bool(self) -> jtp.Matrix:
7070
r"""
7171
Return the boolean support parent array :math:`\kappa_{b}(i)` of the model.
7272
"""
73-
return self._support_body_array_bool.get()
73+
return jnp.array(self._support_body_array_bool, dtype=int)
7474

7575
@staticmethod
7676
def build(model_description: ModelDescription) -> KinDynParameters:
@@ -227,8 +227,8 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:
227227

228228
S = {
229229
JointType.Fixed: np.zeros(shape=(6, 1)),
230-
JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis.axis])),
231-
JointType.Prismatic: np.vstack(np.hstack([axis.axis, np.zeros(3)])),
230+
JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis])),
231+
JointType.Prismatic: np.vstack(np.hstack([axis, np.zeros(3)])),
232232
}
233233

234234
return S[joint_type]
@@ -254,36 +254,16 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:
254254

255255
return KinDynParameters(
256256
link_names=tuple(l.name for l in ordered_links),
257-
_parent_array=HashedNumpyArray(array=parent_array),
258-
_support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),
259-
_motion_subspaces=HashedNumpyArray(array=motion_subspaces),
257+
_parent_array=tuple(parent_array.tolist()),
258+
_support_body_array_bool=tuple(support_body_array_bool.tolist()),
259+
_motion_subspaces=tuple(motion_subspaces.tolist()),
260260
link_parameters=link_parameters,
261261
joint_model=joint_model,
262262
joint_parameters=joint_parameters,
263263
contact_parameters=contact_parameters,
264264
frame_parameters=frame_parameters,
265265
)
266266

267-
def __eq__(self, other: KinDynParameters) -> bool:
268-
269-
if not isinstance(other, KinDynParameters):
270-
return False
271-
272-
return hash(self) == hash(other)
273-
274-
def __hash__(self) -> int:
275-
276-
return hash(
277-
(
278-
hash(self.number_of_links()),
279-
hash(self.number_of_joints()),
280-
hash(self.frame_parameters.name),
281-
hash(self.frame_parameters.body),
282-
hash(self._parent_array),
283-
hash(self._support_body_array_bool),
284-
)
285-
)
286-
287267
# =============================
288268
# Helpers to extract parameters
289269
# =============================
@@ -409,7 +389,7 @@ def joint_transforms(
409389
pre_H_suc_J = jax.vmap(supported_joint_motion)(
410390
joint_types=jnp.array(self.joint_model.joint_types[1:]).astype(int),
411391
joint_positions=jnp.array(joint_positions),
412-
joint_axes=jnp.array([j.axis for j in self.joint_model.joint_axis]),
392+
joint_axes=jnp.array(self.joint_model.joint_axis),
413393
)
414394

415395
# Extract the transforms and motion subspaces of the joints.

src/jaxsim/api/model.py

+10-36
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import jaxsim.typing as jtp
1919
from jaxsim.math import Adjoint, Cross
2020
from jaxsim.parsers.descriptions import ModelDescription
21-
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers
21+
from jaxsim.utils import JaxsimDataclass, Mutability
2222

2323
from .common import VelRepr
2424

@@ -59,43 +59,16 @@ class JaxSimModel(JaxsimDataclass):
5959
default=None, repr=False
6060
)
6161

62-
_description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
63-
dataclasses.field(default=None, repr=False)
62+
_description: Static[ModelDescription | None] = dataclasses.field(
63+
default=None, repr=False
6464
)
6565

6666
@property
6767
def description(self) -> ModelDescription:
6868
"""
6969
Return the model description.
7070
"""
71-
return self._description.get()
72-
73-
def __eq__(self, other: JaxSimModel) -> bool:
74-
75-
if not isinstance(other, JaxSimModel):
76-
return False
77-
78-
if self.model_name != other.model_name:
79-
return False
80-
81-
if self.time_step != other.time_step:
82-
return False
83-
84-
if self.kin_dyn_parameters != other.kin_dyn_parameters:
85-
return False
86-
87-
return True
88-
89-
def __hash__(self) -> int:
90-
91-
return hash(
92-
(
93-
hash(self.model_name),
94-
hash(self.time_step),
95-
hash(self.kin_dyn_parameters),
96-
hash(self.contact_model),
97-
)
98-
)
71+
return self._description
9972

10073
# ========================
10174
# Initialization and state
@@ -252,7 +225,7 @@ def build(
252225
# don't want to trigger recompilation if it changes. All relevant parameters
253226
# needed to compute kinematics and dynamics quantities are stored in the
254227
# kin_dyn_parameters attribute.
255-
_description=wrappers.HashlessObject(obj=model_description),
228+
_description=model_description,
256229
)
257230

258231
return model
@@ -423,15 +396,16 @@ def reduce(
423396

424397
# Operate on a deep copy of the model description in order to prevent problems
425398
# when mutable attributes are updated.
426-
intermediate_description = copy.deepcopy(model.description)
399+
intermediate_description = copy.deepcopy(model._description)
427400

428401
# Update the initial position of the joints.
429402
# This is necessary to compute the correct pose of the link pairs connected
430403
# to removed joints.
431404
for joint_name in set(model.joint_names()) - set(considered_joints):
432-
j = intermediate_description.joints_dict[joint_name]
433-
with j.mutable_context():
434-
j.initial_position = locked_joint_positions.get(joint_name, 0.0)
405+
intermediate_description.joints_dict[joint_name] = dataclasses.replace(
406+
intermediate_description.joints_dict[joint_name],
407+
initial_position=float(locked_joint_positions.get(joint_name, 0.0)),
408+
)
435409

436410
# Reduce the model description.
437411
# If `considered_joints` contains joints not existing in the model,

src/jaxsim/math/joint_model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import jaxsim.typing as jtp
1010
from jaxsim.math import Rotation
11-
from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription
11+
from jaxsim.parsers.descriptions import JointType, ModelDescription
1212
from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms
1313

1414

@@ -39,7 +39,7 @@ class JointModel:
3939
joint_dofs: Static[tuple[int, ...]]
4040
joint_names: Static[tuple[str, ...]]
4141
joint_types: Static[tuple[int, ...]]
42-
joint_axis: Static[tuple[JointGenericAxis, ...]]
42+
joint_axis: Static[tuple[tuple[int]]]
4343

4444
@staticmethod
4545
def build(description: ModelDescription) -> JointModel:
@@ -108,7 +108,7 @@ def build(description: ModelDescription) -> JointModel:
108108
joint_dofs=tuple([base_dofs] + [1 for _ in ordered_joints]),
109109
joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
110110
joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
111-
joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
111+
joint_axis=tuple(j.axis for j in ordered_joints),
112112
)
113113

114114
def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix:

src/jaxsim/parsers/descriptions/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
MeshCollision,
66
SphereCollision,
77
)
8-
from .joint import JointDescription, JointGenericAxis, JointType
8+
from .joint import JointDescription, JointType
99
from .link import LinkDescription
1010
from .model import ModelDescription

src/jaxsim/parsers/descriptions/collision.py

+29-70
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
import abc
44
import dataclasses
55

6-
import jax.numpy as jnp
76
import numpy as np
87
import numpy.typing as npt
98

10-
import jaxsim.typing as jtp
119
from jaxsim import logging
1210

1311
from .link import LinkDescription
@@ -25,8 +23,28 @@ class CollidablePoint:
2523
"""
2624

2725
parent_link: LinkDescription
28-
position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3))
2926
enabled: bool = True
27+
_position: tuple[float] = dataclasses.field(default=(0.0, 0.0, 0.0))
28+
29+
@property
30+
def position(self) -> npt.NDArray:
31+
"""
32+
Get the position of the collidable point.
33+
34+
Returns:
35+
The position of the collidable point.
36+
"""
37+
return np.array(self._position)
38+
39+
@position.setter
40+
def position(self, value: npt.NDArray) -> None:
41+
"""
42+
Set the position of the collidable point.
43+
44+
Args:
45+
value: The new position of the collidable point.
46+
"""
47+
self._position = tuple(value.tolist())
3048

3149
def change_link(
3250
self, new_link: LinkDescription, new_H_old: npt.NDArray
@@ -35,8 +53,8 @@ def change_link(
3553
Move the collidable point to a new parent link.
3654
3755
Args:
38-
new_link (LinkDescription): The new parent link to which the collidable point is moved.
39-
new_H_old (npt.NDArray): The transformation matrix from the new link's frame to the old link's frame.
56+
new_link: The new parent link to which the collidable point is moved.
57+
new_H_old: The transformation matrix from the new link's frame to the old link's frame.
4058
4159
Returns:
4260
CollidablePoint: A new collidable point associated with the new parent link.
@@ -47,27 +65,12 @@ def change_link(
4765

4866
return CollidablePoint(
4967
parent_link=new_link,
50-
position=(new_H_old @ jnp.hstack([self.position, 1.0])).squeeze()[0:3],
68+
_position=tuple(
69+
(new_H_old @ np.hstack([self.position, 1.0])).squeeze()[0:3].tolist()
70+
),
5171
enabled=self.enabled,
5272
)
5373

54-
def __hash__(self) -> int:
55-
56-
return hash(
57-
(
58-
hash(self.parent_link),
59-
hash(tuple(self.position.tolist())),
60-
hash(self.enabled),
61-
)
62-
)
63-
64-
def __eq__(self, other: CollidablePoint) -> bool:
65-
66-
if not isinstance(other, CollidablePoint):
67-
return False
68-
69-
return hash(self) == hash(other)
70-
7174
def __str__(self) -> str:
7275
return (
7376
f"{self.__class__.__name__}("
@@ -107,22 +110,7 @@ class BoxCollision(CollisionShape):
107110
center: The center of the box in the local frame of the collision shape.
108111
"""
109112

110-
center: jtp.VectorLike
111-
112-
def __hash__(self) -> int:
113-
return hash(
114-
(
115-
hash(super()),
116-
hash(tuple(self.center.tolist())),
117-
)
118-
)
119-
120-
def __eq__(self, other: BoxCollision) -> bool:
121-
122-
if not isinstance(other, BoxCollision):
123-
return False
124-
125-
return hash(self) == hash(other)
113+
center: tuple[float, float, float]
126114

127115

128116
@dataclasses.dataclass
@@ -134,22 +122,7 @@ class SphereCollision(CollisionShape):
134122
center: The center of the sphere in the local frame of the collision shape.
135123
"""
136124

137-
center: jtp.VectorLike
138-
139-
def __hash__(self) -> int:
140-
return hash(
141-
(
142-
hash(super()),
143-
hash(tuple(self.center.tolist())),
144-
)
145-
)
146-
147-
def __eq__(self, other: BoxCollision) -> bool:
148-
149-
if not isinstance(other, BoxCollision):
150-
return False
151-
152-
return hash(self) == hash(other)
125+
center: tuple[float, float, float]
153126

154127

155128
@dataclasses.dataclass
@@ -161,18 +134,4 @@ class MeshCollision(CollisionShape):
161134
center: The center of the mesh in the local frame of the collision shape.
162135
"""
163136

164-
center: jtp.VectorLike
165-
166-
def __hash__(self) -> int:
167-
return hash(
168-
(
169-
hash(tuple(self.center.tolist())),
170-
hash(self.collidable_points),
171-
)
172-
)
173-
174-
def __eq__(self, other: MeshCollision) -> bool:
175-
if not isinstance(other, MeshCollision):
176-
return False
177-
178-
return hash(self) == hash(other)
137+
center: tuple[float, float, float]

0 commit comments

Comments
 (0)