Skip to content

Commit 6fdd30a

Browse files
authored
Merge pull request #63 from Daoming-Chen/master
Optimize IK solver performance (1.3-1.6x speedup)
2 parents e366c97 + e3bcca9 commit 6fdd30a

File tree

2 files changed

+69
-17
lines changed

2 files changed

+69
-17
lines changed

src/pytorch_kinematics/ik.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,31 @@
77
import inspect
88
from matplotlib import pyplot as plt, cm as cm
99

10+
# Check if torch.compile is available (PyTorch 2.0+)
11+
_TORCH_COMPILE_AVAILABLE = hasattr(torch, 'compile') and torch.__version__ >= '2.0'
12+
13+
14+
def _compute_dq_kernel(J: torch.Tensor, dx: torch.Tensor, reg_matrix: torch.Tensor) -> torch.Tensor:
15+
"""
16+
Compute joint velocity using damped least squares.
17+
18+
This function is designed to be compatible with torch.compile for JIT optimization.
19+
20+
Args:
21+
J: Jacobian matrix (N, 6, DOF)
22+
dx: Pose error (N, 6, 1)
23+
reg_matrix: Regularization matrix (6, 6)
24+
25+
Returns:
26+
dq: Joint velocity (N, DOF, 1)
27+
"""
28+
# JJ^T + lambda^2*I
29+
tmpA = J @ J.transpose(1, 2) + reg_matrix
30+
# Solve (JJ^T + lambda^2*I) A = dx
31+
A = torch.linalg.solve(tmpA, dx)
32+
# dq = J^T @ A
33+
return J.transpose(1, 2) @ A
34+
1035

1136
class IKSolution:
1237
def __init__(self, dof, num_problems, num_retries, pos_tolerance, rot_tolerance, device="cpu"):
@@ -234,13 +259,14 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
234259
raise NotImplementedError()
235260

236261

237-
def delta_pose(m: torch.tensor, target_pos, target_wxyz):
262+
def delta_pose(m: torch.tensor, target_pos, target_wxyz, out: torch.Tensor = None):
238263
"""
239264
Determine the error in position and rotation between the given poses and the target poses
240265
241266
:param m: (N x M x 4 x 4) tensor of homogenous transforms
242267
:param target_pos:
243268
:param target_wxyz: target orientation represented in unit quaternion
269+
:param out: optional pre-allocated output buffer (N*M, 6, 1) to reduce memory allocation
244270
:return: (N*M, 6, 1) tensor of delta pose (dx, dy, dz, droll, dpitch, dyaw)
245271
"""
246272
pos_diff = target_pos.unsqueeze(1) - m[:, :, :3, 3]
@@ -257,7 +283,13 @@ def delta_pose(m: torch.tensor, target_pos, target_wxyz):
257283

258284
rot_diff = diff_axis_angle.view(-1, 3, 1)
259285

260-
dx = torch.cat((pos_diff, rot_diff), dim=1)
286+
# Use pre-allocated buffer if provided
287+
if out is not None:
288+
out[:, :3] = pos_diff
289+
out[:, 3:] = rot_diff
290+
dx = out
291+
else:
292+
dx = torch.cat((pos_diff, rot_diff), dim=1)
261293
return dx, pos_diff, rot_diff
262294

263295

@@ -266,18 +298,31 @@ def apply_mask(mask, *args):
266298

267299

268300
class PseudoInverseIK(InverseKinematics):
301+
def __init__(self, *args, use_compile: bool = False, **kwargs):
302+
"""
303+
Initialize PseudoInverseIK solver.
304+
305+
Args:
306+
*args: Arguments passed to InverseKinematics.
307+
use_compile: If True and PyTorch 2.0+ is available, use torch.compile
308+
for JIT compilation of the compute_dq kernel. This can provide
309+
performance improvements after a warmup period. Default: False.
310+
**kwargs: Keyword arguments passed to InverseKinematics.
311+
"""
312+
super().__init__(*args, **kwargs)
313+
# Pre-compute regularization matrix once
314+
self._reg_matrix = self.regularlization * torch.eye(6, device=self.device, dtype=self.dtype)
315+
316+
# Set up compute_dq kernel (potentially compiled)
317+
self._use_compile = use_compile and _TORCH_COMPILE_AVAILABLE
318+
if self._use_compile:
319+
self._compute_dq_fn = torch.compile(_compute_dq_kernel)
320+
else:
321+
self._compute_dq_fn = _compute_dq_kernel
322+
269323
def compute_dq(self, J, dx):
270-
# lambda^2*I (lambda^2 is regularization)
271-
reg = self.regularlization * torch.eye(6, device=self.device, dtype=self.dtype)
272-
273-
# JJ^T + lambda^2*I (lambda^2 is regularization)
274-
tmpA = J @ J.transpose(1, 2) + reg
275-
# (JJ^T + lambda^2I) A = dx
276-
# A = (JJ^T + lambda^2I)^-1 dx
277-
A = torch.linalg.solve(tmpA, dx)
278-
# dq = J^T (JJ^T + lambda^2I)^-1 dx
279-
dq = J.transpose(1, 2) @ A
280-
return dq
324+
"""Compute joint velocity using damped least squares."""
325+
return self._compute_dq_fn(J, dx, self._reg_matrix)
281326

282327
def solve(self, target_poses: Transform3d) -> IKSolution:
283328
self.clear()
@@ -313,6 +358,11 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
313358
if inspect.isclass(self.optimizer_method) and issubclass(self.optimizer_method, torch.optim.Optimizer):
314359
q.requires_grad = True
315360
optimizer = torch.optim.Adam([q], lr=self.lr)
361+
362+
# Pre-allocate delta pose buffer to reduce memory allocation in loop
363+
batch_size = M * self.num_retries
364+
dx_buffer = torch.empty((batch_size, 6, 1), device=self.device, dtype=self.dtype)
365+
316366
for i in range(self.max_iterations):
317367
with torch.no_grad():
318368
# early termination if we're out of problems to solve
@@ -324,7 +374,7 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
324374
J, m = self.chain.jacobian(q, ret_eef_pose=True)
325375
# unflatten to broadcast with goal
326376
m = m.view(-1, self.num_retries, 4, 4)
327-
dx, pos_diff, rot_diff = delta_pose(m, target_pos, target_wxyz)
377+
dx, pos_diff, rot_diff = delta_pose(m, target_pos, target_wxyz, out=dx_buffer)
328378

329379
# damped least squares method
330380
# lambda^2*I (lambda^2 is regularization)
@@ -344,7 +394,8 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
344394
lr = lr.unsqueeze(1)
345395
else:
346396
lr = self.lr
347-
q = q + lr * dq
397+
# Use in-place addition to reduce memory allocation
398+
q = q.add(dq, alpha=lr) if isinstance(lr, float) else q.add(lr * dq)
348399

349400
with torch.no_grad():
350401
self.err_all = dx.squeeze()

src/pytorch_kinematics/jacobian.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ def calc_jacobian(serial_chain, th, tool=None, ret_eef_pose=False):
5050
cur_frame_transform = f.get_transform(th[:, -cnt]).get_matrix()
5151
cur_transform = cur_frame_transform @ cur_transform
5252

53-
# currently j_eef is Jacobian in end-effector frame, convert to base/world frame
54-
pose = serial_chain.forward_kinematics(th).get_matrix()
53+
# After the loop, cur_transform is the accumulated base→EEF transform.
54+
# Reuse it instead of calling forward_kinematics again.
55+
pose = cur_transform
5556
rotation = pose[:, :3, :3]
5657
j_tr = torch.zeros((N, 6, 6), dtype=serial_chain.dtype, device=serial_chain.device)
5758
j_tr[:, :3, :3] = rotation

0 commit comments

Comments
 (0)