44import gym
55import numpy as np
66import pandas as pd
7+ from PIL import Image , ImageDraw , ImageFont
78from scipy .ndimage import label
89
910
@@ -38,17 +39,16 @@ def get_feedback(self, reward, old_reward):
3839 return "MISS! Your missile splashed into empty water."
3940
4041 def battleships_process_obsv (self , obs , reward , old_reward ):
41- dataframe = self .get_dataframe (obs )
42+ text_observation = self .get_text_observation (obs )
43+ image = self .get_image_observation (obs )
4244
43- text_observation = self .get_text_observation (dataframe )
4445 feedback = self .get_feedback (reward , old_reward )
4546
4647 prompt = f"{ text_observation } \n { feedback } " if feedback else f"Objects on the map:\n { text_observation } "
4748
4849 obs = defaultdict (lambda : None )
4950
5051 obs ["text" ] = {"long_term_context" : prompt , "short_term_context" : "" }
51- image = self .get_image_observation (dataframe )
5252 obs ["image" ] = image
5353
5454 return obs
@@ -102,7 +102,7 @@ def step(self, action):
102102 def get_stats (self ):
103103 return {"progression" : self .progression }
104104
105- def get_dataframe (self , obs ):
105+ def get_text_observation (self , obs ):
106106 board = np .empty (self .env .board_size , dtype = str )
107107
108108 # Create a mask for sunk ships
@@ -121,69 +121,97 @@ def get_dataframe(self, obs):
121121 dataframe = pd .DataFrame (board , columns = columns , index = index )
122122 dataframe = dataframe .replace (["" ], " " )
123123
124- return dataframe
125-
126- def get_text_observation (self , dataframe ):
127124 obsv = str (dataframe )
128125
129126 return obsv
130127
131- def get_image_observation (self , dataframe ):
132- # import matplotlib.pyplot as plt
133- # from matplotlib.colors import LinearSegmentedColormap
134- # from matplotlib.figure import Figure
135- # from matplotlib.backends.backend_agg import FigureCanvasAgg
136- # import io
137- # from PIL import Image
138-
139- # # Define colors for each cell type
140- # color_map = {
141- # "⬜": [0.9, 0.9, 1.0], # Light blue for empty water
142- # "❌": [1.0, 0.0, 0.0], # Red for hits
143- # "⚫": [0.3, 0.3, 0.3], # Dark gray for misses
144- # "💥": [1.0, 0.6, 0.0] # Orange for sunk ships
145- # }
146-
147- # # Create a numerical representation for colormapping
148- # numeric_board = np.zeros(dataframe.shape + (3,), dtype=float)
149-
150- # for i in range(dataframe.shape[0]):
151- # for j in range(dataframe.shape[1]):
152- # cell_value = dataframe.iloc[i, j]
153- # numeric_board[i, j] = color_map.get(cell_value, [1, 1, 1])
154-
155- # # Create a figure with the right dimensions and no padding
156- # fig_width = dataframe.shape[1] + 1 # +1 for row labels
157- # fig_height = dataframe.shape[0] + 1 # +1 for column labels
158- # fig = Figure(figsize=(fig_width, fig_height), dpi=72)
159- # canvas = FigureCanvasAgg(fig)
160- # ax = fig.add_subplot(111)
161-
162- # # Plot the board
163- # ax.imshow(numeric_board, aspect='equal')
164-
165- # # Add grid lines
166- # ax.set_xticks(np.arange(-0.5, dataframe.shape[1], 1), minor=True)
167- # ax.set_yticks(np.arange(-0.5, dataframe.shape[0], 1), minor=True)
168- # ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
169-
170- # # Add column labels (A, B, C, ...)
171- # ax.set_xticks(np.arange(dataframe.shape[1]))
172- # ax.set_xticklabels(dataframe.columns)
173-
174- # # Add row labels (1, 2, 3, ...)
175- # ax.set_yticks(np.arange(dataframe.shape[0]))
176- # ax.set_yticklabels(dataframe.index)
177-
178- # # Remove axis padding
179- # ax.set_xlim(-0.5, dataframe.shape[1] - 0.5)
180- # ax.set_ylim(-0.5, dataframe.shape[0] - 0.5)
181-
182- # # Render the figure to a numpy array
183- # canvas.draw()
184- # buf = io.BytesIO()
185- # fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
186- # buf.seek(0)
187-
188- # return Image.open(buf)
189- return None
128+ def get_image_observation (self , obs , symbol = " " , cell_size = 80 , font_size = 40 , border_width = 2 ):
129+ board = np .empty (self .env .board_size , dtype = str )
130+ board [obs [0 ] != 0 ] = "X"
131+ board [obs [1 ] != 0 ] = "O"
132+
133+ num_rows , num_columns = board .shape
134+ columns = [chr (i ) for i in range (ord ("A" ), ord ("A" ) + num_columns )]
135+ index = [str (i + 1 ) for i in range (num_rows )]
136+
137+ # Calculate image dimensions with space for row/column labels
138+ header_size = cell_size // 2
139+ width = (num_columns * cell_size ) + header_size
140+ height = (num_rows * cell_size ) + header_size
141+
142+ # Create image with white background
143+ image = Image .new ("RGB" , (width , height ), color = "white" )
144+ draw = ImageDraw .Draw (image )
145+
146+ try :
147+ # Try to load a font that supports Unicode symbols
148+ font = ImageFont .truetype ("Arial Unicode MS" , font_size )
149+ except IOError :
150+ try :
151+ # Try another common font
152+ font = ImageFont .truetype ("DejaVuSans.ttf" , font_size )
153+ except IOError :
154+ # Fallback to default font
155+ font = ImageFont .load_default ()
156+
157+ # Draw column headers (A, B, C, ...)
158+ for col_idx , col in enumerate (columns ):
159+ x = header_size + (col_idx * cell_size ) + (cell_size // 2 )
160+ y = header_size // 2
161+ draw .text ((x , y ), col , fill = "black" , font = font , anchor = "mm" )
162+
163+ # Draw row headers (1, 2, 3, ...)
164+ for row_idx , row in enumerate (index ):
165+ x = header_size // 2
166+ y = header_size + (row_idx * cell_size ) + (cell_size // 2 )
167+ draw .text ((x , y ), row , fill = "black" , font = font , anchor = "mm" )
168+
169+ # Draw grid
170+ for row_idx in range (num_rows + 1 ):
171+ y = header_size + (row_idx * cell_size )
172+ draw .line ([(header_size , y ), (width , y )], fill = "black" , width = border_width )
173+
174+ for col_idx in range (num_columns + 1 ):
175+ x = header_size + (col_idx * cell_size )
176+ draw .line ([(x , header_size ), (x , height )], fill = "black" , width = border_width )
177+
178+ # Draw cell contents
179+ for row_idx in range (num_rows ):
180+ for col_idx in range (num_columns ):
181+ cell_content = board [row_idx , col_idx ]
182+ x0 = header_size + (col_idx * cell_size ) + 5
183+ y0 = header_size + (row_idx * cell_size ) + 5
184+ x1 = header_size + ((col_idx + 1 ) * cell_size ) - 5
185+ y1 = header_size + ((row_idx + 1 ) * cell_size ) - 5
186+
187+ # Center of the cell for drawing shapes
188+ center_x = header_size + (col_idx * cell_size ) + (cell_size // 2 )
189+ center_y = header_size + (row_idx * cell_size ) + (cell_size // 2 )
190+ radius = (cell_size // 2 ) - 10
191+ radius = int (radius * 0.7 )
192+ # Draw based on cell content
193+ if cell_content == "X" : # Hit
194+ # Draw a red X
195+ draw .line (
196+ [(center_x - radius , center_y - radius ), (center_x + radius , center_y + radius )],
197+ fill = "red" ,
198+ width = 10 ,
199+ )
200+ draw .line (
201+ [(center_x + radius , center_y - radius ), (center_x - radius , center_y + radius )],
202+ fill = "red" ,
203+ width = 10 ,
204+ )
205+ elif cell_content == "O" : # Miss
206+ # Draw a blue circle
207+ radius = 5
208+ draw .ellipse (
209+ [(center_x - radius , center_y - radius ), (center_x + radius , center_y + radius )],
210+ outline = "black" ,
211+ width = 5 ,
212+ )
213+ else : # Empty
214+ # Draw a light white square
215+ draw .rectangle ([x0 , y0 , x1 , y1 ], fill = "white" )
216+
217+ return image
0 commit comments