@@ -22,28 +22,6 @@ def default_action(self):
2222 def get_text_action (self , action ):
2323 return self .language_action_space [action ]
2424
25- def get_text_observation (self , obs ):
26- board = np .empty (self .env .board_size , dtype = str )
27-
28- # Create a mask for sunk ships
29- sunk_mask = np .zeros_like (self .ships , dtype = bool )
30- for i in self .sunk_ships :
31- sunk_mask = np .logical_or (sunk_mask , self .ships == i )
32-
33- board [obs [0 ] != 0 ] = "❌"
34- board [obs [1 ] != 0 ] = "⚫"
35- board [sunk_mask ] = "💥" # Sunk ships
36-
37- num_rows , num_columns = board .shape
38- columns = [chr (i ) for i in range (ord ("A" ), ord ("A" ) + num_columns )]
39- index = [i + 1 for i in range (num_rows )]
40-
41- dataframe = pd .DataFrame (board , columns = columns , index = index )
42- dataframe = dataframe .replace (["" ], "⬜" )
43- obsv = str (dataframe )
44-
45- return obsv
46-
4725 def get_feedback (self , reward , old_reward ):
4826 if reward is None :
4927 return ""
@@ -60,7 +38,9 @@ def get_feedback(self, reward, old_reward):
6038 return "MISS! Your missile splashed into empty water."
6139
6240 def battleships_process_obsv (self , obs , reward , old_reward ):
63- text_observation = self .get_text_observation (obs )
41+ dataframe = self .get_dataframe (obs )
42+
43+ text_observation = self .get_text_observation (dataframe )
6444 feedback = self .get_feedback (reward , old_reward )
6545
6646 prompt = (
@@ -72,7 +52,7 @@ def battleships_process_obsv(self, obs, reward, old_reward):
7252 obs = defaultdict (lambda : None )
7353
7454 obs ["text" ] = {"long_term_context" : prompt , "short_term_context" : "" }
75- image = None # TODO add rendering
55+ image = self . get_image_observation ( dataframe )
7656 obs ["image" ] = image
7757
7858 return obs
@@ -125,3 +105,89 @@ def step(self, action):
125105
126106 def get_stats (self ):
127107 return {"progression" : self .progression }
108+
109+ def get_dataframe (self , obs ):
110+ board = np .empty (self .env .board_size , dtype = str )
111+
112+ # Create a mask for sunk ships
113+ sunk_mask = np .zeros_like (self .ships , dtype = bool )
114+ for i in self .sunk_ships :
115+ sunk_mask = np .logical_or (sunk_mask , self .ships == i )
116+
117+ board [obs [0 ] != 0 ] = "❌"
118+ board [obs [1 ] != 0 ] = "⚫"
119+ board [sunk_mask ] = "💥" # Sunk ships
120+
121+ num_rows , num_columns = board .shape
122+ columns = [chr (i ) for i in range (ord ("A" ), ord ("A" ) + num_columns )]
123+ index = [i + 1 for i in range (num_rows )]
124+
125+ dataframe = pd .DataFrame (board , columns = columns , index = index )
126+ dataframe = dataframe .replace (["" ], "⬜" )
127+
128+ return dataframe
129+
130+ def get_text_observation (self , dataframe ):
131+ obsv = str (dataframe )
132+
133+ return obsv
134+
135+ def get_image_observation (self , dataframe ):
136+ # import matplotlib.pyplot as plt
137+ # from matplotlib.colors import LinearSegmentedColormap
138+ # from matplotlib.figure import Figure
139+ # from matplotlib.backends.backend_agg import FigureCanvasAgg
140+ # import io
141+ # from PIL import Image
142+
143+ # # Define colors for each cell type
144+ # color_map = {
145+ # "⬜": [0.9, 0.9, 1.0], # Light blue for empty water
146+ # "❌": [1.0, 0.0, 0.0], # Red for hits
147+ # "⚫": [0.3, 0.3, 0.3], # Dark gray for misses
148+ # "💥": [1.0, 0.6, 0.0] # Orange for sunk ships
149+ # }
150+
151+ # # Create a numerical representation for colormapping
152+ # numeric_board = np.zeros(dataframe.shape + (3,), dtype=float)
153+
154+ # for i in range(dataframe.shape[0]):
155+ # for j in range(dataframe.shape[1]):
156+ # cell_value = dataframe.iloc[i, j]
157+ # numeric_board[i, j] = color_map.get(cell_value, [1, 1, 1])
158+
159+ # # Create a figure with the right dimensions and no padding
160+ # fig_width = dataframe.shape[1] + 1 # +1 for row labels
161+ # fig_height = dataframe.shape[0] + 1 # +1 for column labels
162+ # fig = Figure(figsize=(fig_width, fig_height), dpi=72)
163+ # canvas = FigureCanvasAgg(fig)
164+ # ax = fig.add_subplot(111)
165+
166+ # # Plot the board
167+ # ax.imshow(numeric_board, aspect='equal')
168+
169+ # # Add grid lines
170+ # ax.set_xticks(np.arange(-0.5, dataframe.shape[1], 1), minor=True)
171+ # ax.set_yticks(np.arange(-0.5, dataframe.shape[0], 1), minor=True)
172+ # ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
173+
174+ # # Add column labels (A, B, C, ...)
175+ # ax.set_xticks(np.arange(dataframe.shape[1]))
176+ # ax.set_xticklabels(dataframe.columns)
177+
178+ # # Add row labels (1, 2, 3, ...)
179+ # ax.set_yticks(np.arange(dataframe.shape[0]))
180+ # ax.set_yticklabels(dataframe.index)
181+
182+ # # Remove axis padding
183+ # ax.set_xlim(-0.5, dataframe.shape[1] - 0.5)
184+ # ax.set_ylim(-0.5, dataframe.shape[0] - 0.5)
185+
186+ # # Render the figure to a numpy array
187+ # canvas.draw()
188+ # buf = io.BytesIO()
189+ # fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
190+ # buf.seek(0)
191+
192+ # return Image.open(buf)
193+ return None
0 commit comments