Skip to content

Commit 280bb4f

Browse files
committed
[TEST] Refactor AD test visualizations with consistent formatting and conditional rendering
- Standardized code formatting for better readability and consistency. - Introduced conditional rendering of plots based on the `image` flag to improve test flexibility. - Slight adjustments to visualization parameters for enhanced clarity and debugging support.
1 parent dcfa6da commit 280bb4f

1 file changed

Lines changed: 20 additions & 16 deletions

File tree

test/test_modules/test_ad/test_ad_II.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,19 @@ def _add_gradient_glyphs(plotter, sp_coords, geo_data, arrow_scale=0.5):
5757
arrows_poly["scaled_mag"] = scaled_mag
5858

5959
glyphs = arrows_poly.glyph(orient="vectors", scale="scaled_mag",
60-
factor=arrow_scale)
60+
factor=arrow_scale)
6161
plotter.add_mesh(
6262
glyphs,
6363
scalars="gradient_norm",
6464
cmap="plasma",
65-
scalar_bar_args={"title": "‖∇Z‖ (vertex → surface pts)",
65+
scalar_bar_args={"title" : "‖∇Z‖ (vertex → surface pts)",
6666
"title_font_size": 10,
6767
"label_font_size": 9, "n_labels": 3,
68-
"fmt": "%.1e",
69-
"position_x": 0.75, "position_y": 0.02,
70-
"width": 0.22, "height": 0.06,
71-
"color": "black",
72-
"vertical": False},
68+
"fmt" : "%.1e",
69+
"position_x" : 0.75, "position_y": 0.02,
70+
"width" : 0.22, "height": 0.06,
71+
"color" : "black",
72+
"vertical" : False},
7373
label="Gradient (Z-vertex w.r.t. SP)",
7474
)
7575

@@ -168,18 +168,19 @@ def test_generate_fold_model():
168168
vertex_idx = 14
169169
vertices_tensor = geo_data.solutions.dc_meshes[0].vertices_tensor
170170
vertices_tensor[vertex_idx, 2].backward(retain_graph=True,
171-
create_graph=True)
171+
create_graph=True)
172172

173173
# --- Visualisation ---
174+
image = True
174175
sp_coords = geo_data.taped_interpolation_input.surface_points.sp_coords
175176
p3d = gpv.plot_3d(geo_data, show_surfaces=True, show_data=True,
176-
show=False, show_lith=False, image=True,
177+
show=False, show_lith=False, image=image,
177178
kwargs_plot_surfaces={"opacity": 0.7})
178179
plotter = p3d.p
179180

180181
mesh = geo_data.solutions.dc_meshes[0]
181182
vtx_world = _highlight_vertex_and_triangles(plotter, geo_data, mesh,
182-
vertex_idx)
183+
vertex_idx)
183184

184185
# Mocked borehole through the chosen vertex
185186
_add_borehole(plotter, vtx_world, extent_z=(0, 1000))
@@ -188,7 +189,8 @@ def test_generate_fold_model():
188189
_add_gradient_glyphs(plotter, sp_coords, geo_data)
189190

190191
_style_plotter(plotter, title="Fold model – AD gradients")
191-
plotter.show()
192+
if not image:
193+
plotter.show()
192194

193195

194196
# ---------------------------------------------------------------------------
@@ -217,19 +219,20 @@ def test_generate_combination_model():
217219
mesh_id = 2
218220
vertices_tensor = geo_data.solutions.dc_meshes[mesh_id].vertices_tensor
219221
vertices_tensor[vertex_idx, 2].backward(retain_graph=True,
220-
create_graph=True)
222+
create_graph=True)
221223

222224
# --- Visualisation ---
223225
sp_coords = geo_data.taped_interpolation_input.surface_points.sp_coords
224-
226+
227+
image = True
225228
p3d = gpv.plot_3d(geo_data, show_surfaces=True, show_data=True,
226-
show=False, show_lith=False, image=True,
229+
show=False, show_lith=False, image=image,
227230
kwargs_plot_surfaces={"opacity": 0.7})
228231
plotter = p3d.p
229232

230233
mesh = geo_data.solutions.dc_meshes[mesh_id]
231234
vtx_world = _highlight_vertex_and_triangles(plotter, geo_data, mesh,
232-
vertex_idx)
235+
vertex_idx)
233236

234237
# Mocked borehole through the chosen vertex
235238
extent = geo_data.grid.regular_grid.extent
@@ -239,4 +242,5 @@ def test_generate_combination_model():
239242
_add_gradient_glyphs(plotter, sp_coords, geo_data)
240243

241244
_style_plotter(plotter, title="Combination model – AD gradients")
242-
plotter.show()
245+
if not image:
246+
plotter.show()

0 commit comments

Comments
 (0)