1+ import numpy as np
2+ import pettingzoo
3+ from pettingzoo .utils import parallel_to_aec , wrappers
4+ from gymnasium import spaces
5+
6+ from ogm .occupancy_grid_map import OccupancyGridMap
7+
8+ class PivotingCubesEnv (pettingzoo .ParallelEnv ):
9+ metadata = {"render_modes" : ["human" ], "name" : "pivoting_cubes_v0" }
10+
11+ def __init__ (self , initial_positions , final_positions , n_modules , empathy_lambda = 0.0 , max_episode_steps = 200 ):
12+ """
13+ The constructor for the environment.
14+ """
15+ self .ogm = OccupancyGridMap (initial_positions , final_positions , n_modules )
16+
17+ self .agents = [f"module_{ i } " for i in range (1 , n_modules + 1 )]
18+ self .possible_agents = self .agents [:]
19+ self .n_modules = n_modules
20+ self .empathy_lambda = empathy_lambda
21+ self .max_episode_steps = max_episode_steps
22+ self .episode_step = 0
23+
24+ self ._define_spaces ()
25+
26+ def _define_spaces (self ):
27+ # Action space: 48 pivots + 1 NO-OP action
28+ self .action_spaces = {
29+ agent : spaces .Discrete (49 ) for agent in self .agents
30+ }
31+
32+ # Observation space: A dictionary containing the agent's local grid
33+ # and a mask of legal actions.
34+ self .observation_spaces = {
35+ agent : spaces .Dict ({
36+ # The 5x5x5 local map around the agent
37+ "observation" : spaces .Box (low = 0 , high = self .n_modules , shape = (5 , 5 , 5 ), dtype = np .int8 ),
38+ # A binary mask for legal actions
39+ "action_mask" : spaces .Box (low = 0 , high = 1 , shape = (49 ,), dtype = np .int8 )
40+ }) for agent in self .agents
41+ }
42+
43+ def reset (self , seed = None , options = None ):
44+ # Re-initialize the underlying OGM simulation
45+ self .ogm = OccupancyGridMap (
46+ self .ogm .original_module_positions ,
47+ self .ogm .original_final_module_positions ,
48+ self .n_modules
49+ )
50+ self .agents = [f"module_{ i } " for i in range (1 , self .n_modules + 1 )]
51+ self .episode_step = 0
52+
53+ # Get initial observations and infos
54+ observations = self ._get_obs ()
55+ infos = {agent : {} for agent in self .agents }
56+
57+ return observations , infos
58+
59+ def step (self , actions ):
60+ grid_map_t = self .ogm .curr_grid_map .copy ()
61+
62+ proposed_moves = {}
63+ target_positions = {}
64+
65+ for agent_name , action in actions .items ():
66+ if action == 0 : # NO-OP
67+ continue
68+ module_id = int (agent_name .split ('_' )[1 ])
69+ new_pos = self .ogm ._compute_new_position (self .ogm .module_positions [module_id ], action )
70+ if new_pos in target_positions :
71+ # Both moves fail. The first agent that claimed the spot also fails.
72+ conflicting_agent_id = target_positions [new_pos ]
73+ if conflicting_agent_id in proposed_moves :
74+ del proposed_moves [conflicting_agent_id ]
75+ else :
76+ target_positions [new_pos ] = module_id
77+ proposed_moves [module_id ] = new_pos
78+
79+ # validate connectivity
80+ if proposed_moves :
81+ future_positions = self .ogm .module_positions .copy ()
82+ future_positions .update (proposed_moves )
83+ if not self .ogm .is_connected (future_positions ):
84+ # the set of moves is invalid because it breaks the structure.
85+ # reject all moves for this timestep by clearing the dictionary.
86+ proposed_moves = {}
87+
88+ # Execute valid, non-conflicting moves
89+ self .ogm .execute_moves (proposed_moves )
90+
91+ # calc results
92+ terminations = {agent : self .ogm .check_final () for agent in self .agents }
93+ self .episode_step += 1
94+ truncations = {agent : False for agent in self .agents }
95+ if self .episode_step >= self .max_episode_steps :
96+ truncations = {agent : True for agent in self .agents }
97+ self .agents = []
98+ rewards = self ._get_rewards (grid_map_t )
99+ observations = self ._get_obs ()
100+ infos = {agent : {} for agent in self .agents }
101+
102+ # if any agent terminates, the episode is over for all
103+ if any (terminations .values ()):
104+ self .agents = []
105+
106+ return observations , rewards , terminations , truncations , infos
107+
108+ def _get_obs (self ):
109+ # First, calculate all possible actions for the current state
110+ available_actions = self .ogm .calc_possible_actions ()
111+
112+ observations = {}
113+ for agent_name in self .agents :
114+ module_id = int (agent_name .split ('_' )[1 ])
115+
116+ # Action Mask (always allow NO-OP)
117+ action_mask = np .zeros (49 , dtype = np .int8 )
118+ action_mask [0 ] = 1
119+ legal_pivots = np .where (available_actions [module_id ])[0 ]
120+ action_mask [legal_pivots + 1 ] = 1
121+
122+ local_map = self .ogm .get_local_map (module_id , patch_size = 5 )
123+
124+ observations [agent_name ] = {
125+ "observation" : local_map ,
126+ "action_mask" : action_mask
127+ }
128+ return observations
129+
130+ def _get_rewards (self , grid_map_t ):
131+ rewards = {}
132+ local_maps_t = {}
133+ local_maps_tp1 = {}
134+ final_local_maps = {}
135+ positions = {}
136+ for agent_name in self .agents :
137+ module_id = int (agent_name .split ('_' )[1 ])
138+ positions [agent_name ] = self .ogm .module_positions [module_id ]
139+ pos = positions [agent_name ]
140+ half = 2
141+ x , y , z = pos
142+ x_min = max (x - half , 0 )
143+ x_max = min (x + half + 1 , grid_map_t .shape [0 ])
144+ y_min = max (y - half , 0 )
145+ y_max = min (y + half + 1 , grid_map_t .shape [1 ])
146+ z_min = max (z - half , 0 )
147+ z_max = min (z + half + 1 , grid_map_t .shape [2 ])
148+ local_map_t = np .zeros ((5 , 5 , 5 ), dtype = np .int8 )
149+ x_slice = slice (x_min , x_max )
150+ y_slice = slice (y_min , y_max )
151+ z_slice = slice (z_min , z_max )
152+ local_map_t [
153+ (x_min - (x - half )):(x_max - (x - half )),
154+ (y_min - (y - half )):(y_max - (y - half )),
155+ (z_min - (z - half )):(z_max - (z - half ))
156+ ] = grid_map_t [x_slice , y_slice , z_slice ]
157+ local_maps_t [agent_name ] = local_map_t
158+ local_maps_tp1 [agent_name ] = self .ogm .get_local_map (module_id , patch_size = 5 )
159+ final_local_maps [agent_name ] = self .ogm .get_final_local_map (module_id , patch_size = 5 )
160+ base_rewards = {}
161+ for agent_name in self .agents :
162+ obs_t = local_maps_t [agent_name ]
163+ obs_tp1 = local_maps_tp1 [agent_name ]
164+ obs_f = final_local_maps [agent_name ]
165+ # A: positions where obs_tp1 == obs_f
166+ A = set (zip (* np .where (obs_tp1 == obs_f )))
167+ # B: positions where obs_t == obs_f
168+ B = set (zip (* np .where (obs_t == obs_f )))
169+ base_rewards [agent_name ] = len (A - B ) - len (B - A )
170+ # Compute empathy term
171+ for agent_name in self .agents :
172+ pos = positions [agent_name ]
173+ # Find all agents in the 5x5x5 box centered at pos
174+ neighbors = []
175+ for other_name in self .agents :
176+ if other_name == agent_name :
177+ continue
178+ other_pos = positions [other_name ]
179+ if all (abs (p - q ) <= 2 for p , q in zip (pos , other_pos )):
180+ neighbors .append (other_name )
181+ empathy_sum = sum (base_rewards [n ] for n in neighbors )
182+ rewards [agent_name ] = base_rewards [agent_name ] + self .empathy_lambda * empathy_sum
183+ return rewards
184+
185+ def render (self , mode = "human" ):
186+ print ("Current Module Positions:" , self .ogm .module_positions )
0 commit comments