Skip to content

Commit 4fb3c85

Browse files
author
saramontese
committed
improve visual explanations
1 parent d9fc185 commit 4fb3c85

1 file changed

Lines changed: 56 additions & 29 deletions

File tree

src/policy_graph/plot_utils.py

Lines changed: 56 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get_trajectory_metrics(ii: AVIntentionIntrospector, trajectory:List[Tuple[Tu
3232

3333

3434

35-
def plot_int_progess(ii: AVIntentionIntrospector, s_a_trajectory:List[Tuple[Tuple[AVPredicate], int]], scene_id:str="", desires:Set[AVDesire] = [], min_int_threshold:float=0.1, output_folder:str=None):#, fill_desires = True ):
35+
def plot_int_progess(ii: AVIntentionIntrospector, s_a_trajectory:List[Tuple[Tuple[AVPredicate], int]], scene_id:str="", desires:Set[AVDesire] = [], min_int_threshold:float=0.1, output_folder:str=None, format='pdf'):#, fill_desires = True ):
3636

3737
"""
3838
Plot intention progression of a scene over the specified desires.
@@ -54,22 +54,31 @@ def plot_int_progess(ii: AVIntentionIntrospector, s_a_trajectory:List[Tuple[Tupl
5454

5555
#Filter intentions in the scene to not overload the plot
5656
if max(intention_vals) > min_int_threshold: #sum(intention_vals) >min_int_threshold:
57-
ax.plot(range(episode_length), intention_vals, label=d_name, color=desire_color[d_name],linestyle='dotted', linewidth=5 )
57+
ax.plot(range(episode_length), intention_vals, label=d_name, color=desire_color[d_name],linestyle='dotted', linewidth=3)
5858

59-
ax.legend(loc='best', facecolor='white', framealpha = 1, fontsize=24, frameon=True)
60-
ax.tick_params(axis='x', labelsize=27)
61-
ax.tick_params(axis='y', labelsize=27)
62-
ax.set_xlabel('Time', fontsize=36)
63-
ax.set_ylabel('Intention Value', fontsize = 36)
64-
59+
#ax.legend(loc='best', facecolor='white', framealpha = 1, fontsize=23, frameon=True)
60+
61+
# Put a legend below current axis
62+
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.10),
63+
fancybox=True, framealpha=1, edgecolor='black', fontsize=23, ncol=len(desire_names),frameon=True)
64+
ax.tick_params(axis='x', labelsize=25)
65+
ax.tick_params(axis='y', labelsize=25)
66+
ax.set_xlabel('Step', fontsize=25)
67+
ax.set_ylabel('Intention Value', fontsize = 25)
68+
ax.set_xlim(0, episode_length-1)
69+
70+
#black plot border
71+
for spine in ax.spines.values():
72+
spine.set_color('black')
6573

6674
ax.set_ylim(bottom=0, top=1)
6775
for t, d_name in desire_fulfill_track.items():
6876
if d_name in desire_names:
69-
ax.vlines(t, -0.05, 1.05, label=d_name, colors=desire_color[d_name], linestyles='-', linewidth=4)
70-
plt.title(f'Intention evolution in scene {scene_id}', fontsize=37)
77+
ax.vlines(t, -0.05, 1.05, label=d_name, colors=desire_color[d_name], linestyles='-', linewidth=2)
78+
#plt.title(f'Intention evolution in scene {scene_id}', fontsize=27)
79+
#plt.grid(False)
7180
if output_folder:
72-
plt.savefig(f'{output_folder}/int_progress_{scene_id}.png', bbox_inches = 'tight', dpi=200)
81+
plt.savefig(f'{output_folder}/int_progress_{scene_id}.{format}', bbox_inches = 'tight', dpi=200, format=format)
7382

7483

7584

@@ -131,7 +140,7 @@ def update(frame):
131140

132141

133142

134-
def roc_curve(ipgs: List[AVIntentionIntrospector], output_folder:str=None, step:float=0.1) -> Dict[str, float]:
143+
def roc_curve(ipgs: List[AVIntentionIntrospector], output_folder:str=None, step:float=0.1, format='pdf') -> Dict[str, float]:
135144
"""
136145
Generate ROC curve for intention metrics and find the best threshold for each discretizer.
137146
@@ -174,18 +183,16 @@ def roc_curve(ipgs: List[AVIntentionIntrospector], output_folder:str=None, step:
174183
plt.plot(intention_probs, expected_probs, label=f'D{disc_id}')
175184

176185
plt.xlabel('Probability of having any intention', fontsize=15)
177-
#plt.xlabel('Intention Probability', fontsize=15)
178186
plt.ylabel('Expected Intention Probability', fontsize = 15)
179187
plt.title("Evolution of 'Any' intention metrics", fontsize = 17)
180-
#plt.title('Intention Metrics for ANY Desire', fontsize = 17)
181188
plt.legend(fontsize=13)
182189
plt.grid(False)
183190
plt.xticks(fontsize=13)
184191
plt.yticks(fontsize=13)
185192

186193

187194
if output_folder:
188-
plt.savefig(f'{output_folder}/roc_s{step}.png', bbox_inches = 'tight', dpi=200)
195+
plt.savefig(f'{output_folder}/roc_s{step}.{format}', bbox_inches = 'tight', dpi=200)
189196
else:
190197
plt.show()
191198

@@ -203,7 +210,7 @@ def annotate_bars(rects, ax, fontsize=8):
203210
)
204211

205212

206-
def plot_metrics(metrics_data:Dict[str, Tuple[float, float]], discretizer_id:str, output_folder:str, c:float=0.5, metric_type:str='Desire', fig_size:Tuple[int, int]=(45, 15), y_lim:float=1.15, colors:Tuple[str,str]=['#008080', '#FF7F50']):
213+
def plot_metrics(metrics_data:Dict[str, Tuple[float, float]], discretizer_id:str, output_folder:str, c:float=0.5, metric_type:str='Desire', fig_size:Tuple[int, int]=(45, 15), y_lim:float=1.05, colors:Tuple[str,str]=['#008080', '#FF7F50'], format='pdf'):
207214
"""
208215
Displays bar plots with desire or intention metrics for each desire.
209216
@@ -219,33 +226,52 @@ def plot_metrics(metrics_data:Dict[str, Tuple[float, float]], discretizer_id:str
219226
val1 = np.array([metrics_data[desire][0] for desire in desires])
220227
val2 = np.array([metrics_data[desire][1] for desire in desires])
221228

222-
x = np.arange(len(desires)) # Ensure x values are tightly packed
229+
x = np.arange(len(desires))
223230
width = 0.3 # Bar width to make pairs touch
231+
fontsize = 62 #35
224232

225233
fig, ax = plt.subplots(figsize=fig_size)
226234

227-
rects1 = ax.bar(x - width/2, val1, width, color=colors[0], label=f'{metric_type} Probability')
228-
metric_label = 'Expected Action Probability' if metric_type == 'Desire' else 'Expected Intention Probability'
229-
rects2 = ax.bar(x + width/2, val2, width, color=colors[1], label=metric_label)
235+
rects1 = ax.bar(x - width/2, val1, width, color=colors[0], label=f'Probability of {metric_type} Attribution',zorder=2)
236+
metric_label = 'Expected Action Probability' if metric_type == 'Desire' else 'Expected Intention'
237+
rects2 = ax.bar(x + width/2, val2, width, color=colors[1], label=metric_label, zorder=2)
230238

231-
ax.set_ylabel('Probability', fontsize=35)
239+
ax.set_ylabel('Probability', fontsize=fontsize)
232240
ax.set_ylim(0, y_lim)
233-
ax.set_title(f'Intention Metrics, C = {c}' if metric_type == 'Intention' else 'Desire Metrics', fontsize=50)
241+
242+
#ax.set_title(f'Intention Metrics for D{discretizer_id}, C_threshold = {c}' if metric_type == 'Intention' else 'Desire Metrics', fontsize=65) #50
234243

235244
ax.set_xticks(x)
236-
ax.set_xticklabels(desires, fontsize=35)
237-
plt.yticks(fontsize=35)
245+
ax.set_xticklabels(desires, fontsize=55)
246+
ax.tick_params( pad=40)
247+
248+
plt.yticks(fontsize=fontsize)
238249

239250
# Remove padding before first bar and after last bar
240251
ax.set_xlim([min(x) - width, max(x) + width])
241252

242-
annotate_bars(rects1, ax, fontsize=35)
243-
annotate_bars(rects2, ax, fontsize=35)
253+
annotate_bars(rects1, ax, fontsize=fontsize)
254+
annotate_bars(rects2, ax, fontsize=fontsize)
255+
256+
#black plot border
257+
for spine in ax.spines.values():
258+
spine.set_color('black')
244259

245-
ax.legend(ncol=2, fontsize=35, loc='upper left', facecolor='white')
260+
ax.grid(False)
246261

247-
save_path = f"{output_folder}/{metric_type}_{discretizer_id}.png"
248-
plt.savefig(save_path, bbox_inches='tight', dpi=200)
262+
# Add horizontal lines at every y-tick
263+
for y_val in ax.get_yticks():
264+
ax.axhline(y=y_val, color='gray', linestyle='-', linewidth=0.3, zorder=0)
265+
266+
267+
#ax.legend(ncol=2, fontsize=fontsize, loc='upper left', facecolor='white')
268+
# Put a legend below current axis
269+
ax.legend(loc='lower center', bbox_to_anchor=(0.5, 1.05),
270+
fancybox=True, framealpha=1, edgecolor='black', fontsize=fontsize, ncol=2,frameon=True)
271+
272+
save_path = f"{output_folder}/{metric_type}_{discretizer_id}.{format}"
273+
os.makedirs(output_folder, exist_ok=True)
274+
plt.savefig(save_path, bbox_inches='tight', dpi=200, format=format)
249275

250276

251277
def plot_metrics_per_desire(desires_data, desire, output_folder, metric_type='Desire', y_lim=1, colors=['#008080', '#FF7F50']):
@@ -308,5 +334,6 @@ def plot_metrics_per_desire(desires_data, desire, output_folder, metric_type='De
308334
fig.delaxes(axes[j])
309335

310336
plt.tight_layout()
337+
plt.grid(False)
311338
plt.savefig(f'{output_folder}/desire_metrics_{desire.name}.png', dpi=100)
312339

0 commit comments

Comments
 (0)