Skip to content

Commit 87a738e

Browse files
committed
finish image rendering
1 parent 2a56cce commit 87a738e

File tree

1 file changed

+94
-66
lines changed
  • balrog/environments/battleships

1 file changed

+94
-66
lines changed

balrog/environments/battleships/base.py

Lines changed: 94 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import gym
55
import numpy as np
66
import pandas as pd
7+
from PIL import Image, ImageDraw, ImageFont
78
from scipy.ndimage import label
89

910

@@ -38,17 +39,16 @@ def get_feedback(self, reward, old_reward):
3839
return "MISS! Your missile splashed into empty water."
3940

4041
def battleships_process_obsv(self, obs, reward, old_reward):
41-
dataframe = self.get_dataframe(obs)
42+
text_observation = self.get_text_observation(obs)
43+
image = self.get_image_observation(obs)
4244

43-
text_observation = self.get_text_observation(dataframe)
4445
feedback = self.get_feedback(reward, old_reward)
4546

4647
prompt = f"{text_observation}\n{feedback}" if feedback else f"Objects on the map:\n{text_observation}"
4748

4849
obs = defaultdict(lambda: None)
4950

5051
obs["text"] = {"long_term_context": prompt, "short_term_context": ""}
51-
image = self.get_image_observation(dataframe)
5252
obs["image"] = image
5353

5454
return obs
@@ -102,7 +102,7 @@ def step(self, action):
102102
def get_stats(self):
103103
return {"progression": self.progression}
104104

105-
def get_dataframe(self, obs):
105+
def get_text_observation(self, obs):
106106
board = np.empty(self.env.board_size, dtype=str)
107107

108108
# Create a mask for sunk ships
@@ -121,69 +121,97 @@ def get_dataframe(self, obs):
121121
dataframe = pd.DataFrame(board, columns=columns, index=index)
122122
dataframe = dataframe.replace([""], " ")
123123

124-
return dataframe
125-
126-
def get_text_observation(self, dataframe):
127124
obsv = str(dataframe)
128125

129126
return obsv
130127

131-
def get_image_observation(self, dataframe):
132-
# import matplotlib.pyplot as plt
133-
# from matplotlib.colors import LinearSegmentedColormap
134-
# from matplotlib.figure import Figure
135-
# from matplotlib.backends.backend_agg import FigureCanvasAgg
136-
# import io
137-
# from PIL import Image
138-
139-
# # Define colors for each cell type
140-
# color_map = {
141-
# "⬜": [0.9, 0.9, 1.0], # Light blue for empty water
142-
# "❌": [1.0, 0.0, 0.0], # Red for hits
143-
# "⚫": [0.3, 0.3, 0.3], # Dark gray for misses
144-
# "💥": [1.0, 0.6, 0.0] # Orange for sunk ships
145-
# }
146-
147-
# # Create a numerical representation for colormapping
148-
# numeric_board = np.zeros(dataframe.shape + (3,), dtype=float)
149-
150-
# for i in range(dataframe.shape[0]):
151-
# for j in range(dataframe.shape[1]):
152-
# cell_value = dataframe.iloc[i, j]
153-
# numeric_board[i, j] = color_map.get(cell_value, [1, 1, 1])
154-
155-
# # Create a figure with the right dimensions and no padding
156-
# fig_width = dataframe.shape[1] + 1 # +1 for row labels
157-
# fig_height = dataframe.shape[0] + 1 # +1 for column labels
158-
# fig = Figure(figsize=(fig_width, fig_height), dpi=72)
159-
# canvas = FigureCanvasAgg(fig)
160-
# ax = fig.add_subplot(111)
161-
162-
# # Plot the board
163-
# ax.imshow(numeric_board, aspect='equal')
164-
165-
# # Add grid lines
166-
# ax.set_xticks(np.arange(-0.5, dataframe.shape[1], 1), minor=True)
167-
# ax.set_yticks(np.arange(-0.5, dataframe.shape[0], 1), minor=True)
168-
# ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
169-
170-
# # Add column labels (A, B, C, ...)
171-
# ax.set_xticks(np.arange(dataframe.shape[1]))
172-
# ax.set_xticklabels(dataframe.columns)
173-
174-
# # Add row labels (1, 2, 3, ...)
175-
# ax.set_yticks(np.arange(dataframe.shape[0]))
176-
# ax.set_yticklabels(dataframe.index)
177-
178-
# # Remove axis padding
179-
# ax.set_xlim(-0.5, dataframe.shape[1] - 0.5)
180-
# ax.set_ylim(-0.5, dataframe.shape[0] - 0.5)
181-
182-
# # Render the figure to a numpy array
183-
# canvas.draw()
184-
# buf = io.BytesIO()
185-
# fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
186-
# buf.seek(0)
187-
188-
# return Image.open(buf)
189-
return None
128+
def get_image_observation(self, obs, symbol=" ", cell_size=80, font_size=40, border_width=2):
129+
board = np.empty(self.env.board_size, dtype=str)
130+
board[obs[0] != 0] = "X"
131+
board[obs[1] != 0] = "O"
132+
133+
num_rows, num_columns = board.shape
134+
columns = [chr(i) for i in range(ord("A"), ord("A") + num_columns)]
135+
index = [str(i + 1) for i in range(num_rows)]
136+
137+
# Calculate image dimensions with space for row/column labels
138+
header_size = cell_size // 2
139+
width = (num_columns * cell_size) + header_size
140+
height = (num_rows * cell_size) + header_size
141+
142+
# Create image with white background
143+
image = Image.new("RGB", (width, height), color="white")
144+
draw = ImageDraw.Draw(image)
145+
146+
try:
147+
# Try to load a font that supports Unicode symbols
148+
font = ImageFont.truetype("Arial Unicode MS", font_size)
149+
except IOError:
150+
try:
151+
# Try another common font
152+
font = ImageFont.truetype("DejaVuSans.ttf", font_size)
153+
except IOError:
154+
# Fallback to default font
155+
font = ImageFont.load_default()
156+
157+
# Draw column headers (A, B, C, ...)
158+
for col_idx, col in enumerate(columns):
159+
x = header_size + (col_idx * cell_size) + (cell_size // 2)
160+
y = header_size // 2
161+
draw.text((x, y), col, fill="black", font=font, anchor="mm")
162+
163+
# Draw row headers (1, 2, 3, ...)
164+
for row_idx, row in enumerate(index):
165+
x = header_size // 2
166+
y = header_size + (row_idx * cell_size) + (cell_size // 2)
167+
draw.text((x, y), row, fill="black", font=font, anchor="mm")
168+
169+
# Draw grid
170+
for row_idx in range(num_rows + 1):
171+
y = header_size + (row_idx * cell_size)
172+
draw.line([(header_size, y), (width, y)], fill="black", width=border_width)
173+
174+
for col_idx in range(num_columns + 1):
175+
x = header_size + (col_idx * cell_size)
176+
draw.line([(x, header_size), (x, height)], fill="black", width=border_width)
177+
178+
# Draw cell contents
179+
for row_idx in range(num_rows):
180+
for col_idx in range(num_columns):
181+
cell_content = board[row_idx, col_idx]
182+
x0 = header_size + (col_idx * cell_size) + 5
183+
y0 = header_size + (row_idx * cell_size) + 5
184+
x1 = header_size + ((col_idx + 1) * cell_size) - 5
185+
y1 = header_size + ((row_idx + 1) * cell_size) - 5
186+
187+
# Center of the cell for drawing shapes
188+
center_x = header_size + (col_idx * cell_size) + (cell_size // 2)
189+
center_y = header_size + (row_idx * cell_size) + (cell_size // 2)
190+
radius = (cell_size // 2) - 10
191+
radius = int(radius * 0.7)
192+
# Draw based on cell content
193+
if cell_content == "X": # Hit
194+
# Draw a red X
195+
draw.line(
196+
[(center_x - radius, center_y - radius), (center_x + radius, center_y + radius)],
197+
fill="red",
198+
width=10,
199+
)
200+
draw.line(
201+
[(center_x + radius, center_y - radius), (center_x - radius, center_y + radius)],
202+
fill="red",
203+
width=10,
204+
)
205+
elif cell_content == "O": # Miss
206+
# Draw a blue circle
207+
radius = 5
208+
draw.ellipse(
209+
[(center_x - radius, center_y - radius), (center_x + radius, center_y + radius)],
210+
outline="black",
211+
width=5,
212+
)
213+
else: # Empty
214+
# Draw a light white square
215+
draw.rectangle([x0, y0, x1, y1], fill="white")
216+
217+
return image

0 commit comments

Comments
 (0)