Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,16 @@ dependencies = [
'pandas',
'tqdm',
'gymnasium',
'roma', # For batched sim
'torch>=1.11.0', # For batched sim
'torchdiffeq', # For batched sim
'opt-einsum', # For batched sim
'timed_count', # Only for ardupilot sitl example
]

[project.optional-dependencies]
batched = [
'torch>=1.11.0',
'torchdiffeq',
'roma',
'opt-einsum',
]
learning = [
'stable_baselines3',
'tensorboard',
Expand All @@ -58,6 +60,10 @@ px4 = [
'pymavlink',
]
all = [
"torch>=1.11.0",
"torchdiffeq",
"roma",
"opt-einsum",
"stable_baselines3",
"tensorboard",
"pytest",
Expand Down
7 changes: 5 additions & 2 deletions rotorpy/controllers/quadrotor_control.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import torch
import roma
try:
import torch
import roma
except ImportError:
pass
from scipy.spatial.transform import Rotation

class SE3Control(object):
Expand Down
5 changes: 4 additions & 1 deletion rotorpy/sensors/imu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
from scipy.spatial.transform import Rotation
import torch
try:
import torch
except ImportError:
pass
import copy

class Imu:
Expand Down
7 changes: 5 additions & 2 deletions rotorpy/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from enum import Enum
import copy
import numpy as np
import roma
import torch
try:
import roma
import torch
except ImportError:
pass
from numpy.linalg import norm
from scipy.spatial.transform import Rotation
from time import perf_counter
Expand Down
5 changes: 4 additions & 1 deletion rotorpy/trajectories/circular_traj.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import torch
try:
import torch
except ImportError:
pass
import sys

class ThreeDCircularTraj(object):
Expand Down
5 changes: 4 additions & 1 deletion rotorpy/trajectories/hover_traj.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import torch
try:
import torch
except ImportError:
pass

class HoverTraj(object):
"""
Expand Down
5 changes: 4 additions & 1 deletion rotorpy/trajectories/lissajous_traj.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import torch
try:
import torch
except ImportError:
pass

"""
Lissajous curves are defined by trigonometric functions parameterized in time.
Expand Down
5 changes: 4 additions & 1 deletion rotorpy/trajectories/minsnap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import cvxopt
from scipy.linalg import block_diag
from typing import List
import torch
try:
import torch
except ImportError:
pass

def cvxopt_solve_qp(P, q, G=None, h=None, A=None, b=None):
"""
Expand Down
5 changes: 4 additions & 1 deletion rotorpy/trajectories/traj_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
Imports
"""
import numpy as np
import torch
try:
import torch
except ImportError:
pass

class TrajTemplate(object):
"""
Expand Down
9 changes: 6 additions & 3 deletions rotorpy/vehicles/multirotor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from scipy.spatial.transform import Rotation as R

# imports for Batched Dynamics
import torch
from torchdiffeq import odeint
import roma
try:
import torch
from torchdiffeq import odeint
import roma
except ImportError:
pass

import time

Expand Down
13 changes: 6 additions & 7 deletions rotorpy/vehicles/px4_multirotor.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def _send_hil_sensor(self, state, statedot):
fields_updated=updated_bitmask,
)

def step(self, state, control, t_step):
def step(self, state, control, t_step):

# Compute state derivative once for state and messages
# and send both HIL messages
statedot = self.statedot(state, control, 0.0)
Expand All @@ -298,14 +298,13 @@ def step(self, state, control, t_step):
px4_control = self._fetch_latest_px4_control(blocking=self._lockstep_enabled)
if px4_control is not None:
self._last_control = px4_control
control = px4_control
else:
control = self._last_control
pass # Do not modify _last_control

else: # In this case we use the control provided by the external controller
pass
state = super().step(state, control, t_step)
self._last_control = control

state = super().step(state, self._last_control, t_step)
self.state = state
self.t += t_step

Expand Down
5 changes: 4 additions & 1 deletion rotorpy/wind/default_winds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import sys
import torch
try:
import torch
except ImportError:
pass
import math
import random

Expand Down
5 changes: 4 additions & 1 deletion rotorpy/wind/dryden_winds.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import torch
try:
import torch
except ImportError:
pass
import os
import sys

Expand Down
Loading