77import inspect
88from matplotlib import pyplot as plt , cm as cm
99
10+ # Check if torch.compile is available (PyTorch 2.0+)
11+ _TORCH_COMPILE_AVAILABLE = hasattr (torch , 'compile' ) and torch .__version__ >= '2.0'
12+
13+
14+ def _compute_dq_kernel (J : torch .Tensor , dx : torch .Tensor , reg_matrix : torch .Tensor ) -> torch .Tensor :
15+ """
16+ Compute joint velocity using damped least squares.
17+
18+ This function is designed to be compatible with torch.compile for JIT optimization.
19+
20+ Args:
21+ J: Jacobian matrix (N, 6, DOF)
22+ dx: Pose error (N, 6, 1)
23+ reg_matrix: Regularization matrix (6, 6)
24+
25+ Returns:
26+ dq: Joint velocity (N, DOF, 1)
27+ """
28+ # JJ^T + lambda^2*I
29+ tmpA = J @ J .transpose (1 , 2 ) + reg_matrix
30+ # Solve (JJ^T + lambda^2*I) A = dx
31+ A = torch .linalg .solve (tmpA , dx )
32+ # dq = J^T @ A
33+ return J .transpose (1 , 2 ) @ A
34+
1035
1136class IKSolution :
1237 def __init__ (self , dof , num_problems , num_retries , pos_tolerance , rot_tolerance , device = "cpu" ):
@@ -234,13 +259,14 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
234259 raise NotImplementedError ()
235260
236261
237- def delta_pose (m : torch .tensor , target_pos , target_wxyz ):
262+ def delta_pose (m : torch .tensor , target_pos , target_wxyz , out : torch . Tensor = None ):
238263 """
239264 Determine the error in position and rotation between the given poses and the target poses
240265
241266 :param m: (N x M x 4 x 4) tensor of homogenous transforms
242267 :param target_pos:
243268 :param target_wxyz: target orientation represented in unit quaternion
269+ :param out: optional pre-allocated output buffer (N*M, 6, 1) to reduce memory allocation
244270 :return: (N*M, 6, 1) tensor of delta pose (dx, dy, dz, droll, dpitch, dyaw)
245271 """
246272 pos_diff = target_pos .unsqueeze (1 ) - m [:, :, :3 , 3 ]
@@ -257,7 +283,13 @@ def delta_pose(m: torch.tensor, target_pos, target_wxyz):
257283
258284 rot_diff = diff_axis_angle .view (- 1 , 3 , 1 )
259285
260- dx = torch .cat ((pos_diff , rot_diff ), dim = 1 )
286+ # Use pre-allocated buffer if provided
287+ if out is not None :
288+ out [:, :3 ] = pos_diff
289+ out [:, 3 :] = rot_diff
290+ dx = out
291+ else :
292+ dx = torch .cat ((pos_diff , rot_diff ), dim = 1 )
261293 return dx , pos_diff , rot_diff
262294
263295
@@ -266,18 +298,31 @@ def apply_mask(mask, *args):
266298
267299
268300class PseudoInverseIK (InverseKinematics ):
301+ def __init__ (self , * args , use_compile : bool = False , ** kwargs ):
302+ """
303+ Initialize PseudoInverseIK solver.
304+
305+ Args:
306+ *args: Arguments passed to InverseKinematics.
307+ use_compile: If True and PyTorch 2.0+ is available, use torch.compile
308+ for JIT compilation of the compute_dq kernel. This can provide
309+ performance improvements after a warmup period. Default: False.
310+ **kwargs: Keyword arguments passed to InverseKinematics.
311+ """
312+ super ().__init__ (* args , ** kwargs )
313+ # Pre-compute regularization matrix once
314+ self ._reg_matrix = self .regularlization * torch .eye (6 , device = self .device , dtype = self .dtype )
315+
316+ # Set up compute_dq kernel (potentially compiled)
317+ self ._use_compile = use_compile and _TORCH_COMPILE_AVAILABLE
318+ if self ._use_compile :
319+ self ._compute_dq_fn = torch .compile (_compute_dq_kernel )
320+ else :
321+ self ._compute_dq_fn = _compute_dq_kernel
322+
269323 def compute_dq (self , J , dx ):
270- # lambda^2*I (lambda^2 is regularization)
271- reg = self .regularlization * torch .eye (6 , device = self .device , dtype = self .dtype )
272-
273- # JJ^T + lambda^2*I (lambda^2 is regularization)
274- tmpA = J @ J .transpose (1 , 2 ) + reg
275- # (JJ^T + lambda^2I) A = dx
276- # A = (JJ^T + lambda^2I)^-1 dx
277- A = torch .linalg .solve (tmpA , dx )
278- # dq = J^T (JJ^T + lambda^2I)^-1 dx
279- dq = J .transpose (1 , 2 ) @ A
280- return dq
324+ """Compute joint velocity using damped least squares."""
325+ return self ._compute_dq_fn (J , dx , self ._reg_matrix )
281326
282327 def solve (self , target_poses : Transform3d ) -> IKSolution :
283328 self .clear ()
@@ -313,6 +358,11 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
313358 if inspect .isclass (self .optimizer_method ) and issubclass (self .optimizer_method , torch .optim .Optimizer ):
314359 q .requires_grad = True
315360 optimizer = torch .optim .Adam ([q ], lr = self .lr )
361+
362+ # Pre-allocate delta pose buffer to reduce memory allocation in loop
363+ batch_size = M * self .num_retries
364+ dx_buffer = torch .empty ((batch_size , 6 , 1 ), device = self .device , dtype = self .dtype )
365+
316366 for i in range (self .max_iterations ):
317367 with torch .no_grad ():
318368 # early termination if we're out of problems to solve
@@ -324,7 +374,7 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
324374 J , m = self .chain .jacobian (q , ret_eef_pose = True )
325375 # unflatten to broadcast with goal
326376 m = m .view (- 1 , self .num_retries , 4 , 4 )
327- dx , pos_diff , rot_diff = delta_pose (m , target_pos , target_wxyz )
377+ dx , pos_diff , rot_diff = delta_pose (m , target_pos , target_wxyz , out = dx_buffer )
328378
329379 # damped least squares method
330380 # lambda^2*I (lambda^2 is regularization)
@@ -344,7 +394,8 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
344394 lr = lr .unsqueeze (1 )
345395 else :
346396 lr = self .lr
347- q = q + lr * dq
397+ # Use in-place addition to reduce memory allocation
398+ q = q .add (dq , alpha = lr ) if isinstance (lr , float ) else q .add (lr * dq )
348399
349400 with torch .no_grad ():
350401 self .err_all = dx .squeeze ()
0 commit comments