Skip to content

Commit ee16c9d

Browse files
committed
formatting
1 parent bac622e commit ee16c9d

11 files changed

+127
-107
lines changed

main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,12 @@ def main():
6262
keys_pressed = penv.get_key_pressed()
6363
move_action = robot.get_key_move_action(keys_pressed=keys_pressed)
6464
sensor_data = robot.get_key_sensor_action(keys_pressed=keys_pressed)
65-
66-
joint_vels, jacobian = robot.calculate_joint_velocities_from_ee_velocity_dls(end_effector_velocity=move_action)
65+
66+
joint_vels, jacobian = robot.calculate_joint_velocities_from_ee_velocity_dls(
67+
end_effector_velocity=move_action
68+
)
6769
singularity = robot.set_joint_velocities(joint_velocities=joint_vels)
68-
70+
6971
# Step simulation
7072
pbutils.pbclient.stepSimulation()
7173
time.sleep(0.001)

pybullet_tree_sim/pruning_environment.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def __init__(
6464
name: str = "PruningEnv",
6565
# num_trees: int | None = None,
6666
renders: bool = False,
67-
6867
verbose: bool = True,
6968
) -> None:
7069
"""Initialize the Pruning Environment
@@ -87,10 +86,9 @@ def __init__(
8786
self.global_step_counter = 0
8887
# self.max_steps = max_steps
8988

90-
9189
self.verbose = verbose
9290

93-
self.collision_object_ids = { # TODO: move to tree.py
91+
self.collision_object_ids = { # TODO: move to tree.py
9492
"SPUR": None,
9593
"TRUNK": None,
9694
"BRANCH": None,
@@ -103,8 +101,6 @@ def __init__(
103101
self.last_button_push_time = time.time()
104102
return
105103

106-
107-
108104
def load_tree( # TODO: Clean up Tree init vs create_tree, probably not needed. Too many file checks.
109105
self,
110106
pbutils: PyBUtils,
@@ -335,10 +331,9 @@ def main():
335331
start = 0.31
336332
stop = 0.35
337333
# Depth data IRL comes in as a C-format nx1 array. start with this IRL
338-
depth_data[:, 3:5] = np.array([
339-
np.arange(start, stop, (stop - start) / 8),
340-
np.arange(start, stop, (stop - start) / 8)
341-
]).T
334+
depth_data[:, 3:5] = np.array(
335+
[np.arange(start, stop, (stop - start) / 8), np.arange(start, stop, (stop - start) / 8)]
336+
).T
342337
depth_data[-1, 3] = 0.31
343338
# Switch to F-format
344339
depth_data = depth_data.reshape((tof0.depth_width * tof0.depth_height, 1), order="F")

pybullet_tree_sim/robot.py

Lines changed: 84 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,19 @@ def __init__(
4343
self.verbose = verbose
4444
self.position = position
4545
self.orientation = orientation
46-
self.randomize_pose = randomize_pose # TODO: This isn't set up anymore... fix
46+
self.randomize_pose = randomize_pose # TODO: This isn't set up anymore... fix
4747
self.init_joint_angles = (
48-
-np.pi / 2,
49-
-np.pi * 2 / 3,
50-
np.pi * 2 / 3,
51-
-np.pi,
52-
-np.pi / 2,
53-
np.pi,
54-
) if init_joint_angles is None else init_joint_angles
48+
(
49+
-np.pi / 2,
50+
-np.pi * 2 / 3,
51+
np.pi * 2 / 3,
52+
-np.pi,
53+
-np.pi / 2,
54+
np.pi,
55+
)
56+
if init_joint_angles is None
57+
else init_joint_angles
58+
)
5559

5660
# Robot setup
5761
self.robot = None
@@ -61,13 +65,13 @@ def __init__(
6165
self._setup_robot()
6266
self.num_joints = self.pbclient.getNumJoints(self.robot)
6367
self.robot_stack: list = self.robot_conf["robot_stack"]
64-
68+
6569
# Links
6670
self.links = self._get_links()
6771
self.robot_collision_filter_idxs = self._assign_collision_links()
6872
self.set_collision_filter(self.robot_collision_filter_idxs)
6973
self.tool0_link_idx = self._get_tool0_link_idx()
70-
74+
7175
# Joints
7276
self.joints = self._get_joints()
7377
self.control_joints, self.control_joint_idxs = self._assign_control_joints(self.joints)
@@ -183,7 +187,7 @@ def _assign_control_joints(self, joints: dict) -> list:
183187
control_joints = []
184188
control_joint_idxs = []
185189
for joint, joint_info in joints.items():
186-
if joint_info["type"] == 0: # TODO: Check if this works for prismatic joints or just revolute
190+
if joint_info["type"] == 0: # TODO: Check if this works for prismatic joints or just revolute
187191
control_joints.append(joint)
188192
control_joint_idxs.append(joint_info["id"])
189193
return control_joints, control_joint_idxs
@@ -194,13 +198,13 @@ def _get_links(self) -> dict:
194198
info = self.pbclient.getJointInfo(self.robot, i)
195199
# log.debug(info)
196200
child_link_name = info[12].decode("utf-8")
197-
links.update({child_link_name: {'id': i, "tf_from_parent": info[14]}})
201+
links.update({child_link_name: {"id": i, "tf_from_parent": info[14]}})
198202
return links
199203

200204
def _assign_collision_links(self) -> list:
201205
"""Find tool0/base pairs, add to collision filter list.
202206
Requires that the robot part is ordered from base to tool0.
203-
207+
204208
TODO: Clean this up, there must be a better way.
205209
"""
206210
robot_collision_filter_idxs = []
@@ -215,15 +219,15 @@ def _assign_collision_links(self) -> list:
215219
):
216220
robot_collision_filter_idxs.append(
217221
(
218-
self.links[robot_part + "__base"]['id'],
219-
self.links[self.robot_conf["robot_stack"][i - 1] + "__tool0"]['id'],
222+
self.links[robot_part + "__base"]["id"],
223+
self.links[self.robot_conf["robot_stack"][i - 1] + "__tool0"]["id"],
220224
)
221225
)
222226
return robot_collision_filter_idxs
223227

224228
def _get_tool0_link_idx(self):
225229
"""TODO: Clean up, find a better way?"""
226-
return self.links[self.robot_conf["robot_stack"][-1] + "__tool0"]['id']
230+
return self.links[self.robot_conf["robot_stack"][-1] + "__tool0"]["id"]
227231

228232
def _get_sensors(self) -> dict:
229233
"""Get sensors on robot based on runtime config files"""
@@ -251,19 +255,23 @@ def _get_sensors(self) -> dict:
251255
robot_part + "__" + metadata["tf_frame"]
252256
) # TODO: find a better way to get the prefix. If
253257
# from robot_conf, need standard for all robots TODO: log an error if robot_part doesn't have all the right frames. Xacro utils?
254-
sensors[sensor_name].tf_id = self.links[sensors[sensor_name].tf_frame]['id']
255-
sensors[sensor_name].tf_from_parent = self.links[sensors[sensor_name].tf_frame]['tf_from_parent']
256-
sensors[sensor_name].pan = metadata["pan"] # TODO: Are these only for cameras/toFs? If so, needs reorg
258+
sensors[sensor_name].tf_id = self.links[sensors[sensor_name].tf_frame]["id"]
259+
sensors[sensor_name].tf_from_parent = self.links[sensors[sensor_name].tf_frame]["tf_from_parent"]
260+
sensors[sensor_name].pan = metadata["pan"] # TODO: Are these only for cameras/toFs? If so, needs reorg
257261
sensors[sensor_name].tilt = metadata["tilt"]
258262
# for key, value in yamlcontent.items():
259263
# sensors.update({Path(file).stem: yamlcontent})
260264
return sensors
261-
265+
262266
def _get_sensor_attributes(self) -> dict:
263267
"""TODO: Delete? This is not used"""
264268
sensor_attributes = {}
265269
# Cameras
266-
camera_configs_path = os.path.join(CONFIG_PATH, "description", "camera",)
270+
camera_configs_path = os.path.join(
271+
CONFIG_PATH,
272+
"description",
273+
"camera",
274+
)
267275
camera_configs_files = glob.glob(os.path.join(camera_configs_path, "*.yaml"))
268276
for file in camera_configs_files:
269277
yamlcontent = yutils.load_yaml(file)
@@ -432,7 +440,6 @@ def calculate_joint_velocities_from_ee_velocity_dls(self, end_effector_velocity,
432440
def create_camera_transform(self, world_position, world_orientation, camera: Camera | None) -> np.ndarray:
433441
"""Create rotation matrix for camera"""
434442
base_offset_tf = np.identity(4)
435-
436443

437444
ee_transform = np.identity(4)
438445
ee_rot_mat = np.array(self.pbclient.getMatrixFromQuaternion(world_orientation)).reshape(3, 3)
@@ -450,20 +457,15 @@ def create_camera_transform(self, world_position, world_orientation, camera: Cam
450457
pan = camera.tilt
451458
base_offset_tf[:3, 3] = camera.xyz_offset
452459

453-
tilt_rot = np.array(
454-
[[1, 0, 0], [0, np.cos(tilt), -np.sin(tilt)], [0, np.sin(tilt), np.cos(tilt)]]
455-
)
460+
tilt_rot = np.array([[1, 0, 0], [0, np.cos(tilt), -np.sin(tilt)], [0, np.sin(tilt), np.cos(tilt)]])
456461
tilt_tf[:3, :3] = tilt_rot
457462

458-
pan_rot = np.array(
459-
[[np.cos(pan), 0, np.sin(pan)], [0, 1, 0], [-np.sin(pan), 0, np.cos(pan)]]
460-
)
463+
pan_rot = np.array([[np.cos(pan), 0, np.sin(pan)], [0, 1, 0], [-np.sin(pan), 0, np.cos(pan)]])
461464
pan_tf[:3, :3] = pan_rot
462-
463-
465+
464466
tf = ee_transform @ pan_tf @ tilt_tf @ base_offset_tf
465467
return tf
466-
468+
467469
def set_collision_filter(self, robot_collision_filter_idxs) -> None:
468470
"""Disable collision between pruner and arm"""
469471
for i in robot_collision_filter_idxs:
@@ -582,7 +584,7 @@ def get_view_mat_at_curr_pose(self, camera: Camera | TimeOfFlight) -> np.ndarray
582584
cameraUpVector=up_vector,
583585
)
584586
return view_matrix
585-
587+
586588
def get_view_mat_by_id_at_curr_pose(self, id) -> np.ndarray:
587589
pos, orientation = self.get_current_pose(id)
588590
camera_tf = self.create_camera_transform(pos, orientation, camera=None)
@@ -591,18 +593,18 @@ def get_view_mat_by_id_at_curr_pose(self, id) -> np.ndarray:
591593
# Initial vectors
592594
camera_vector = np.array([0, 0, 1]) @ camera_tf[:3, :3].T
593595
up_vector = np.array([0, 1, 0]) @ camera_tf[:3, :3].T
594-
596+
595597
# log.debug(f"camera_vector: {camera_vector}")
596598
# log.debug(f"up_vector: {up_vector}")
597-
599+
598600
view_matrix = self.pbclient.computeViewMatrix(
599601
cameraEyePosition=camera_tf[:3, 3],
600602
cameraTargetPosition=camera_tf[:3, 3] + 0.1 * camera_vector,
601603
cameraUpVector=up_vector,
602604
)
603605
# log.warn(np.asarray(view_matrix).reshape((4,4), order="F"))
604606
return view_matrix
605-
607+
606608
def get_rgbd_at_cur_pose(self, camera, type, view_matrix) -> Tuple:
607609
"""Get RGBD image at current pose
608610
@param camera (Camera): Camera object
@@ -623,7 +625,7 @@ def get_rgbd_at_cur_pose(self, camera, type, view_matrix) -> Tuple:
623625
# log.debug(f"depth_after_lin: {depth}")
624626

625627
return rgb, depth
626-
628+
627629
def get_image_at_curr_pose(self, camera, type, view_matrix=None) -> list:
628630
"""Take the current pose of the sensor and capture an image
629631
TODO: Add support for different types of sensors? For now, full rgbd
@@ -661,7 +663,7 @@ def get_image_at_curr_pose(self, camera, type, view_matrix=None) -> list:
661663
# return camera_tf
662664

663665
# Collision checking
664-
#
666+
#
665667
def deproject_pixels_to_points(
666668
self, sensor, data: np.ndarray, view_matrix: np.ndarray, return_frame: str = "world", debug=False
667669
) -> np.ndarray:
@@ -701,16 +703,22 @@ def deproject_pixels_to_points(
701703

702704
# Get camera coordinates from film-plane coordinates. Scale, add z (depth), then homogenize the matrix.
703705
sensor_coords = np.divide(np.multiply(sensor.depth_film_coords, data), [fx, fy])
704-
sensor_coords = np.concatenate((sensor_coords, data, np.ones((sensor.depth_width * sensor.depth_height, 1))), axis=1)
706+
sensor_coords = np.concatenate(
707+
(sensor_coords, data, np.ones((sensor.depth_width * sensor.depth_height, 1))), axis=1
708+
)
705709

706710
return_frame = return_frame.strip().lower()
707711
if return_frame == "sensor":
708712
return sensor_coords
709713
elif return_frame == "world":
710-
world_coords = (mr.TransInv(view_matrix) @ sensor_coords.T).T
714+
world_coords = (mr.TransInv(view_matrix) @ sensor_coords.T).T
711715
if debug:
712716
plot.debug_deproject_pixels_to_points(
713-
sensor=sensor, data=data, cam_coords=sensor_coords, world_coords=world_coords, view_matrix=view_matrix
717+
sensor=sensor,
718+
data=data,
719+
cam_coords=sensor_coords,
720+
world_coords=world_coords,
721+
view_matrix=view_matrix,
714722
)
715723
return world_coords
716724
else:
@@ -780,65 +788,77 @@ def compute_deprojected_point_mask(self):
780788

781789
def get_key_move_action(self, keys_pressed: list) -> np.ndarray:
782790
"""Return an action based on the keys pressed."""
783-
action = np.zeros((6,1), dtype=float)
791+
action = np.zeros((6, 1), dtype=float)
784792
if keys_pressed:
785793
if ord("a") in keys_pressed:
786-
action[0,0] += 0.01
794+
action[0, 0] += 0.01
787795
if ord("d") in keys_pressed:
788-
action[0,0] += -0.01
796+
action[0, 0] += -0.01
789797
if ord("s") in keys_pressed:
790-
action[1,0] += 0.01
798+
action[1, 0] += 0.01
791799
if ord("w") in keys_pressed:
792-
action[1,0] += -0.01
800+
action[1, 0] += -0.01
793801
if ord("q") in keys_pressed:
794-
action[2,0] += 0.01
802+
action[2, 0] += 0.01
795803
if ord("e") in keys_pressed:
796-
action[2,0] += -0.01
804+
action[2, 0] += -0.01
797805
if ord("z") in keys_pressed:
798-
action[3,0] += 0.01
806+
action[3, 0] += 0.01
799807
if ord("c") in keys_pressed:
800-
action[3,0] += -0.01
808+
action[3, 0] += -0.01
801809
if ord("x") in keys_pressed:
802-
action[4,0] += 0.01
810+
action[4, 0] += 0.01
803811
if ord("v") in keys_pressed:
804-
action[4,0] += -0.01
812+
action[4, 0] += -0.01
805813
if ord("r") in keys_pressed:
806-
action[5,0] += 0.05
814+
action[5, 0] += 0.05
807815
if ord("f") in keys_pressed:
808-
action[5,0] += -0.05
816+
action[5, 0] += -0.05
809817
return action
810-
818+
811819
def get_key_sensor_action(self, keys_pressed: list) -> dict | None:
812820
if keys_pressed:
813-
if ord('p') in keys_pressed:
821+
if ord("p") in keys_pressed:
814822
if time.time() - self.debounce_time > 0.1:
815823
sensor_data = {}
816824
for sensor_name, sensor in self.sensors.items():
817-
if sensor_name.startswith('tof'):
825+
if sensor_name.startswith("tof"):
818826
view_matrix = self.get_view_mat_at_curr_pose(camera=sensor)
819-
rgb, depth = self.get_rgbd_at_cur_pose(camera=sensor, type='sensor', view_matrix=view_matrix)
827+
rgb, depth = self.get_rgbd_at_cur_pose(
828+
camera=sensor, type="sensor", view_matrix=view_matrix
829+
)
820830
view_matrix = np.asarray(view_matrix).reshape([4, 4], order="F")
821831
depth = depth.reshape((sensor.depth_width * sensor.depth_height, 1), order="F")
822-
823-
camera_points = self.deproject_pixels_to_points(sensor=sensor, data=depth, view_matrix=view_matrix, return_frame='sensor')
824-
825-
sensor_data.update({sensor_name: {'data': camera_points, 'tf_frame': sensor.tf_frame, 'view_matrix': view_matrix, 'sensor': sensor}})
832+
833+
camera_points = self.deproject_pixels_to_points(
834+
sensor=sensor, data=depth, view_matrix=view_matrix, return_frame="sensor"
835+
)
836+
837+
sensor_data.update(
838+
{
839+
sensor_name: {
840+
"data": camera_points,
841+
"tf_frame": sensor.tf_frame,
842+
"view_matrix": view_matrix,
843+
"sensor": sensor,
844+
}
845+
}
846+
)
826847
# plot.debug_sensor_world_data(sensor_data)
827848
self.debounce_time = time.time()
828849
return sensor_data
829850
else:
830851
return
831852
return
832-
833853

834-
835854
def get_key_action(self, keys_pressed: list):
836855
move_action = self.get_key_move_action(keys_pressed=keys_pressed)
837856
sensor_data = self.get_key_sensor_action(keys_pressed=keys_pressed)
838857
# controller_action = self.get_key_controller_action(keys_pressed=keys_pressed)
839-
858+
840859
return
841860

861+
842862
def main():
843863
from pybullet_tree_sim.utils.pyb_utils import PyBUtils
844864
import time

pybullet_tree_sim/sensors/camera.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Camera(RGBSensor, DepthSensor):
2121
def __init__(self, sensor_type: str = "camera", *args, **kwargs) -> None:
2222
super().__init__(sensor_type=sensor_type, *args, **kwargs)
2323
# TODO: check if camera has depth, check here, not in depth class. How to avoid inheritance without depth? Bool flag passed to super?
24-
24+
2525
return
2626

2727

0 commit comments

Comments
 (0)