Skip to content

Commit 4a964b8

Browse files
authored
Merge pull request #48 from ichumuh/spatial_types_everywhere
Spatial types everywhere
2 parents 2e5bf34 + 3d288c3 commit 4a964b8

23 files changed

Lines changed: 2144 additions & 746 deletions

examples/graph_of_convex_sets.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ Let's use graph theory to find a path!
9090
```{code-cell} ipython2
9191
from semantic_world.spatial_types import Point3
9292
93-
start = Point3.from_xyz(-0.75, 0, 0.15)
94-
goal = Point3.from_xyz(0.75, 0, 0.15)
93+
start = Point3(-0.75, 0, 0.15)
94+
goal = Point3(0.75, 0, 0.15)
9595
path = gcs.path_from_to(start, goal)
9696
print("A potential path is", [(point.x, point.y) for point in path])
9797
```
@@ -145,8 +145,8 @@ This allows the accessing of locations using a sequence of local problems put to
145145
Finally, let's find a way from here to there:
146146

147147
```{code-cell} ipython2
148-
start = Point3.from_xyz(-0.75, 0, 0.15)
149-
goal = Point3.from_xyz(0.75, 0, 0.15)
148+
start = Point3(-0.75, 0, 0.15)
149+
goal = Point3(0.75, 0, 0.15)
150150
path = gcs.path_from_to(start, goal)
151151
print("A potential path is", [(point.x, point.y, point.z) for point in path])
152152
```

scripts/generate_orm.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
1-
import logging
21
import os
3-
import sys
42
from enum import Enum
53

6-
from ormatic.ormatic import logger, ORMatic
7-
from ormatic.utils import classes_of_module, recursive_subclasses
8-
from sqlacodegen.generators import TablesGenerator
9-
from sqlalchemy import create_engine
10-
from sqlalchemy.orm import registry, Session
11-
4+
import semantic_world.degree_of_freedom
125
import semantic_world.geometry
6+
import semantic_world.views.views
137
import semantic_world.world_entity
14-
from semantic_world.connections import FixedConnection, OmniDrive
15-
import semantic_world.degree_of_freedom
16-
from semantic_world.prefixed_name import PrefixedName
8+
from ormatic.ormatic import ORMatic
9+
from ormatic.utils import classes_of_module, recursive_subclasses
10+
from semantic_world.connections import FixedConnection
1711
from semantic_world.orm.model import *
12+
from semantic_world.prefixed_name import PrefixedName
1813
from semantic_world.world import *
19-
import semantic_world.views.views
2014

2115
# ----------------------------------------------------------------------------------------------------------------------
2216
# This script generates the ORM classes for the semantic_world package.
@@ -41,34 +35,22 @@
4135
World, ForwardKinematicsVisitor, Has1DOFState, DegreeOfFreedom}
4236
classes -= set(recursive_subclasses(Enum))
4337

38+
4439
def generate_orm():
4540
"""
4641
Generate the ORM classes for the pycram package.
4742
"""
48-
# Set up logging
49-
handler = logging.StreamHandler(sys.stdout)
50-
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
51-
52-
logger.addHandler(handler)
53-
logger.setLevel(logging.INFO)
54-
55-
mapper_registry = registry()
56-
engine = create_engine('sqlite:///:memory:')
57-
session = Session(engine)
5843

5944
# Create an ORMatic object with the classes to be mapped
6045
ormatic = ORMatic(list(classes))
6146

6247
# Generate the ORM classes
6348
ormatic.make_all_tables()
6449

65-
# Create the tables in the database
66-
mapper_registry.metadata.create_all(session.bind)
67-
6850
path = os.path.abspath(os.path.join(os.getcwd(), '../src/semantic_world/orm/'))
6951
with open(os.path.join(path, 'ormatic_interface.py'), 'w') as f:
7052
ormatic.to_sqlalchemy_file(f)
7153

7254

7355
if __name__ == '__main__':
74-
generate_orm()
56+
generate_orm()

src/semantic_world/adapters/multi_parser.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
JointBuilder, JointType)
1010
from pxr import UsdUrdf
1111

12-
from ..connections import RevoluteConnection, PrismaticConnection, FixedConnection, UnitVector
12+
from ..connections import RevoluteConnection, PrismaticConnection, FixedConnection
1313
from ..spatial_types.derivatives import DerivativeMap
1414
from ..prefixed_name import PrefixedName
1515
from ..spatial_types import spatial_types as cas
@@ -115,11 +115,11 @@ def parse_joints(self, body_builder: BodyBuilder, world: World) -> list[Connecti
115115
transform = body_builder.xform.GetLocalTransformation()
116116
pos = transform.ExtractTranslation()
117117
quat = transform.ExtractRotationQuat()
118-
point_expr = cas.Point3((pos[0], pos[1], pos[2]))
119-
quaternion_expr = cas.Quaternion((quat.GetImaginary()[0],
118+
point_expr = cas.Point3(pos[0], pos[1], pos[2])
119+
quaternion_expr = cas.Quaternion(quat.GetImaginary()[0],
120120
quat.GetImaginary()[1],
121121
quat.GetImaginary()[2],
122-
quat.GetReal()))
122+
quat.GetReal())
123123
origin = cas.TransformationMatrix.from_point_rotation_matrix(point=point_expr,
124124
rotation_matrix=quaternion_expr.to_rotation_matrix())
125125
connection = FixedConnection(parent=parent_body, child=child_body, origin_expression=origin)
@@ -132,11 +132,11 @@ def parse_joint(self, joint_builder: JointBuilder, parent_body: Body, child_body
132132
joint_name = joint_prim.GetName()
133133
joint_pos = joint_builder.pos
134134
joint_quat = joint_builder.quat
135-
point_expr = cas.Point3((joint_pos[0], joint_pos[1], joint_pos[2]))
136-
quaternion_expr = cas.Quaternion((joint_quat.GetImaginary()[0],
135+
point_expr = cas.Point3(joint_pos[0], joint_pos[1], joint_pos[2])
136+
quaternion_expr = cas.Quaternion(joint_quat.GetImaginary()[0],
137137
joint_quat.GetImaginary()[1],
138138
joint_quat.GetImaginary()[2],
139-
joint_quat.GetReal()))
139+
joint_quat.GetReal())
140140
origin = cas.TransformationMatrix.from_point_rotation_matrix(point=point_expr,
141141
rotation_matrix=quaternion_expr.to_rotation_matrix())
142142
free_variable_name = PrefixedName(joint_name)
@@ -153,9 +153,10 @@ def parse_joint(self, joint_builder: JointBuilder, parent_body: Body, child_body
153153
elif joint_builder.type == JointType.FIXED:
154154
return FixedConnection(parent=parent_body, child=child_body, origin_expression=origin)
155155
elif joint_builder.type in [JointType.REVOLUTE, JointType.CONTINUOUS, JointType.PRISMATIC]:
156-
axis = UnitVector(float(joint_builder.axis.to_array()[0]),
157-
float(joint_builder.axis.to_array()[1]),
158-
float(joint_builder.axis.to_array()[2]))
156+
axis = cas.Vector3(float(joint_builder.axis.to_array()[0]),
157+
float(joint_builder.axis.to_array()[1]),
158+
float(joint_builder.axis.to_array()[2]),
159+
reference_frame=parent_body)
159160
try:
160161
dof = world.get_degree_of_freedom_by_name(free_variable_name)
161162
except KeyError:

src/semantic_world/adapters/urdf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from ..spatial_types import spatial_types as cas
55
from urdf_parser_py import urdf
66

7-
from ..connections import RevoluteConnection, PrismaticConnection, FixedConnection, UnitVector
7+
from ..connections import RevoluteConnection, PrismaticConnection, FixedConnection
88
from ..prefixed_name import PrefixedName
99
from ..spatial_types.derivatives import Derivatives, DerivativeMap
10-
from ..spatial_types.spatial_types import TransformationMatrix
10+
from ..spatial_types.spatial_types import TransformationMatrix, Vector3
1111
from ..utils import suppress_stdout_stderr, hacky_urdf_parser_fix
1212
from ..world import World, Body, Connection
1313
from ..geometry import Box, Sphere, Cylinder, Mesh, Scale, Shape, Color
@@ -157,7 +157,8 @@ def parse_joint(self, joint: urdf.Joint, parent: Body, child: Body, world: World
157157

158158
result = connection_type(parent=parent, child=child, origin_expression=parent_T_child,
159159
multiplier=multiplier, offset=offset,
160-
axis=UnitVector(*map(int, joint.axis)),
160+
axis=Vector3(*map(int, joint.axis),
161+
reference_frame=parent),
161162
dof=dof)
162163
return result
163164

src/semantic_world/connections.py

Lines changed: 18 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass, field
5-
from typing import List, Tuple, TYPE_CHECKING
5+
from typing import List, TYPE_CHECKING, Union
66

77
import numpy as np
88

99
from . import spatial_types as cas
1010
from .degree_of_freedom import DegreeOfFreedom
1111
from .prefixed_name import PrefixedName
12-
from .spatial_types.derivatives import Derivatives, DerivativeMap
12+
from .spatial_types.derivatives import DerivativeMap
1313
from .spatial_types.math import quaternion_from_rotation_matrix
1414
from .types import NpMatrix4x4
1515
from .world_entity import Connection
@@ -111,61 +111,13 @@ def passive_dofs(self) -> List[DegreeOfFreedom]:
111111
return []
112112

113113

114-
@dataclass
115-
class UnitVector:
116-
"""
117-
Represents a unit vector which is always of size 1.
118-
"""
119-
120-
x: float
121-
y: float
122-
z: float
123-
124-
def __post_init__(self):
125-
self.normalize()
126-
127-
def normalize(self):
128-
length = self.length
129-
self.x /= length
130-
self.y /= length
131-
self.z /= length
132-
133-
@property
134-
def length(self):
135-
return np.sqrt(self.x ** 2 + self.y ** 2 + self.z ** 2)
136-
137-
def __getitem__(self, item: int) -> float:
138-
if item == 0:
139-
return self.x
140-
if item == 1:
141-
return self.y
142-
if item == 2:
143-
return self.z
144-
raise IndexError
145-
146-
def as_tuple(self) -> Tuple[float, float, float]:
147-
return self.x, self.y, self.z
148-
149-
@classmethod
150-
def X(cls):
151-
return cls(1, 0, 0)
152-
153-
@classmethod
154-
def Y(cls):
155-
return cls(0, 1, 0)
156-
157-
@classmethod
158-
def Z(cls):
159-
return cls(0, 0, 1)
160-
161-
162114
@dataclass
163115
class PrismaticConnection(ActiveConnection, Has1DOFState):
164116
"""
165117
Allows the movement along an axis.
166118
"""
167119

168-
axis: UnitVector = field(kw_only=True)
120+
axis: cas.Vector3 = field(kw_only=True)
169121
"""
170122
Connection moves along this axis, should be a unit vector.
171123
The axis is defined relative to the local reference frame of the parent body.
@@ -200,7 +152,7 @@ def __post_init__(self):
200152
self._post_init_world_part()
201153

202154
motor_expression = self.dof.symbols.position * self.multiplier + self.offset
203-
translation_axis = cas.Vector3(self.axis) * motor_expression
155+
translation_axis = cas.Vector3.from_iterable(self.axis) * motor_expression
204156
parent_T_child = cas.TransformationMatrix.from_xyz_rpy(x=translation_axis[0],
205157
y=translation_axis[1],
206158
z=translation_axis[2])
@@ -230,7 +182,7 @@ class RevoluteConnection(ActiveConnection, Has1DOFState):
230182
Allows rotation about an axis.
231183
"""
232184

233-
axis: UnitVector = field(kw_only=True)
185+
axis: cas.Vector3 = field(kw_only=True)
234186
"""
235187
Connection rotates about this axis, should be a unit vector.
236188
The axis is defined relative to the local reference frame of the parent body.
@@ -265,9 +217,8 @@ def __post_init__(self):
265217
self._post_init_world_part()
266218

267219
motor_expression = self.dof.symbols.position * self.multiplier + self.offset
268-
rotation_axis = cas.Vector3(self.axis)
269-
parent_R_child = cas.RotationMatrix.from_axis_angle(rotation_axis, motor_expression)
270-
self.origin_expression = self.origin_expression.dot(cas.TransformationMatrix(parent_R_child))
220+
parent_R_child = cas.RotationMatrix.from_axis_angle(self.axis, motor_expression)
221+
self.origin_expression = self.origin_expression @ cas.TransformationMatrix(parent_R_child)
271222
self.origin_expression.reference_frame = self.parent
272223
self.origin_expression.child_frame = self.child
273224

@@ -318,13 +269,13 @@ class Connection6DoF(PassiveConnection):
318269
def __post_init__(self):
319270
super().__post_init__()
320271
self._post_init_world_part()
321-
parent_P_child = cas.Point3((self.x.symbols.position,
322-
self.y.symbols.position,
323-
self.z.symbols.position))
324-
parent_R_child = cas.Quaternion((self.qx.symbols.position,
325-
self.qy.symbols.position,
326-
self.qz.symbols.position,
327-
self.qw.symbols.position)).to_rotation_matrix()
272+
parent_P_child = cas.Point3(x=self.x.symbols.position,
273+
y=self.y.symbols.position,
274+
z=self.z.symbols.position)
275+
parent_R_child = cas.Quaternion(x=self.qx.symbols.position,
276+
y=self.qy.symbols.position,
277+
z=self.qz.symbols.position,
278+
w=self.qw.symbols.position).to_rotation_matrix()
328279
self.origin_expression = cas.TransformationMatrix.from_point_rotation_matrix(point=parent_P_child,
329280
rotation_matrix=parent_R_child,
330281
reference_frame=self.parent,
@@ -354,11 +305,13 @@ def passive_dofs(self) -> List[DegreeOfFreedom]:
354305
return [self.x, self.y, self.z, self.qx, self.qy, self.qz, self.qw]
355306

356307
@property
357-
def origin(self) -> NpMatrix4x4:
308+
def origin(self) -> cas.TransformationMatrix:
358309
return super().origin
359310

360311
@origin.setter
361-
def origin(self, transformation: NpMatrix4x4) -> None:
312+
def origin(self, transformation: Union[NpMatrix4x4, cas.TransformationMatrix]) -> None:
313+
if isinstance(transformation, cas.TransformationMatrix):
314+
transformation = transformation.to_np()
362315
orientation = quaternion_from_rotation_matrix(transformation)
363316
self._world.state[self.x.name].position = transformation[0, 3]
364317
self._world.state[self.y.name].position = transformation[1, 3]

src/semantic_world/exceptions.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22
from typing_extensions import List, TYPE_CHECKING
33

4+
from .prefixed_name import PrefixedName
5+
46
if TYPE_CHECKING:
5-
from semantic_world.world_entity import View
7+
from .world_entity import View
68

79

810
class LogicalError(Exception):
@@ -29,3 +31,9 @@ class DuplicateViewError(UsageError):
2931
def __init__(self, views: List[View]):
3032
msg = f'Views {views} are duplicates, while views elements should be unique.'
3133
super().__init__(msg)
34+
35+
36+
class ViewNotFoundError(UsageError):
37+
def __init__(self, name: PrefixedName):
38+
msg = f'View with name {name} not found'
39+
super().__init__(msg)

src/semantic_world/geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def get_points(self) -> List[Point3]:
383383
384384
:return: A list of Point3 objects representing the corners of the bounding box.
385385
"""
386-
return [Point3.from_xyz(x, y, z)
386+
return [Point3(x, y, z)
387387
for x in (self.min_x, self.max_x)
388388
for y in (self.min_y, self.max_y)
389389
for z in (self.min_z, self.max_z)]

src/semantic_world/graph_of_convex_sets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def path_from_to(self, start: Point3, goal: Point3) -> Optional[List[Point3]]:
198198
x_target = intersection.x_interval.center()
199199
y_target = intersection.y_interval.center()
200200
z_target = intersection.z_interval.center()
201-
result.append(Point3.from_xyz(x_target, y_target, z_target))
201+
result.append(Point3(x_target, y_target, z_target))
202202

203203
result.append(goal)
204204
return result

src/semantic_world/ik_solver.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from .connections import ActiveConnection, PassiveConnection
1212
from .degree_of_freedom import DegreeOfFreedom
1313
from .spatial_types import spatial_types as cas
14-
from .spatial_types.derivatives import Derivatives
15-
from .types import NpMatrix4x4
1614

1715
if TYPE_CHECKING:
1816
from .world import World
@@ -137,7 +135,7 @@ class InverseKinematicsSolver:
137135
Unit is m for the position target or rad for the orientation target.
138136
"""
139137

140-
def solve(self, root: Body, tip: Body, target: NpMatrix4x4,
138+
def solve(self, root: Body, tip: Body, target: cas.TransformationMatrix,
141139
dt: float = 0.05, max_iterations: int = 200,
142140
translation_velocity: float = 0.2, rotation_velocity: float = 0.2) -> Dict[DegreeOfFreedom, float]:
143141
"""
@@ -257,7 +255,7 @@ class QPProblem:
257255
Tip body of the kinematic chain.
258256
"""
259257

260-
target: NpMatrix4x4
258+
target: cas.TransformationMatrix
261259
"""
262260
Desired tip pose relative to the root body.
263261
"""
@@ -373,7 +371,7 @@ class ConstraintBuilder:
373371
Tip body of the kinematic chain.
374372
"""
375373

376-
target: NpMatrix4x4
374+
target: cas.TransformationMatrix
377375
"""
378376
Desired tip pose relative to the root body.
379377
"""
@@ -461,7 +459,7 @@ def _compute_rotation_error(self, root_T_tip: cas.TransformationMatrix) -> Tuple
461459
"""
462460
rotation_cap = self.max_rotation_velocity * self.dt
463461

464-
hack = cas.RotationMatrix.from_axis_angle(cas.Vector3((0, 0, 1)), -0.0001)
462+
hack = cas.RotationMatrix.from_axis_angle(cas.Vector3.Z(), -0.0001)
465463
root_R_tip = root_T_tip.to_rotation().dot(hack)
466464
q_actual = cas.TransformationMatrix(self.target).to_quaternion()
467465
q_goal = root_R_tip.to_quaternion()

0 commit comments

Comments
 (0)