Skip to content

Commit 12b076d

Browse files
Tuxliriclaude
andcommitted
feat: make torch an optional dependency for the batched simulator
torch, torchdiffeq, roma, and opt-einsum are only required for the batched simulator (BatchedMultirotor, simulate_batch, BatchedSE3Control, etc.). Moving them to an optional extra avoids forcing all users to install PyTorch (~2 GB) when they only need the standard single-drone simulator. - pyproject.toml: remove the four packages from core dependencies and add a new `batched` extra; also include them in `all` - Wrap top-level torch/roma/torchdiffeq imports with try/except in every file that mixes batched and non-batched classes, so those modules remain importable without the extra installed Install the batched simulator with: pip install rotorpy[batched] Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 7a5b6c9 commit 12b076d

File tree

12 files changed

+58
-19
lines changed

12 files changed

+58
-19
lines changed

pyproject.toml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,16 @@ dependencies = [
2929
'pandas',
3030
'tqdm',
3131
'gymnasium',
32-
'roma', # For batched sim
33-
'torch>=1.11.0', # For batched sim
34-
'torchdiffeq', # For batched sim
35-
'opt-einsum', # For batched sim
3632
'timed_count', # Only for ardupilot sitl example
3733
]
3834

3935
[project.optional-dependencies]
36+
batched = [
37+
'torch>=1.11.0',
38+
'torchdiffeq',
39+
'roma',
40+
'opt-einsum',
41+
]
4042
learning = [
4143
'stable_baselines3',
4244
'tensorboard',
@@ -64,6 +66,10 @@ px4 = [
6466
'pymavlink',
6567
]
6668
all = [
69+
"torch>=1.11.0",
70+
"torchdiffeq",
71+
"roma",
72+
"opt-einsum",
6773
"stable_baselines3",
6874
"tensorboard",
6975
"pytest",

rotorpy/controllers/quadrotor_control.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
2-
import torch
3-
import roma
2+
try:
3+
import torch
4+
import roma
5+
except ImportError:
6+
pass
47
from scipy.spatial.transform import Rotation
58

69
class SE3Control(object):

rotorpy/sensors/imu.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22
from scipy.spatial.transform import Rotation
3-
import torch
3+
try:
4+
import torch
5+
except ImportError:
6+
pass
47
import copy
58

69
class Imu:

rotorpy/simulate.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
from enum import Enum
33
import copy
44
import numpy as np
5-
import roma
6-
import torch
5+
try:
6+
import roma
7+
import torch
8+
except ImportError:
9+
pass
710
from numpy.linalg import norm
811
from scipy.spatial.transform import Rotation
912
from time import perf_counter

rotorpy/trajectories/circular_traj.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import numpy as np
2-
import torch
2+
try:
3+
import torch
4+
except ImportError:
5+
pass
36
import sys
47

58
class ThreeDCircularTraj(object):

rotorpy/trajectories/hover_traj.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import numpy as np
2-
import torch
2+
try:
3+
import torch
4+
except ImportError:
5+
pass
36

47
class HoverTraj(object):
58
"""

rotorpy/trajectories/lissajous_traj.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import numpy as np
2-
import torch
2+
try:
3+
import torch
4+
except ImportError:
5+
pass
36

47
"""
58
Lissajous curves are defined by trigonometric functions parameterized in time.

rotorpy/trajectories/minsnap.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import cvxopt
66
from scipy.linalg import block_diag
77
from typing import List
8-
import torch
8+
try:
9+
import torch
10+
except ImportError:
11+
pass
912

1013
def cvxopt_solve_qp(P, q, G=None, h=None, A=None, b=None):
1114
"""

rotorpy/trajectories/traj_template.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
Imports
33
"""
44
import numpy as np
5-
import torch
5+
try:
6+
import torch
7+
except ImportError:
8+
pass
69

710
class TrajTemplate(object):
811
"""

rotorpy/vehicles/multirotor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
from scipy.spatial.transform import Rotation as R
99

1010
# imports for Batched Dynamics
11-
import torch
12-
from torchdiffeq import odeint
13-
import roma
11+
try:
12+
import torch
13+
from torchdiffeq import odeint
14+
import roma
15+
except ImportError:
16+
pass
1417

1518
import time
1619

0 commit comments

Comments
 (0)