Skip to content

Commit 2cb6e9c

Browse files
authored
Merge pull request #150 from ichumuh/fast_vibes
Faster state synchronization
2 parents 7f1e218 + d15a5f6 commit 2cb6e9c

5 files changed

Lines changed: 175 additions & 22 deletions

File tree

src/semantic_digital_twin/adapters/ros/world_synchronizer.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -229,17 +229,28 @@ def world_callback(self):
229229
self.publish(msg)
230230

231231
def compute_state_changes(self) -> Dict[UUID, float]:
232-
changes = {
233-
_id: current_state
234-
for _id, current_state in zip(
235-
self.world.state.keys(), self.world.state.positions
236-
)
237-
if _id not in self.previous_world_state_data
238-
or not np.allclose(
239-
current_state, self.previous_world_state_data[_id].position
240-
)
241-
}
242-
return changes
232+
"""
233+
Compute and return only the position changes since the last published snapshot.
234+
235+
Returns a mapping of DOF name to current position for entries whose position
236+
differs from the previous snapshot, using a vectorized tolerance-based diff.
237+
"""
238+
ids = self.world.state.keys() # List[PrefixedName] in column order
239+
curr = self.world.state.positions # np.ndarray shape (N,)
240+
prev = self.previous_world_state_data # np.ndarray shape (N,)
241+
242+
# If the number of DOFs changed (model update), send everything once
243+
# so the other side can resync, then the snapshot will be updated afterward.
244+
if prev.shape != curr.shape:
245+
return {n: float(v) for n, v in zip(ids, curr)}
246+
247+
# Vectorized comparison: O(N) with minimal Python overhead
248+
changed_mask = ~np.isclose(curr, prev, rtol=1e-8, atol=1e-12, equal_nan=True)
249+
if not np.any(changed_mask):
250+
return {}
251+
252+
idx = np.nonzero(changed_mask)[0]
253+
return {ids[i]: float(curr[i]) for i in idx}
243254

244255

245256
@dataclass
@@ -259,10 +270,15 @@ def __post_init__(self):
259270
SynchronizerOnCallback.__post_init__(self)
260271

261272
def apply_message(self, msg: ModificationBlock):
262-
for callback in self.world.state.state_change_callbacks:
273+
running_callbacks = [
274+
callback
275+
for callback in self.world.state.state_change_callbacks
276+
if not callback._is_paused
277+
]
278+
for callback in running_callbacks:
263279
callback.pause()
264280
msg.modifications.apply(self.world)
265-
for callback in self.world.state.state_change_callbacks:
281+
for callback in running_callbacks:
266282
callback.resume()
267283

268284
def world_callback(self):

src/semantic_digital_twin/spatial_types/spatial_types.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,23 @@ def _setup_constant_result(self) -> None:
240240
self._function_evaluator()
241241
self._is_constant = True
242242

243+
def bind_args_to_memory_view(self, arg_idx: int, numpy_array: np.ndarray) -> None:
244+
"""
245+
Binds the arg at index arg_idx to the memoryview of a numpy_array.
246+
If your args keep the same memory across calls, you only need to bind them once.
247+
"""
248+
self._function_buffer.set_arg(arg_idx, memoryview(numpy_array))
249+
250+
def evaluate(self) -> Union[np.ndarray, sp.csc_matrix]:
251+
"""
252+
Evaluate the compiled function with the current args.
253+
"""
254+
self._function_evaluator()
255+
return self._out
256+
243257
def __call__(self, *args: np.ndarray) -> Union[np.ndarray, sp.csc_matrix]:
244258
"""
245-
Efficiently evaluate the compiled function with positional arguments, by directly writing the memory of the
259+
Efficiently evaluate the compiled function with positional arguments by directly writing the memory of the
246260
numpy arrays to the memoryview of the compiled function.
247261
Similarly, the result will be written to the output buffer and doesn't allocate new memory on each eval.
248262
@@ -262,9 +276,8 @@ def __call__(self, *args: np.ndarray) -> Union[np.ndarray, sp.csc_matrix]:
262276
actual_number_of_args,
263277
)
264278
for arg_idx, arg in enumerate(args):
265-
self._function_buffer.set_arg(arg_idx, memoryview(arg))
266-
self._function_evaluator()
267-
return self._out
279+
self.bind_args_to_memory_view(arg_idx, arg)
280+
return self.evaluate()
268281

269282
def call_with_kwargs(self, **kwargs: float) -> np.ndarray:
270283
"""
@@ -1183,6 +1196,10 @@ def fmod(a: ScalarData, b: ScalarData) -> Expression:
11831196
return Expression(ca.fmod(a, b))
11841197

11851198

1199+
def sum(*expressions: ScalarData) -> Expression:
1200+
return Expression(ca.sum(to_sx(Expression(expressions))))
1201+
1202+
11861203
def normalize_angle_positive(angle: ScalarData) -> Expression:
11871204
"""
11881205
Normalizes the angle to be 0 to 2*pi

test/test_casadi/test_casadi_wrapper.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3261,6 +3261,56 @@ def test_stacked_compiled_function_dense(self):
32613261
assert_allclose(actual_e1, expected_e1)
32623262
assert_allclose(actual_e2, expected_e2)
32633263

3264+
def test_single_args(self):
3265+
size = 10_000
3266+
variables = cas.create_float_variables([str(i) for i in range(size)])
3267+
expr = cas.sum(*variables)
3268+
f = expr.compile()
3269+
for i in range(10):
3270+
data = np.random.rand(size)
3271+
assert np.isclose(f(data), np.sum(data))
3272+
3273+
def test_single_args_with_bind(self):
3274+
size = 10_000
3275+
data = np.random.rand(size)
3276+
variables = cas.create_float_variables([str(i) for i in range(size)])
3277+
expr = cas.sum(*variables)
3278+
f = expr.compile()
3279+
f.bind_args_to_memory_view(0, data)
3280+
for i in range(10):
3281+
np.copyto(data, np.random.rand(size))
3282+
assert np.isclose(f.evaluate(), np.sum(data))
3283+
3284+
def test_multiple_args(self):
3285+
size = 10_000
3286+
n = 10
3287+
element_size = size // n
3288+
variables = cas.create_float_variables([str(i) for i in range(size)])
3289+
expr = cas.sum(*variables)
3290+
args = [variables[i * element_size : (i + 1) * element_size] for i in range(n)]
3291+
f = expr.compile(parameters=args)
3292+
for i in range(100):
3293+
args_values = [np.ones(element_size)] * n
3294+
assert f(*args_values) == size
3295+
3296+
def test_multiple_args_with_bind(self):
3297+
size = 10_000
3298+
n = 10
3299+
element_size = size // n
3300+
variables = cas.create_float_variables([str(i) for i in range(size)])
3301+
expr = cas.sum(*variables)
3302+
args = [variables[i * element_size : (i + 1) * element_size] for i in range(n)]
3303+
f = expr.compile(parameters=args)
3304+
3305+
datas = []
3306+
for i in range(n):
3307+
datas.append(np.random.rand(element_size))
3308+
f.bind_args_to_memory_view(i, datas[i])
3309+
for i in range(100):
3310+
for i in range(n):
3311+
datas[i][:] = np.random.rand(element_size)
3312+
assert np.isclose(f.evaluate(), np.sum(datas))
3313+
32643314
def test_missing_free_variables(self):
32653315
s1, s2 = cas.create_float_variables(["s1", "s2"])
32663316
e = cas.sqrt(cas.cos(s1) + cas.sin(s2))

test/test_ros/test_world_synchronizer.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import unittest
55
import uuid
66
from typing import Optional
7+
from uuid import UUID, uuid4
78

89
import numpy as np
910
import sqlalchemy
@@ -51,17 +52,37 @@ def deterministic_uuid(seed: str) -> uuid.UUID:
5152
w.add_degree_of_freedom(y_dof)
5253
z_dof = DegreeOfFreedom(name=PrefixedName("z"), id=deterministic_uuid("z_dof"))
5354
w.add_degree_of_freedom(z_dof)
54-
qx_dof = DegreeOfFreedom(name=PrefixedName("qx"), id=deterministic_uuid("qx_dof"))
55+
qx_dof = DegreeOfFreedom(
56+
name=PrefixedName("qx"), id=deterministic_uuid("qx_dof")
57+
)
5558
w.add_degree_of_freedom(qx_dof)
56-
qy_dof = DegreeOfFreedom(name=PrefixedName("qy"), id=deterministic_uuid("qy_dof"))
59+
qy_dof = DegreeOfFreedom(
60+
name=PrefixedName("qy"), id=deterministic_uuid("qy_dof")
61+
)
5762
w.add_degree_of_freedom(qy_dof)
58-
qz_dof = DegreeOfFreedom(name=PrefixedName("qz"), id=deterministic_uuid("qz_dof"))
63+
qz_dof = DegreeOfFreedom(
64+
name=PrefixedName("qz"), id=deterministic_uuid("qz_dof")
65+
)
5966
w.add_degree_of_freedom(qz_dof)
60-
qw_dof = DegreeOfFreedom(name=PrefixedName("qw"), id=deterministic_uuid("qw_dof"))
67+
qw_dof = DegreeOfFreedom(
68+
name=PrefixedName("qw"), id=deterministic_uuid("qw_dof")
69+
)
6170
w.add_degree_of_freedom(qw_dof)
6271
w.state[qw_dof.id].position = 1.0
6372

64-
w.add_connection(Connection6DoF(parent=b1, child=b2, x_id=x_dof.id, y_id=y_dof.id, z_id=z_dof.id, qx_id=qx_dof.id, qy_id=qy_dof.id, qz_id=qz_dof.id, qw_id=qw_dof.id))
73+
w.add_connection(
74+
Connection6DoF(
75+
parent=b1,
76+
child=b2,
77+
x_id=x_dof.id,
78+
y_id=y_dof.id,
79+
z_id=z_dof.id,
80+
qx_id=qx_dof.id,
81+
qy_id=qy_dof.id,
82+
qz_id=qz_dof.id,
83+
qw_id=qw_dof.id,
84+
)
85+
)
6586
return w
6687

6788

@@ -428,5 +449,49 @@ def test_synchronize_6dof(rclpy_node):
428449
np.testing.assert_array_almost_equal(w1.state.data, w2.state.data)
429450

430451

452+
def test_compute_state_changes_no_changes(rclpy_node):
453+
w = create_dummy_world()
454+
s = StateSynchronizer(node=rclpy_node, world=w)
455+
# Immediately compare without changing state
456+
changes = s.compute_state_changes()
457+
assert changes == {}
458+
s.close()
459+
460+
461+
def test_compute_state_changes_single_change(rclpy_node):
462+
w = create_dummy_world()
463+
s = StateSynchronizer(node=rclpy_node, world=w)
464+
# change first position
465+
w.state.data[0, 0] += 1e-3
466+
changes = s.compute_state_changes()
467+
names = w.state.keys()
468+
assert list(changes.keys()) == [names[0]]
469+
assert np.isclose(changes[names[0]], w.state.positions[0])
470+
s.close()
471+
472+
473+
def test_compute_state_changes_shape_change_full_snapshot(rclpy_node):
474+
w = create_dummy_world()
475+
s = StateSynchronizer(node=rclpy_node, world=w)
476+
# append a new DOF by writing a new name into state
477+
new_uuid = uuid4()
478+
w.state._add_dof(new_uuid)
479+
w.state[new_uuid] = np.zeros(4)
480+
changes = s.compute_state_changes()
481+
# full snapshot expected
482+
assert len(changes) == len(w.state)
483+
s.close()
484+
485+
486+
def test_compute_state_changes_nan_handling(rclpy_node):
487+
w = create_dummy_world()
488+
s = StateSynchronizer(node=rclpy_node, world=w)
489+
# set both previous and current to NaN for entry 0
490+
w.state.data[0, 0] = np.nan
491+
s.previous_world_state_data[0] = np.nan
492+
assert s.compute_state_changes() == {}
493+
s.close()
494+
495+
431496
if __name__ == "__main__":
432497
unittest.main()

test/test_worlds/test_world.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,10 @@ def test_compute_fk(world_setup):
221221

222222
connection: PrismaticConnection = world.get_connection(r1, r2)
223223

224+
state_memory_id = id(world.state.data)
224225
world.state[connection.dof.id].position = 1.0
225226
world.notify_state_change()
227+
assert state_memory_id == id(world.state.data)
226228
fk = world.compute_forward_kinematics_np(l2, r2)
227229
assert np.allclose(
228230
fk,
@@ -269,6 +271,7 @@ def test_compute_fk_expression(world_setup):
269271

270272
def test_apply_control_commands(world_setup):
271273
world, l1, l2, bf, r1, r2 = world_setup
274+
state_memory_id = id(world.state.data)
272275
connection: PrismaticConnection = world.get_connection(r1, r2)
273276
cmd = np.array([100.0, 0, 0, 0, 0, 0, 0, 0])
274277
dt = 0.1
@@ -277,6 +280,8 @@ def test_apply_control_commands(world_setup):
277280
assert world.state[connection.dof.id].acceleration == 100.0 * dt
278281
assert world.state[connection.dof.id].velocity == 100.0 * dt * dt
279282
assert world.state[connection.dof.id].position == 100.0 * dt * dt * dt
283+
# the state should reuse the same memory
284+
assert state_memory_id == id(world.state.data)
280285

281286

282287
def test_compute_relative_pose(world_setup):

0 commit comments

Comments
 (0)