Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@
- Added focused tests for empty threshold iterables, mixed `None` threshold groups in dict mode, and explicit all-`None` threshold handling across return formats.
- **Expanded IPW diagnostics coverage for fit-parameter reporting**
- Refactored diagnostics tests to use a shared IPW setup helper (removing repeated fixture construction), added edge-case assertions for filtered non-string solver/penalty values and NaN coercion of non-scalar `tol`/`l1_ratio` inputs, and now assert solver/penalty labels match fitted model parameters.
- **Added unit coverage for CLI I/O and empty-batch handling**
- Added focused tests for `BalanceCLI.process_batch()` empty-sample failure payloads, `load_and_check_input()` CSV loading paths, and `write_outputs()` delimiter-aware output writing for both adjusted and diagnostics files.

# 0.16.0 (2026-02-09)

Expand Down
3 changes: 0 additions & 3 deletions balance/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,6 @@ def process_batch(
set(result.keys()) == {"adjusted", "diagnostics"}
# True
"""
# TODO: add unit tests
sample_df, target_df = self.split_sample(batch_df)

if sample_df.shape[0] == 0:
Expand Down Expand Up @@ -890,7 +889,6 @@ def load_and_check_input(self) -> pd.DataFrame:
loaded.shape
# (1, 2)
"""
# TODO: Add unit tests for function
# Load and check input
input_df = pd.read_csv(self.args.input_file, sep=self.args.sep_input_file)
logger.info("Number of rows in input file: %d" % input_df.shape[0])
Expand Down Expand Up @@ -932,7 +930,6 @@ def write_outputs(
)
cli.write_outputs(output_df, diagnostics_df)
"""
# TODO: Add unit tests for function
# Write output
output_df.to_csv(
path_or_buf=self.args.output_file,
Expand Down
113 changes: 113 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,119 @@ def test_cli_weights_impact_on_outcome_method(self) -> None:
cli_none = BalanceCLI(args_none)
self.assertIsNone(cli_none.weights_impact_on_outcome_method())

def test_process_batch_returns_failure_payload_for_empty_sample(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
input_file = os.path.join(temp_dir, "input.csv")
output_file = os.path.join(temp_dir, "output.csv")
parser = make_parser()
args = parser.parse_args(
[
"--input_file",
input_file,
"--output_file",
output_file,
"--sample_column",
"is_respondent",
"--covariate_columns",
"x",
]
)
cli = BalanceCLI(args)

batch_df = pd.DataFrame(
{
"is_respondent": [0, 0],
"id": [1, 2],
"weight": [1.0, 1.0],
"x": [1.0, 2.0],
}
)
result = cli.process_batch(batch_df)

self.assertTrue(result["adjusted"].empty)
self.assertEqual(
result["diagnostics"].to_dict("records"),
[
{
"metric": "adjustment_failure",
"var": None,
"val": 1,
},
{
"metric": "adjustment_failure_reason",
"var": None,
"val": "No input data",
},
],
)

def test_load_and_check_input_reads_file_and_columns(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
input_file = os.path.join(temp_dir, "input.csv")
output_file = os.path.join(temp_dir, "output.csv")
parser = make_parser()
args = parser.parse_args(
[
"--input_file",
input_file,
"--output_file",
output_file,
"--sample_column",
"is_respondent",
"--covariate_columns",
"x",
"--keep_row_column",
"keep",
]
)
cli = BalanceCLI(args)

input_df = pd.DataFrame(
{
"is_respondent": [1, 0],
"id": [1, 2],
"weight": [1.0, 1.0],
"x": [1.0, 2.0],
"keep": [1, 0],
}
)
input_df.to_csv(input_file, index=False)

loaded = cli.load_and_check_input()
pd.testing.assert_frame_equal(loaded, input_df)

def test_write_outputs_skips_diagnostics_when_no_output_path(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
input_file = os.path.join(temp_dir, "input.csv")
output_file = os.path.join(temp_dir, "output.csv")
parser = make_parser()
args = parser.parse_args(
[
"--input_file",
input_file,
"--output_file",
output_file,
"--sample_column",
"is_respondent",
"--covariate_columns",
"x",
]
)
cli = BalanceCLI(args)

output_df = pd.DataFrame({"id": [1], "weight": [1.25]})
diagnostics_df = pd.DataFrame(
{"metric": ["adjustment_failure"], "var": [None], "val": [0]}
)

cli.write_outputs(output_df, diagnostics_df)

pd.testing.assert_frame_equal(
pd.read_csv(output_file, sep=","),
output_df,
)
self.assertIsNone(cli.args.diagnostics_output_file)

def test_cli_help(self) -> None:
"""Test that CLI help command executes without errors."""
parser = make_parser()
Expand Down
Loading