Skip to content

Commit 7d17335

Browse files
authored
Merge branch 'main' into solver-uvs
2 parents 7ff2ffa + 1b5aadf commit 7d17335

File tree

7 files changed

+23
-25
lines changed

7 files changed

+23
-25
lines changed

genesis/engine/sensors/base_sensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class SharedSensorMetadata:
4848
"""
4949

5050
cache_sizes: list[int] = field(default_factory=list)
51-
delays_ts: torch.Tensor = make_tensor_field((0, 0), dtype_factory=lambda: gs.tc_int)
51+
delays_ts: torch.Tensor = make_tensor_field((0, 0), dtype=gs.tc_int)
5252

5353
def __del__(self):
5454
try:
@@ -329,7 +329,7 @@ class RigidSensorMetadataMixin:
329329
"""
330330

331331
solver: "RigidSolver | None" = None
332-
links_idx: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_int)
332+
links_idx: torch.Tensor = make_tensor_field((0,), dtype=gs.tc_int)
333333
offsets_pos: torch.Tensor = make_tensor_field((0, 0, 3))
334334
offsets_quat: torch.Tensor = make_tensor_field((0, 0, 4))
335335

genesis/engine/sensors/contact_force.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class ContactSensorMetadata(SharedSensorMetadata):
7676
"""
7777

7878
solver: "RigidSolver | None" = None
79-
expanded_links_idx: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_int)
79+
expanded_links_idx: torch.Tensor = make_tensor_field((0,), dtype=gs.tc_int)
8080

8181

8282
@register_sensor(ContactSensorOptions, ContactSensorMetadata, tuple)

genesis/engine/sensors/imu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ class IMUSharedMetadata(RigidSensorMetadataMixin, NoisySensorMetadataMixin, Shar
7070

7171
alignment_rot_matrix: torch.Tensor = make_tensor_field((0, 0, 3, 3))
7272
magnetic_field_vector: torch.Tensor = make_tensor_field((0, 0, 3)) # added another dimension to match data layout
73-
acc_indices: torch.Tensor = make_tensor_field((0, 0), dtype_factory=lambda: gs.tc_int)
74-
gyro_indices: torch.Tensor = make_tensor_field((0, 0), dtype_factory=lambda: gs.tc_int)
75-
mag_indices: torch.Tensor = make_tensor_field((0, 0), dtype_factory=lambda: gs.tc_int)
73+
acc_indices: torch.Tensor = make_tensor_field((0, 0), dtype=gs.tc_int)
74+
gyro_indices: torch.Tensor = make_tensor_field((0, 0), dtype=gs.tc_int)
75+
mag_indices: torch.Tensor = make_tensor_field((0, 0), dtype=gs.tc_int)
7676

7777

7878
class IMUData(NamedTuple):

genesis/engine/sensors/raycaster.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,18 +293,18 @@ class RaycasterSharedMetadata(RigidSensorMetadataMixin, SharedSensorMetadata):
293293
min_ranges: torch.Tensor = make_tensor_field((0,))
294294
max_ranges: torch.Tensor = make_tensor_field((0,))
295295
no_hit_values: torch.Tensor = make_tensor_field((0,))
296-
return_world_frame: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_bool)
296+
return_world_frame: torch.Tensor = make_tensor_field((0,), dtype=gs.tc_bool)
297297

298298
patterns: list[RaycastPattern] = field(default_factory=list)
299299
ray_dirs: torch.Tensor = make_tensor_field((0, 3))
300300
ray_starts: torch.Tensor = make_tensor_field((0, 3))
301301
ray_starts_world: torch.Tensor = make_tensor_field((0, 3))
302302
ray_dirs_world: torch.Tensor = make_tensor_field((0, 3))
303303

304-
points_to_sensor_idx: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_int)
305-
sensor_cache_offsets: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_int)
306-
sensor_point_offsets: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_int)
307-
sensor_point_counts: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_int)
304+
points_to_sensor_idx: torch.Tensor = make_tensor_field((0,), dtype=gs.tc_int)
305+
sensor_cache_offsets: torch.Tensor = make_tensor_field((0,), dtype=gs.tc_int)
306+
sensor_point_offsets: torch.Tensor = make_tensor_field((0,), dtype=gs.tc_int)
307+
sensor_point_counts: torch.Tensor = make_tensor_field((0,), dtype=gs.tc_int)
308308

309309

310310
class RaycasterData(NamedTuple):

genesis/engine/solvers/rigid/collider/gjk.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,10 @@ def func_gjk_contact(
223223
)
224224

225225
if shrink_sphere:
226-
# If we shrinked the sphere and capsule to point and line and the distance is larger than the
227-
# collision epsilon, it means a shallow penetration. Thus we subtract the radius of the sphere and
228-
# the capsule to get the actual distance. If the distance is smaller than the collision epsilon, it
229-
# means a deep penetration, which requires the default GJK handling.
226+
# If we shrunk the sphere and capsule to point and line and the distance is larger than the collision
227+
# epsilon, it means a shallow penetration. Thus we subtract the radius of the sphere and the capsule to
228+
# get the actual distance. If the distance is smaller than the collision epsilon, it means a deep
229+
# penetration, which requires the default GJK handling.
230230
if distance > gjk_info.collision_eps[None]:
231231
radius_a, radius_b = 0.0, 0.0
232232
if is_sphere_swept_geom_a:

genesis/engine/solvers/rigid/collider/narrowphase.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def func_convex_convex_contact(
614614

615615
if multi_contact and is_col_0:
616616
# Perturbation axis must not be aligned with the principal axes of inertia the geometry,
617-
# otherwise it would be more sensitive to ill-conditionning.
617+
# otherwise it would be more sensitive to ill-conditioning.
618618
axis = (2 * (i_detection % 2) - 1) * axis_0 + (1 - 2 * ((i_detection // 2) % 2)) * axis_1
619619
qrot = gu.ti_rotvec_to_quat(collider_info.mc_perturbation[None] * axis, EPS)
620620
func_rotate_frame(i_ga, contact_pos_0, qrot, i_b, geoms_state, geoms_info)
@@ -863,8 +863,8 @@ def func_convex_convex_contact(
863863
# First-order correction of the normal direction.
864864
# The way the contact normal gets twisted by applying perturbation of geometry poses is
865865
# unpredictable as it depends on the final portal discovered by MPR. Alternatively, let compute
866-
# the mininal rotation that makes the corrected twisted normal as closed as possible to the
867-
# original one, up to the scale of the perturbation, then apply first-order Taylor expension of
866+
# the minimal rotation that makes the corrected twisted normal as closed as possible to the
867+
# original one, up to the scale of the perturbation, then apply first-order Taylor expansion of
868868
# Rodrigues' rotation formula.
869869
twist_rotvec = ti.math.clamp(
870870
normal.cross(normal_0),

genesis/utils/misc.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -343,24 +343,22 @@ def concat_with_tensor(
343343
return torch.cat([tensor, value], dim=dim)
344344

345345

346-
def make_tensor_field(shape: tuple[int, ...] = (), dtype_factory: Callable[[], torch.dtype] | None = None):
346+
def make_tensor_field(shape: tuple[int, ...] = (), dtype: torch.dtype | None = None):
347347
"""
348348
Helper method to create a tensor field for dataclasses.
349349
350350
Parameters
351351
----------
352352
shape : tuple
353353
The shape of the tensor field. It must have zero elements, otherwise it will trigger an exception.
354-
dtype_factory : Callable[[], torch.dtype], optional
355-
The factory function to create the dtype of the tensor field. Default is gs.tc_float.
356-
A factory is used because gs types may not be available at the time of field creation.
354+
dtype : torch.dtype, optional
355+
Data type of the tensor field. Default is gs.tc_float.
357356
"""
358357
assert not shape or math.prod(shape) == 0
359358

360359
def _default_factory():
361-
nonlocal shape, dtype_factory
362-
dtype = dtype_factory() if dtype_factory is not None else gs.tc_float
363-
return torch.empty(shape, dtype=dtype, device=gs.device)
360+
nonlocal shape, dtype
361+
return torch.empty(shape, dtype=dtype or gs.tc_float, device=gs.device)
364362

365363
return field(default_factory=_default_factory)
366364

0 commit comments

Comments
 (0)