Skip to content

Commit 87bd6ce

Browse files
committed
add rendering
1 parent 2e1b04b commit 87bd6ce

File tree

2 files changed

+93
-25
lines changed

2 files changed

+93
-25
lines changed

balrog/environments/battleships/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def get_instruction_prompt(env, instruction):
3939
- When you get a hit, explore adjacent cells to determine ship orientation
4040
- Avoid targeting cells adjacent to sunken ships
4141
42-
For each turn, provide the coordinate you wish to attack (e.g., "E5")
42+
IMPORTANT: Your response must be EXACTLY one coordinate in the format of a letter followed by a number (e.g., "E5", "A1", "J10"). Do not provide any explanation or reasoning in your response.
43+
Valid responses: "A1", "B3", "J10"
44+
Invalid responses: "A", "1", "Attack A1", "I choose A1"
4345
4446
PLAY
4547
""".strip()

balrog/environments/battleships/base.py

Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,6 @@ def default_action(self):
2222
def get_text_action(self, action):
2323
return self.language_action_space[action]
2424

25-
def get_text_observation(self, obs):
26-
board = np.empty(self.env.board_size, dtype=str)
27-
28-
# Create a mask for sunk ships
29-
sunk_mask = np.zeros_like(self.ships, dtype=bool)
30-
for i in self.sunk_ships:
31-
sunk_mask = np.logical_or(sunk_mask, self.ships == i)
32-
33-
board[obs[0] != 0] = "❌"
34-
board[obs[1] != 0] = "⚫"
35-
board[sunk_mask] = "💥" # Sunk ships
36-
37-
num_rows, num_columns = board.shape
38-
columns = [chr(i) for i in range(ord("A"), ord("A") + num_columns)]
39-
index = [i + 1 for i in range(num_rows)]
40-
41-
dataframe = pd.DataFrame(board, columns=columns, index=index)
42-
dataframe = dataframe.replace([""], "⬜")
43-
obsv = str(dataframe)
44-
45-
return obsv
46-
4725
def get_feedback(self, reward, old_reward):
4826
if reward is None:
4927
return ""
@@ -60,7 +38,9 @@ def get_feedback(self, reward, old_reward):
6038
return "MISS! Your missile splashed into empty water."
6139

6240
def battleships_process_obsv(self, obs, reward, old_reward):
63-
text_observation = self.get_text_observation(obs)
41+
dataframe = self.get_dataframe(obs)
42+
43+
text_observation = self.get_text_observation(dataframe)
6444
feedback = self.get_feedback(reward, old_reward)
6545

6646
prompt = (
@@ -72,7 +52,7 @@ def battleships_process_obsv(self, obs, reward, old_reward):
7252
obs = defaultdict(lambda: None)
7353

7454
obs["text"] = {"long_term_context": prompt, "short_term_context": ""}
75-
image = None # TODO add rendering
55+
image = self.get_image_observation(dataframe)
7656
obs["image"] = image
7757

7858
return obs
@@ -125,3 +105,89 @@ def step(self, action):
125105

126106
def get_stats(self):
127107
return {"progression": self.progression}
108+
109+
def get_dataframe(self, obs):
110+
board = np.empty(self.env.board_size, dtype=str)
111+
112+
# Create a mask for sunk ships
113+
sunk_mask = np.zeros_like(self.ships, dtype=bool)
114+
for i in self.sunk_ships:
115+
sunk_mask = np.logical_or(sunk_mask, self.ships == i)
116+
117+
board[obs[0] != 0] = "❌"
118+
board[obs[1] != 0] = "⚫"
119+
board[sunk_mask] = "💥" # Sunk ships
120+
121+
num_rows, num_columns = board.shape
122+
columns = [chr(i) for i in range(ord("A"), ord("A") + num_columns)]
123+
index = [i + 1 for i in range(num_rows)]
124+
125+
dataframe = pd.DataFrame(board, columns=columns, index=index)
126+
dataframe = dataframe.replace([""], "⬜")
127+
128+
return dataframe
129+
130+
def get_text_observation(self, dataframe):
131+
obsv = str(dataframe)
132+
133+
return obsv
134+
135+
def get_image_observation(self, dataframe):
136+
# import matplotlib.pyplot as plt
137+
# from matplotlib.colors import LinearSegmentedColormap
138+
# from matplotlib.figure import Figure
139+
# from matplotlib.backends.backend_agg import FigureCanvasAgg
140+
# import io
141+
# from PIL import Image
142+
143+
# # Define colors for each cell type
144+
# color_map = {
145+
# "⬜": [0.9, 0.9, 1.0], # Light blue for empty water
146+
# "❌": [1.0, 0.0, 0.0], # Red for hits
147+
# "⚫": [0.3, 0.3, 0.3], # Dark gray for misses
148+
# "💥": [1.0, 0.6, 0.0] # Orange for sunk ships
149+
# }
150+
151+
# # Create a numerical representation for colormapping
152+
# numeric_board = np.zeros(dataframe.shape + (3,), dtype=float)
153+
154+
# for i in range(dataframe.shape[0]):
155+
# for j in range(dataframe.shape[1]):
156+
# cell_value = dataframe.iloc[i, j]
157+
# numeric_board[i, j] = color_map.get(cell_value, [1, 1, 1])
158+
159+
# # Create a figure with the right dimensions and no padding
160+
# fig_width = dataframe.shape[1] + 1 # +1 for row labels
161+
# fig_height = dataframe.shape[0] + 1 # +1 for column labels
162+
# fig = Figure(figsize=(fig_width, fig_height), dpi=72)
163+
# canvas = FigureCanvasAgg(fig)
164+
# ax = fig.add_subplot(111)
165+
166+
# # Plot the board
167+
# ax.imshow(numeric_board, aspect='equal')
168+
169+
# # Add grid lines
170+
# ax.set_xticks(np.arange(-0.5, dataframe.shape[1], 1), minor=True)
171+
# ax.set_yticks(np.arange(-0.5, dataframe.shape[0], 1), minor=True)
172+
# ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
173+
174+
# # Add column labels (A, B, C, ...)
175+
# ax.set_xticks(np.arange(dataframe.shape[1]))
176+
# ax.set_xticklabels(dataframe.columns)
177+
178+
# # Add row labels (1, 2, 3, ...)
179+
# ax.set_yticks(np.arange(dataframe.shape[0]))
180+
# ax.set_yticklabels(dataframe.index)
181+
182+
# # Remove axis padding
183+
# ax.set_xlim(-0.5, dataframe.shape[1] - 0.5)
184+
# ax.set_ylim(-0.5, dataframe.shape[0] - 0.5)
185+
186+
# # Render the figure to a numpy array
187+
# canvas.draw()
188+
# buf = io.BytesIO()
189+
# fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
190+
# buf.seek(0)
191+
192+
# return Image.open(buf)
193+
return None

0 commit comments

Comments
 (0)