Skip to content

Commit 62fcc5a

Browse files
authored
Merge pull request #123 from Tigul/world-copy
Removed JSON serializer from World deepcopy for better perfromance
2 parents 0f4be1b + d2bc4c8 commit 62fcc5a

4 files changed

Lines changed: 155 additions & 7 deletions

File tree

src/semantic_digital_twin/world.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import inspect
55
import logging
6-
from copy import deepcopy
6+
from copy import deepcopy, copy
77
from dataclasses import dataclass, field
88
from enum import IntEnum
99
from functools import wraps, lru_cache, cached_property
@@ -1772,7 +1772,6 @@ def __deepcopy__(self, memo):
17721772
new_world = World(name=self.name)
17731773
memo[me_id] = new_world
17741774

1775-
tracker = KinematicStructureEntityKwargsTracker.from_world(new_world)
17761775
with new_world.modify_world():
17771776
for body in self.bodies:
17781777
new_body = Body(
@@ -1790,9 +1789,7 @@ def __deepcopy__(self, memo):
17901789
new_world.add_degree_of_freedom(new_dof)
17911790
new_world.state[dof.name] = self.state[dof.name].data
17921791
for connection in self.connections:
1793-
new_connection = SubclassJSONSerializer.from_json(
1794-
connection.to_json(), **tracker.create_kwargs()
1795-
)
1792+
new_connection = connection.copy_for_world(new_world)
17961793
new_world.add_connection(new_connection)
17971794
return new_world
17981795

src/semantic_digital_twin/world_description/connections.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,24 @@ def jerk(self, value: float) -> None:
309309
self._world.state[self.raw_dof.name].jerk = value / self.multiplier
310310
self._world.notify_state_change()
311311

312+
def copy_for_world(self, world: World):
313+
(
314+
other_parent,
315+
other_child,
316+
parent_T_connection_expression,
317+
) = self._find_references_in_world(world)
318+
319+
return self.__class__(
320+
name=PrefixedName(self.name.name, self.name.prefix),
321+
parent=other_parent,
322+
child=other_child,
323+
parent_T_connection_expression=parent_T_connection_expression,
324+
dof_name=PrefixedName(self.dof_name.name, self.dof_name.prefix),
325+
axis=self.axis,
326+
multiplier=self.multiplier,
327+
offset=self.offset,
328+
)
329+
312330

313331
@dataclass(eq=False)
314332
class PrismaticConnection(ActiveConnection1DOF):
@@ -557,6 +575,32 @@ def origin(
557575
self._world.state[self.qw.name].position = orientation[3]
558576
self._world.notify_state_change()
559577

578+
def copy_for_world(self, world: World) -> Connection6DoF:
579+
"""
580+
Copies this 6DoF connection for another world. Returns a new connection with references to the given world.
581+
:param world: The world to copy this connection for.
582+
:return: A copy of this connection for the given world.
583+
"""
584+
(
585+
other_parent,
586+
other_child,
587+
parent_T_connection_expression,
588+
) = self._find_references_in_world(world)
589+
590+
return Connection6DoF(
591+
name=deepcopy(self.name),
592+
parent=other_parent,
593+
child=other_child,
594+
parent_T_connection_expression=parent_T_connection_expression,
595+
x_name=deepcopy(self.x_name),
596+
y_name=deepcopy(self.y_name),
597+
z_name=deepcopy(self.z_name),
598+
qx_name=deepcopy(self.qx_name),
599+
qy_name=deepcopy(self.qy_name),
600+
qz_name=deepcopy(self.qz_name),
601+
qw_name=deepcopy(self.qw_name),
602+
)
603+
560604

561605
@dataclass(eq=False)
562606
class OmniDrive(ActiveConnection, HasUpdateState):
@@ -821,3 +865,30 @@ def has_hardware_interface(self, value: bool) -> None:
821865
self.x_velocity.has_hardware_interface = value
822866
self.y_velocity.has_hardware_interface = value
823867
self.yaw.has_hardware_interface = value
868+
869+
def copy_for_world(self, world: World) -> OmniDrive:
870+
"""
871+
Copies this OmniDriveConnection for the provided world. This finds the references for the parent and child in
872+
the new world and returns a new connection with references to the new parent and child.
873+
:param world: The world where the connection is copied.
874+
:return: The connection with references to the new parent and child.
875+
"""
876+
(
877+
other_parent,
878+
other_child,
879+
parent_T_connection_expression,
880+
) = self._find_references_in_world(world)
881+
882+
return OmniDrive(
883+
name=deepcopy(self.name),
884+
parent=other_parent,
885+
child=other_child,
886+
parent_T_connection_expression=parent_T_connection_expression,
887+
x_name=deepcopy(self.x_name),
888+
y_name=deepcopy(self.y_name),
889+
roll_name=deepcopy(self.roll_name),
890+
pitch_name=deepcopy(self.pitch_name),
891+
yaw_name=deepcopy(self.yaw_name),
892+
x_velocity_name=deepcopy(self.x_velocity_name),
893+
y_velocity_name=deepcopy(self.y_velocity_name),
894+
)

src/semantic_digital_twin/world_description/world_entity.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
import inspect
4+
from copy import copy, deepcopy
5+
46
import itertools
57
from abc import ABC, abstractmethod
68
from collections import deque
@@ -793,7 +795,9 @@ class Connection(WorldEntity, SubclassJSONSerializer):
793795
"""
794796

795797
parent_T_connection_expression: TransformationMatrix = field(default=None)
796-
_connection_T_child_expression: TransformationMatrix = field(default=None)
798+
_connection_T_child_expression: TransformationMatrix = field(
799+
default=None, init=False
800+
)
797801
"""
798802
The origin expression of a connection is split into 2 transforms:
799803
1. parent_T_connection describes the pose of the connection and is always constant.
@@ -955,6 +959,51 @@ def create_with_dofs(
955959
"ConnectionWithDofs.create_with_dofs is not implemented."
956960
)
957961

962+
def _find_references_in_world(
963+
self, world: World
964+
) -> Tuple[
965+
KinematicStructureEntity, KinematicStructureEntity, TransformationMatrix
966+
]:
967+
"""
968+
Finds the reference frames to this connection in the given world and returns them as usable objects.
969+
:param world: Reference to the world where the reference frames are searched.
970+
:return: The other parent and child and new connection expressions with correct reference frames.
971+
"""
972+
other_parent = world.get_kinematic_structure_entity_by_name(self.parent.name)
973+
other_child = world.get_kinematic_structure_entity_by_name(self.child.name)
974+
975+
parent_T_connection_expression = deepcopy(self.parent_T_connection_expression)
976+
parent_T_connection_expression.reference_frame = (
977+
world.get_kinematic_structure_entity_by_name(
978+
parent_T_connection_expression.reference_frame.name
979+
)
980+
)
981+
return (
982+
other_parent,
983+
other_child,
984+
parent_T_connection_expression,
985+
)
986+
987+
def copy_for_world(self, world: World) -> Self:
988+
"""
989+
Copies this connection to the given world the parent and child references are updated to the new world as well
990+
as the references from the expression.
991+
:param world: World in which the connection should be copied.
992+
:return: The copied connection.
993+
"""
994+
(
995+
other_parent,
996+
other_child,
997+
parent_T_connection_expression,
998+
) = self._find_references_in_world(world)
999+
1000+
return self.__class__(
1001+
other_parent,
1002+
other_child,
1003+
parent_T_connection_expression=parent_T_connection_expression,
1004+
name=PrefixedName(self.name.name, prefix=self.name.prefix),
1005+
)
1006+
9581007

9591008
GenericConnection = TypeVar("GenericConnection", bound=Connection)
9601009

test/test_worlds/test_world.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import numpy as np
44
import pytest
5+
from numpy.testing import assert_raises
6+
57
from semantic_digital_twin.semantic_annotations.semantic_annotations import Handle
68

79
from semantic_digital_twin.spatial_types import Vector3
@@ -647,7 +649,7 @@ def test_copy_pr2_world_connection_origin(pr2_world):
647649
pr2_body = pr2_world.get_kinematic_structure_entity_by_name(body.name)
648650
pr2_copy_body = pr2_copy.get_kinematic_structure_entity_by_name(body.name)
649651
np.testing.assert_array_almost_equal(
650-
pr2_body.global_pose.to_np(), pr2_copy_body.global_pose.to_np()
652+
pr2_body.global_pose.to_np(), pr2_copy_body.global_pose.to_np(), decimal=4
651653
)
652654

653655

@@ -676,6 +678,35 @@ def test_copy_pr2(pr2_world):
676678
).global_pose.to_np()[2, 3] == pytest.approx(1.472, abs=1e-3)
677679

678680

681+
def test_copy_connections(pr2_world):
682+
pr2_copy = deepcopy(pr2_world)
683+
for connection in pr2_world.connections:
684+
pr2_copy_connection = pr2_copy.get_connection_by_name(connection.name)
685+
assert connection.name == pr2_copy_connection.name
686+
np.testing.assert_array_almost_equal(
687+
connection.origin.to_np(), pr2_copy_connection.origin.to_np(), decimal=3
688+
)
689+
pr2_copy.state[
690+
pr2_copy.get_degree_of_freedom_by_name("torso_lift_joint").name
691+
].position = 0.3
692+
pr2_copy.notify_state_change()
693+
694+
assert_raises(
695+
AssertionError,
696+
np.testing.assert_array_almost_equal,
697+
pr2_world.get_connection_by_name("torso_lift_joint").origin.to_np(),
698+
pr2_copy.get_connection_by_name("torso_lift_joint").origin.to_np(),
699+
)
700+
701+
702+
def test_copy_two_times(pr2_world):
703+
pr2_copy = deepcopy(pr2_world)
704+
pr2_copy_2 = deepcopy(pr2_copy)
705+
for connection in pr2_world.connections:
706+
pr2_copy_connection = pr2_copy_2.get_connection_by_name(connection.name)
707+
assert connection.name == pr2_copy_connection.name
708+
709+
679710
def test_add_entity_with_duplicate_name(world_setup):
680711
world, l1, l2, bf, r1, r2 = world_setup
681712
body_duplicate = Body(name=PrefixedName("l1"))

0 commit comments

Comments
 (0)