Skip to content

Commit b6dcdc1

Browse files
committed
add check for number of lines being plotted and the labels
1 parent 0437b21 commit b6dcdc1

File tree

1 file changed

+56
-20
lines changed

1 file changed

+56
-20
lines changed

tests/test_fitrecipe.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,23 @@ def optimize_recipe(recipe):
343343
leastsq(residuals, values)
344344

345345

346+
def get_labels_and_linecount(ax):
347+
"""Helper to get line labels and count from a matplotlib Axes."""
348+
labels = [
349+
line.get_label()
350+
for line in ax.get_lines()
351+
if not line.get_label().startswith("_")
352+
]
353+
line_count = len(
354+
[
355+
line
356+
for line in ax.get_lines()
357+
if not line.get_label().startswith("_")
358+
]
359+
)
360+
return labels, line_count
361+
362+
346363
def test_plot_recipe_bad_display():
347364
recipe = build_recipe_one_contribution()
348365
# Case: All plots are disabled
@@ -377,7 +394,10 @@ def test_plot_recipe_before_refinement(capsys):
377394
recipe = build_recipe_one_contribution()
378395
plt.close("all")
379396
before = set(plt.get_fignums())
380-
recipe.plot_recipe(show=False)
397+
# include fit_label="nothing" to make sure fit line is not plotted
398+
fig, ax = recipe.plot_recipe(
399+
show=False, data_label="my data", fit_label="nothing", return_fig=True
400+
)
381401
after = set(plt.get_fignums())
382402
new_figs = after - before
383403
captured = capsys.readouterr()
@@ -386,6 +406,12 @@ def test_plot_recipe_before_refinement(capsys):
386406
"Contribution 'c1' has no calculated values (ycalc is None). "
387407
"Only observed data will be plotted."
388408
)
409+
# get labels from the plotted line
410+
actual_label, actual_line_count = get_labels_and_linecount(ax)
411+
expected_line_count = 1
412+
expected_label = ["my data"]
413+
assert actual_line_count == expected_line_count
414+
assert actual_label == expected_label
389415
assert len(new_figs) == 1
390416
assert actual == expected
391417

@@ -397,9 +423,14 @@ def test_plot_recipe_after_refinement():
397423
optimize_recipe(recipe)
398424
plt.close("all")
399425
before = set(plt.get_fignums())
400-
recipe.plot_recipe(show=False)
426+
fig, ax = recipe.plot_recipe(show=False, return_fig=True)
401427
after = set(plt.get_fignums())
402428
new_figs = after - before
429+
actual_label, actual_line_count = get_labels_and_linecount(ax)
430+
expected_label = ["Observed", "Calculated", "Difference"]
431+
expected_line_count = 3
432+
assert actual_line_count == expected_line_count
433+
assert actual_label == expected_label
403434
assert len(new_figs) == 1
404435

405436

@@ -410,7 +441,13 @@ def test_plot_recipe_two_contributions():
410441
optimize_recipe(recipe)
411442
plt.close("all")
412443
before = set(plt.get_fignums())
413-
recipe.plot_recipe(show=False)
444+
figs, axes = recipe.plot_recipe(show=False, return_fig=True)
445+
for ax in axes:
446+
actual_label, actual_line_count = get_labels_and_linecount(ax)
447+
expected_label = ["Observed", "Calculated", "Difference"]
448+
expected_line_count = 3
449+
assert actual_line_count == expected_line_count
450+
assert actual_label == expected_label
414451
after = set(plt.get_fignums())
415452
new_figs = after - before
416453
assert len(new_figs) == 2
@@ -428,9 +465,12 @@ def test_plot_recipe_on_existing_plot():
428465
recipe.plot_recipe(ax=ax, show=False)
429466
actual_title = ax.get_title()
430467
expected_title = "User Title"
468+
actual_labels, actual_line_count = get_labels_and_linecount(ax)
469+
expected_line_count = 4
470+
expected_labels = ["Calculated", "Difference", "New Data", "Observed"]
471+
assert actual_line_count == expected_line_count
472+
assert sorted(actual_labels) == sorted(expected_labels)
431473
assert actual_title == expected_title
432-
labels = [label.get_label() for label in ax.get_lines()]
433-
assert "New Data" in labels
434474

435475

436476
def test_plot_recipe_add_new_data():
@@ -440,20 +480,18 @@ def test_plot_recipe_add_new_data():
440480
optimize_recipe(recipe)
441481
plt.close("all")
442482
before = set(plt.get_fignums())
443-
figure, ax = recipe.plot_recipe(return_fig=True, show=False)
483+
fig, ax = recipe.plot_recipe(return_fig=True, show=False)
444484
after = set(plt.get_fignums())
445485
new_figs = after - before
446486
# add new data to existing plot
447487
ax.plot([0, pi], [0, 0], label="New Data")
448488
ax.legend()
449-
legend = ax.get_legend()
450-
# get sorted list of legend labels for comparison
451-
actual_labels = sorted([t.get_text() for t in legend.get_texts()])
452-
expected_labels = sorted(
453-
["Observed", "Calculated", "Difference", "New Data"]
454-
)
489+
actual_labels, actual_line_count = get_labels_and_linecount(ax)
490+
expected_labels = ["Observed", "Calculated", "Difference", "New Data"]
491+
expected_line_count = 4
455492
assert len(new_figs) == 1
456-
assert actual_labels == expected_labels
493+
assert actual_line_count == expected_line_count
494+
assert sorted(actual_labels) == sorted(expected_labels)
457495

458496

459497
def test_plot_recipe_add_new_data_two_figs():
@@ -471,13 +509,11 @@ def test_plot_recipe_add_new_data_two_figs():
471509
for ax in axes:
472510
ax.plot([0, pi], [0, 0], label="New Data")
473511
ax.legend()
474-
legend = ax.get_legend()
475-
# get sorted list of legend labels for comparison
476-
actual_labels = sorted([t.get_text() for t in legend.get_texts()])
477-
expected_labels = sorted(
478-
["Observed", "Calculated", "Difference", "New Data"]
479-
)
480-
assert actual_labels == expected_labels
512+
actual_labels, actual_line_count = get_labels_and_linecount(ax)
513+
expected_labels = ["Observed", "Calculated", "Difference", "New Data"]
514+
expected_line_count = 4
515+
assert actual_line_count == expected_line_count
516+
assert sorted(actual_labels) == sorted(expected_labels)
481517
assert len(new_figs) == 2
482518

483519

0 commit comments

Comments
 (0)