Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
appear in the final output CSV. Keep columns that are not id, weight,
covariate, or outcome columns will be placed into ``ignore_columns`` during
processing but are still retained and available in the output.
- **Clarified `_prepare_input_model_matrix` argument docs**
- Updated docstrings in `balance.utils.model_matrix` with
explicit descriptions for `sample`, `target`, `variables`, and `add_na`
behavior when preparing model-matrix inputs.

## Bug Fixes

Expand All @@ -57,6 +61,11 @@
like `a`, `a_1`, and repeated `a` names appear together.
- Duplicate columns are now renamed deterministically to guaranteed-unique
names, preventing downstream clashes after formula sanitization.
- **`model_matrix` empty-sample errors now raise `ValueError`**
- `_prepare_input_model_matrix()` now raises a deterministic `ValueError`
when the input sample has zero rows, instead of relying on an assertion.
- This aligns runtime behavior with documented exceptions and avoids
optimization-dependent assert behavior.

## Tests

Expand Down
35 changes: 28 additions & 7 deletions balance/utils/model_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,34 @@ def _prepare_input_model_matrix(
- Add na indicator if required.

Args:
sample (pd.DataFrame | Any): This can either be a DataFrame or a Sample object. TODO: add text.
target (pd.DataFrame | Any | None, optional): This can either be a DataFrame or a Sample object.. Defaults to None.
variables (List[str] | None, optional): Defaults to None. TODO: add text.
add_na (bool, optional): Defaults to True. TODO: add text.
fix_columns_names (bool, optional): Defaults to True. If to fix the column names of the DataFrame by changing special characters to '_'.
sample (pd.DataFrame | Any): Input sample data as either a
``pandas.DataFrame`` or a ``Sample`` object from
``balance.sample_class`` (recognized via ``_isinstance_sample``).
target (pd.DataFrame | Any | None, optional): Optional target data as
either a ``DataFrame`` or a ``Sample`` object. If provided, the
model-matrix inputs are prepared from a sample/target union of
variables and rows. Defaults to None.
variables (List[str] | None, optional): Variables to use from both
inputs. If provided, `choose_variables` validates that each
requested variable exists in both sample and target (when target is
supplied), otherwise it raises ``ValueError``. For ``Sample``
inputs, this validation/inference is based on covariate names
(``sample.covars().names()``), not all raw ``._df`` columns. If
None, variables are inferred by `choose_variables`.
add_na (bool, optional): If True, add NA indicator columns before
model-matrix creation. If False, drop rows containing missing
values; this can raise ``ValueError`` if dropping rows empties the
sample or target. Defaults to True.
fix_columns_names (bool, optional): Whether to sanitize column names by
replacing non-word characters with ``_`` and making duplicate names
unique. Defaults to True.

Raises:
Exception: "Variable names cannot contain characters '[' or ']'"
ValueError: If requested ``variables`` are not present in the
provided input frame(s) (and in both sample and target when target
is supplied), if variables contain ``[`` or ``]``, or if
``add_na=False`` drops all rows from sample/target, or if
sample has zero rows.

Returns:
Dict[str, Any]: returns a dictionary containing two keys: 'all_data' and 'sample_n'.
Expand All @@ -312,7 +332,8 @@ def _prepare_input_model_matrix(
sample_df = sample._df
else:
sample_df = sample
assert sample_df.shape[0] > 0, "sample must have more than zero rows"
if sample_df.shape[0] == 0:
raise ValueError("sample must have more than zero rows")
# NOTE: .copy() not needed as it is copied anyway in _concat_frames
sample_n = sample_df.shape[0]
sample_df = sample_df.loc[:, variables]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_util_model_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ def test_model_matrix(self) -> None:
t,
)

# Test zero rows warning:
# Test zero rows error:
self.assertRaisesRegex(
AssertionError,
ValueError,
"sample must have more than zero rows",
model_matrix,
pd.DataFrame(),
Expand Down
Loading