Skip to content

Commit c68029b

Browse files
committed
Implement PettingZoo wrapper for OGM
1 parent d8f5cc2 commit c68029b

File tree

4 files changed

+486
-108
lines changed

4 files changed

+486
-108
lines changed

env/pivoting_cubes_env.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import numpy as np
2+
import pettingzoo
3+
from pettingzoo.utils import parallel_to_aec, wrappers
4+
from gymnasium import spaces
5+
6+
from ogm.occupancy_grid_map import OccupancyGridMap
7+
8+
class PivotingCubesEnv(pettingzoo.ParallelEnv):
9+
metadata = {"render_modes": ["human"], "name": "pivoting_cubes_v0"}
10+
11+
def __init__(self, initial_positions, final_positions, n_modules, empathy_lambda=0.0, max_episode_steps=200):
12+
"""
13+
The constructor for the environment.
14+
"""
15+
self.ogm = OccupancyGridMap(initial_positions, final_positions, n_modules)
16+
17+
self.agents = [f"module_{i}" for i in range(1, n_modules + 1)]
18+
self.possible_agents = self.agents[:]
19+
self.n_modules = n_modules
20+
self.empathy_lambda = empathy_lambda
21+
self.max_episode_steps = max_episode_steps
22+
self.episode_step = 0
23+
24+
self._define_spaces()
25+
26+
def _define_spaces(self):
27+
# Action space: 48 pivots + 1 NO-OP action
28+
self.action_spaces = {
29+
agent: spaces.Discrete(49) for agent in self.agents
30+
}
31+
32+
# Observation space: A dictionary containing the agent's local grid
33+
# and a mask of legal actions.
34+
self.observation_spaces = {
35+
agent: spaces.Dict({
36+
# The 5x5x5 local map around the agent
37+
"observation": spaces.Box(low=0, high=self.n_modules, shape=(5, 5, 5), dtype=np.int8),
38+
# A binary mask for legal actions
39+
"action_mask": spaces.Box(low=0, high=1, shape=(49,), dtype=np.int8)
40+
}) for agent in self.agents
41+
}
42+
43+
def reset(self, seed=None, options=None):
44+
# Re-initialize the underlying OGM simulation
45+
self.ogm = OccupancyGridMap(
46+
self.ogm.original_module_positions,
47+
self.ogm.original_final_module_positions,
48+
self.n_modules
49+
)
50+
self.agents = [f"module_{i}" for i in range(1, self.n_modules + 1)]
51+
self.episode_step = 0
52+
53+
# Get initial observations and infos
54+
observations = self._get_obs()
55+
infos = {agent: {} for agent in self.agents}
56+
57+
return observations, infos
58+
59+
def step(self, actions):
60+
grid_map_t = self.ogm.curr_grid_map.copy()
61+
62+
proposed_moves = {}
63+
target_positions = {}
64+
65+
for agent_name, action in actions.items():
66+
if action == 0: # NO-OP
67+
continue
68+
module_id = int(agent_name.split('_')[1])
69+
new_pos = self.ogm._compute_new_position(self.ogm.module_positions[module_id], action)
70+
if new_pos in target_positions:
71+
# Both moves fail. The first agent that claimed the spot also fails.
72+
conflicting_agent_id = target_positions[new_pos]
73+
if conflicting_agent_id in proposed_moves:
74+
del proposed_moves[conflicting_agent_id]
75+
else:
76+
target_positions[new_pos] = module_id
77+
proposed_moves[module_id] = new_pos
78+
79+
# validate connectivity
80+
if proposed_moves:
81+
future_positions = self.ogm.module_positions.copy()
82+
future_positions.update(proposed_moves)
83+
if not self.ogm.is_connected(future_positions):
84+
# the set of moves is invalid because it breaks the structure.
85+
# reject all moves for this timestep by clearing the dictionary.
86+
proposed_moves = {}
87+
88+
# Execute valid, non-conflicting moves
89+
self.ogm.execute_moves(proposed_moves)
90+
91+
# calc results
92+
terminations = {agent: self.ogm.check_final() for agent in self.agents}
93+
self.episode_step += 1
94+
truncations = {agent: False for agent in self.agents}
95+
if self.episode_step >= self.max_episode_steps:
96+
truncations = {agent: True for agent in self.agents}
97+
self.agents = []
98+
rewards = self._get_rewards(grid_map_t)
99+
observations = self._get_obs()
100+
infos = {agent: {} for agent in self.agents}
101+
102+
# if any agent terminates, the episode is over for all
103+
if any(terminations.values()):
104+
self.agents = []
105+
106+
return observations, rewards, terminations, truncations, infos
107+
108+
def _get_obs(self):
109+
# First, calculate all possible actions for the current state
110+
available_actions = self.ogm.calc_possible_actions()
111+
112+
observations = {}
113+
for agent_name in self.agents:
114+
module_id = int(agent_name.split('_')[1])
115+
116+
# Action Mask (always allow NO-OP)
117+
action_mask = np.zeros(49, dtype=np.int8)
118+
action_mask[0] = 1
119+
legal_pivots = np.where(available_actions[module_id])[0]
120+
action_mask[legal_pivots + 1] = 1
121+
122+
local_map = self.ogm.get_local_map(module_id, patch_size=5)
123+
124+
observations[agent_name] = {
125+
"observation": local_map,
126+
"action_mask": action_mask
127+
}
128+
return observations
129+
130+
def _get_rewards(self, grid_map_t):
131+
rewards = {}
132+
local_maps_t = {}
133+
local_maps_tp1 = {}
134+
final_local_maps = {}
135+
positions = {}
136+
for agent_name in self.agents:
137+
module_id = int(agent_name.split('_')[1])
138+
positions[agent_name] = self.ogm.module_positions[module_id]
139+
pos = positions[agent_name]
140+
half = 2
141+
x, y, z = pos
142+
x_min = max(x - half, 0)
143+
x_max = min(x + half + 1, grid_map_t.shape[0])
144+
y_min = max(y - half, 0)
145+
y_max = min(y + half + 1, grid_map_t.shape[1])
146+
z_min = max(z - half, 0)
147+
z_max = min(z + half + 1, grid_map_t.shape[2])
148+
local_map_t = np.zeros((5, 5, 5), dtype=np.int8)
149+
x_slice = slice(x_min, x_max)
150+
y_slice = slice(y_min, y_max)
151+
z_slice = slice(z_min, z_max)
152+
local_map_t[
153+
(x_min - (x - half)):(x_max - (x - half)),
154+
(y_min - (y - half)):(y_max - (y - half)),
155+
(z_min - (z - half)):(z_max - (z - half))
156+
] = grid_map_t[x_slice, y_slice, z_slice]
157+
local_maps_t[agent_name] = local_map_t
158+
local_maps_tp1[agent_name] = self.ogm.get_local_map(module_id, patch_size=5)
159+
final_local_maps[agent_name] = self.ogm.get_final_local_map(module_id, patch_size=5)
160+
base_rewards = {}
161+
for agent_name in self.agents:
162+
obs_t = local_maps_t[agent_name]
163+
obs_tp1 = local_maps_tp1[agent_name]
164+
obs_f = final_local_maps[agent_name]
165+
# A: positions where obs_tp1 == obs_f
166+
A = set(zip(*np.where(obs_tp1 == obs_f)))
167+
# B: positions where obs_t == obs_f
168+
B = set(zip(*np.where(obs_t == obs_f)))
169+
base_rewards[agent_name] = len(A - B) - len(B - A)
170+
# Compute empathy term
171+
for agent_name in self.agents:
172+
pos = positions[agent_name]
173+
# Find all agents in the 5x5x5 box centered at pos
174+
neighbors = []
175+
for other_name in self.agents:
176+
if other_name == agent_name:
177+
continue
178+
other_pos = positions[other_name]
179+
if all(abs(p - q) <= 2 for p, q in zip(pos, other_pos)):
180+
neighbors.append(other_name)
181+
empathy_sum = sum(base_rewards[n] for n in neighbors)
182+
rewards[agent_name] = base_rewards[agent_name] + self.empathy_lambda * empathy_sum
183+
return rewards
184+
185+
def render(self, mode="human"):
186+
print("Current Module Positions:", self.ogm.module_positions)

env/test-temp.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numpy as np
2+
from env.pivoting_cubes_env import PivotingCubesEnv
3+
4+
initial_positions = {
5+
1: (3, 3, 3),
6+
2: (3, 4, 3)
7+
}
8+
final_positions = {
9+
1: (3, 3, 3),
10+
2: (4, 3, 3)
11+
}
12+
n_modules = 2
13+
14+
def main():
15+
env = PivotingCubesEnv(initial_positions, final_positions, n_modules, empathy_lambda=0.1, max_episode_steps=10)
16+
obs, infos = env.reset()
17+
print("Initial observations:")
18+
for agent, ob in obs.items():
19+
print(f"{agent}: {ob}")
20+
done = False
21+
step = 0
22+
while not done and step < 10:
23+
actions = {}
24+
for agent in env.agents:
25+
mask = obs[agent]["action_mask"]
26+
legal_actions = np.where(mask)[0]
27+
actions[agent] = np.random.choice(legal_actions)
28+
obs, rewards, terminations, truncations, infos = env.step(actions)
29+
print(f"\nStep {step+1}")
30+
print("Actions:", actions)
31+
print("Rewards:", rewards)
32+
print("Terminations:", terminations)
33+
print("Truncations:", truncations)
34+
done = not env.agents or all(terminations.values()) or all(truncations.values())
35+
step += 1
36+
print("\nFinal module positions:")
37+
env.render()
38+
39+
if __name__ == "__main__":
40+
main()

0 commit comments

Comments
 (0)