Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 55 additions & 30 deletions balance/stats_and_plots/ascii_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,12 @@ def _auto_n_bins(n_samples: int, n_unique: int) -> int:
return max(2, min(sturges, n_unique, 50))


def _auto_bar_width(label_width: int, n_datasets: int) -> int:
"""Pick bar_width to fit within terminal width."""
def _auto_bar_width(label_width: int) -> int:
"""Pick bar_width to fit within terminal width.

Used by grouped barplots and histograms where each dataset gets its own
line within a row (single bar per line).
"""
import shutil

term_width = shutil.get_terminal_size((80, 24)).columns
Expand All @@ -49,6 +53,24 @@ def _auto_bar_width(label_width: int, n_datasets: int) -> int:
return max(10, available)


def _auto_bar_width_columnar(range_width: int, n_columns: int) -> int:
"""Pick per-column bar_width for a columnar (side-by-side) layout.

Used by :func:`ascii_comparative_hist` where all datasets are rendered as
columns on the same line. Each column needs space for the bar, a
percentage string (~6 chars), and inter-column separators (`` | ``, 3
chars each).
"""
import shutil

term_width = shutil.get_terminal_size((80, 24)).columns
# "Range | col1 | col2 | ..."
# range_width + " | " (3+1 for padding) consumed by the label column
available = term_width - range_width - 4
per_col = max(10, (available - (n_columns - 1) * 3) // n_columns - 6)
return per_col


def _weighted_histogram(
values: pd.Series,
weights: Optional[pd.Series],
Expand Down Expand Up @@ -174,7 +196,8 @@ def ascii_plot_bar(
names: Names for each DataFrame (e.g., ["self", "target"]).
column: The categorical column name to plot.
weighted: Whether to use weights. Defaults to True.
bar_width: Maximum character width for bars. Defaults to 40.
bar_width: Maximum character width for bars. Defaults to None,
which auto-detects based on terminal width.
dist_type: Accepted for compatibility but only "hist_ascii" is supported.
A warning is logged if any other value is passed.
separate_categories: If True, insert a blank line between categories
Expand Down Expand Up @@ -243,7 +266,7 @@ def ascii_plot_bar(
label_width = max(label_width, 8) # minimum width for "Category"

if bar_width is None:
bar_width = _auto_bar_width(label_width, len(legend_names))
bar_width = _auto_bar_width(label_width)

# Build output
lines: List[str] = []
Expand Down Expand Up @@ -304,8 +327,10 @@ def ascii_plot_hist(
names: Names for each DataFrame (e.g., ["self", "target"]).
column: The numeric column name to plot.
weighted: Whether to use weights. Defaults to True.
n_bins: Number of histogram bins. Defaults to 10.
bar_width: Maximum character width for bars. Defaults to 40.
n_bins: Number of histogram bins. Defaults to None, which
auto-detects using Sturges' rule.
bar_width: Maximum character width for bars. Defaults to None,
which auto-detects based on terminal width.
dist_type: Accepted for compatibility but only "hist_ascii" is supported.
A warning is logged if any other value is passed.

Expand Down Expand Up @@ -395,7 +420,7 @@ def ascii_plot_hist(
label_width = max(label_width, 3) # minimum width for "Bin"

if bar_width is None:
bar_width = _auto_bar_width(label_width, len(legend_names))
bar_width = _auto_bar_width(label_width)

# Build output
lines: List[str] = []
Expand Down Expand Up @@ -459,8 +484,10 @@ def ascii_comparative_hist(
names: Names for each DataFrame (e.g., ["Target", "Sample"]).
column: The numeric column name to plot.
weighted: Whether to use weights. Defaults to True.
n_bins: Number of histogram bins. Defaults to 10.
bar_width: Maximum character width for bars. Defaults to 20.
n_bins: Number of histogram bins. Defaults to None, which
auto-detects using Sturges' rule.
bar_width: Maximum character width for bars. Defaults to None,
which auto-detects based on terminal width.

Returns:
ASCII comparative histogram text.
Expand All @@ -470,6 +497,8 @@ def ascii_comparative_hist(

>>> print(ascii_comparative_hist(dfs, names=["Target", "Sample"],
... column="income", n_bins=2, bar_width=20))
=== income (numeric, comparative) ===
<BLANKLINE>
Range | Target (%) | Sample (%)
---------------------------------------------------------------
[10.00, 25.00) | █████████████ 50.0 | █████████████▒▒▒▒▒▒▒ 75.0
Expand Down Expand Up @@ -538,14 +567,7 @@ def ascii_comparative_hist(
range_width = max(len(range_header), max(len(lbl) for lbl in bin_labels))

if bar_width is None:
import shutil

term_width = shutil.get_terminal_size((80, 24)).columns
n_cols = len(legend_names)
# Each column needs: bar_width + pct string (~6) + spacing (3)
available = term_width - range_width - 4 # " | " separator
per_col = max(10, (available - (n_cols - 1) * 3) // n_cols - 6)
bar_width = per_col
bar_width = _auto_bar_width_columnar(range_width, len(legend_names))

# Baseline percentages (first dataset)
baseline_pcts = hist_pcts[0]
Expand Down Expand Up @@ -597,6 +619,8 @@ def ascii_comparative_hist(

# Build output
lines: List[str] = []
lines.append(f"=== {column} (numeric, comparative) ===")
lines.append("")

# Header row
header_parts = [range_header.ljust(range_width)]
Expand Down Expand Up @@ -650,7 +674,7 @@ def ascii_plot_dist(

Iterates over variables, classifying each as categorical or numeric
(using the same logic as :func:`seaborn_plot_dist`), then delegates to
:func:`ascii_plot_bar` or :func:`ascii_plot_hist` respectively.
:func:`ascii_plot_bar` or :func:`ascii_comparative_hist` respectively.

The output is both printed to stdout and returned as a string.

Expand All @@ -662,8 +686,10 @@ def ascii_plot_dist(
numeric_n_values_threshold: Columns with fewer unique values than this
are treated as categorical. Defaults to 15.
weighted: Whether to use weights. Defaults to True.
n_bins: Number of bins for numeric histograms. Defaults to 10.
bar_width: Maximum character width for the longest bar. Defaults to 40.
n_bins: Number of bins for numeric histograms. Defaults to None,
which auto-detects using Sturges' rule.
bar_width: Maximum character width for the longest bar. Defaults to
None, which auto-detects based on terminal width.
dist_type: Accepted for compatibility but only "hist_ascii" is supported.
A warning is logged if any other value is passed.
separate_categories: If True, insert a blank line between categories
Expand Down Expand Up @@ -707,17 +733,16 @@ def ascii_plot_dist(
Legend: █ sample ▒ population
Bar lengths are proportional to weighted frequency within each dataset.
<BLANKLINE>
=== age (numeric) ===
=== age (numeric, comparative) ===
<BLANKLINE>
Bin | sample population
|
[10.00, 25.00) | █████████████ (50.0%)
| ▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒ (75.0%)
[25.00, 40.00] | █████████████ (50.0%)
| ▒▒▒▒▒▒▒ (25.0%)
Range | sample (%) | population (%)
---------------------------------------------------------------
[10.00, 25.00) | █████████████ 50.0 | █████████████▒▒▒▒▒▒▒ 75.0
[25.00, 40.00] | █████████████ 50.0 | ███████ ] 25.0
---------------------------------------------------------------
Total | 100.0 | 100.0
<BLANKLINE>
Legend: █ sample ▒ population
Bar lengths are proportional to weighted frequency within each dataset.
Key: █ = shared with sample, ▒ = excess, ] = deficit
"""
if dist_type is not None and dist_type != "hist_ascii":
logger.warning(
Expand Down Expand Up @@ -758,7 +783,7 @@ def ascii_plot_dist(
)
else:
output_parts.append(
ascii_plot_hist(
ascii_comparative_hist(
dfs,
names,
o,
Expand Down
54 changes: 38 additions & 16 deletions tests/test_ascii_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pandas as pd
from balance.stats_and_plots.ascii_plots import (
_auto_bar_width,
_auto_bar_width_columnar,
_auto_n_bins,
_build_legend,
_render_horizontal_bars,
Expand Down Expand Up @@ -362,7 +363,7 @@ def test_dispatches_categorical_and_numeric(self) -> None:
dfs, names=["self", "target"], numeric_n_values_threshold=0
)
self.assertIn("(categorical)", result)
self.assertIn("(numeric)", result)
self.assertIn("(numeric, comparative)", result)

def test_respects_numeric_n_values_threshold(self) -> None:
"""Test that low-cardinality numeric columns are treated as categorical."""
Expand All @@ -376,7 +377,7 @@ def test_respects_numeric_n_values_threshold(self) -> None:

# With threshold=0, treated as numeric
result = ascii_plot_dist(dfs, names=["self"], numeric_n_values_threshold=0)
self.assertIn("(numeric)", result)
self.assertIn("(numeric, comparative)", result)

def test_returns_string(self) -> None:
"""Test that the function returns a string."""
Expand Down Expand Up @@ -607,17 +608,16 @@ def test_e2e_ascii_plot_dist_mixed(self) -> None:
Legend: █ sample ▒ population
Bar lengths are proportional to weighted frequency within each dataset.

=== age (numeric) ===
=== age (numeric, comparative) ===

Bin | sample population
|
[10.00, 25.00) | █████████████ (50.0%)
| ▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒▒ (75.0%)
[25.00, 40.00] | █████████████ (50.0%)
| ▒▒▒▒▒▒▒ (25.0%)
Range | sample (%) | population (%)
---------------------------------------------------------------
[10.00, 25.00) | █████████████ 50.0 | █████████████▒▒▒▒▒▒▒ 75.0
[25.00, 40.00] | █████████████ 50.0 | ███████ ] 25.0
---------------------------------------------------------------
Total | 100.0 | 100.0

Legend: █ sample ▒ population
Bar lengths are proportional to weighted frequency within each dataset.
Key: █ = shared with sample, ▒ = excess, ] = deficit
""",
)

Expand Down Expand Up @@ -675,6 +675,8 @@ def test_e2e_comparative_hist_single_dataset(self) -> None:
self._assert_lines_equal(
result,
"""\
=== age (numeric, comparative) ===

Range | Normal (%)
------------------------------------------
[10.00, 25.00) | ████████████████████ 50.0
Expand All @@ -698,6 +700,8 @@ def test_e2e_comparative_hist_two_datasets(self) -> None:
self._assert_lines_equal(
result,
"""\
=== age (numeric, comparative) ===

Range | Normal (%) | Skewed (%)
---------------------------------------------------------------
[10.00, 25.00) | █████████████ 50.0 | █████████████▒▒▒▒▒▒▒ 75.0
Expand Down Expand Up @@ -726,6 +730,8 @@ def test_e2e_comparative_hist_three_datasets(self) -> None:
self._assert_lines_equal(
result,
"""\
=== v (numeric, comparative) ===

Range | Baseline (%) | Left (%) | Right (%)
---------------------------------------------------------------------------------
[1.00, 1.67) | ██████████ 33.3 | ██████████▒▒▒▒▒ 50.0 | █████ ] 16.7
Expand Down Expand Up @@ -1108,31 +1114,47 @@ def test_n_unique_one(self) -> None:


class TestAutoBarWidth(balance.testutil.BalanceTestCase):
"""Tests for _auto_bar_width helper."""
"""Tests for _auto_bar_width and _auto_bar_width_columnar helpers."""

def test_default_terminal_width(self) -> None:
"""Test bar width computed from mocked terminal width."""
with patch(
"shutil.get_terminal_size", return_value=os.terminal_size((120, 24))
):
result = _auto_bar_width(10, 2)
result = _auto_bar_width(10)
# available = 120 - 10 - 3 - 9 = 98
self.assertEqual(result, 98)

def test_narrow_terminal(self) -> None:
"""Test that bar_width is clamped to minimum 10."""
with patch("shutil.get_terminal_size", return_value=os.terminal_size((20, 24))):
result = _auto_bar_width(15, 1)
result = _auto_bar_width(15)
# available = 20 - 15 - 3 - 9 = -7, clamped to 10
self.assertEqual(result, 10)

def test_standard_terminal(self) -> None:
"""Test with standard 80-column terminal."""
with patch("shutil.get_terminal_size", return_value=os.terminal_size((80, 24))):
result = _auto_bar_width(8, 1)
result = _auto_bar_width(8)
# available = 80 - 8 - 3 - 9 = 60
self.assertEqual(result, 60)

def test_columnar_default_terminal(self) -> None:
"""Test columnar bar width for two columns on a standard terminal."""
with patch("shutil.get_terminal_size", return_value=os.terminal_size((80, 24))):
result = _auto_bar_width_columnar(14, 2)
# available = 80 - 14 - 4 = 62
# per_col = max(10, (62 - 3) // 2 - 6) = max(10, 23) = 23
self.assertEqual(result, 23)

def test_columnar_narrow_terminal(self) -> None:
"""Test columnar bar width clamps to minimum 10."""
with patch("shutil.get_terminal_size", return_value=os.terminal_size((30, 24))):
result = _auto_bar_width_columnar(14, 3)
# available = 30 - 14 - 4 = 12
# per_col = max(10, (12 - 6) // 3 - 6) = max(10, -4) = 10
self.assertEqual(result, 10)


class TestAutoDetectionIntegration(balance.testutil.BalanceTestCase):
"""Tests that auto-detection of n_bins and bar_width works end-to-end."""
Expand Down Expand Up @@ -1173,7 +1195,7 @@ def test_ascii_plot_dist_auto_detection(self) -> None:
with patch("sys.stdout", new_callable=io.StringIO):
result = ascii_plot_dist(dfs, names=["self"], numeric_n_values_threshold=0)
self.assertIn("(categorical)", result)
self.assertIn("(numeric)", result)
self.assertIn("(numeric, comparative)", result)

def test_ascii_comparative_hist_auto_detection(self) -> None:
"""Test ascii_comparative_hist without explicit n_bins or bar_width."""
Expand Down
28 changes: 1 addition & 27 deletions tutorials/balance_ascii_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,7 @@
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ASCII Plots Tutorial\n",
"\n",
"This tutorial demonstrates the text-based plotting functions in `balance`.\n",
"ASCII plots are useful in terminals, CI logs, notebooks with limited rendering,\n",
"or anywhere you want a quick visual comparison without a graphical backend.\n",
"\n",
"We cover:\n",
"1. **Grouped barplots** (`ascii_plot_bar`) for categorical variables\n",
"2. **Grouped histograms** (`ascii_plot_hist`) for numeric variables\n",
"3. **Comparative histograms** (`ascii_comparative_hist`) with baseline-relative rendering\n",
"4. **`ascii_plot_dist`** — the all-in-one dispatcher\n",
"5. Using `library=\"balance\"` with the `.covars().plot()` API\n",
"6. Options: `separate_categories`, `n_bins`, `bar_width`, and auto-detection"
]
"source": "# ASCII Plots Tutorial\n\nThis tutorial demonstrates the text-based plotting functions in `balance`.\nASCII plots are useful in terminals, CI logs, notebooks with limited rendering,\nor anywhere you want a quick visual comparison without a graphical backend.\n\nWe cover:\n1. **Grouped barplots** (`ascii_plot_bar`) for categorical variables\n2. **Grouped histograms** (`ascii_plot_hist`) for numeric variables\n3. **Comparative histograms** (`ascii_comparative_hist`) with baseline-relative rendering\n4. **`ascii_plot_dist`** — the all-in-one dispatcher\n5. Using `library=\"balance\"` with the `.covars().plot()` API"
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -169,18 +155,6 @@
"metadata": {},
"outputs": [],
"source": "dfs = [\n {\"df\": target.covars().df, \"weight\": target.weight_column},\n {\"df\": sample_with_target.covars().df, \"weight\": sample_with_target.weight_column},\n {\"df\": adjusted.covars().df, \"weight\": adjusted.weight_column},\n]\nprint(ascii_comparative_hist(\n dfs, names=[\"Target\", \"Unadjusted\", \"Adjusted\"],\n column=\"income\",\n))"
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## 6. Overriding `n_bins` and `bar_width`\n\nAll the examples above rely on auto-detection:\n- `n_bins` uses Sturges' rule (`ceil(log2(n) + 1)`), capped at the number\n of unique values and clamped to `[2, 50]`.\n- `bar_width` is computed from the terminal width so bars fill the\n available space.\n\nYou can override either (or both) explicitly when you want finer control:"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "# Explicit n_bins and bar_width\nadjusted.covars().plot(\n library=\"balance\", variables=[\"income\"], n_bins=5, bar_width=30,\n);"
}
],
"metadata": {
Expand Down