diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bc39a0c6..e59f4ede0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,10 @@ appear in the final output CSV. Keep columns that are not id, weight, covariate, or outcome columns will be placed into ``ignore_columns`` during processing but are still retained and available in the output. +- **Clarified `_prepare_input_model_matrix` argument docs** + - Updated docstrings in `balance.utils.model_matrix` with + explicit descriptions for `sample`, `target`, `variables`, and `add_na` + behavior when preparing model-matrix inputs. ## Bug Fixes @@ -57,6 +61,11 @@ like `a`, `a_1`, and repeated `a` names appear together. - Duplicate columns are now renamed deterministically to guaranteed-unique names, preventing downstream clashes after formula sanitization. +- **`model_matrix` empty-sample errors now raise `ValueError`** + - `_prepare_input_model_matrix()` now raises a deterministic `ValueError` + when the input sample has zero rows, instead of relying on an assertion. + - This aligns runtime behavior with documented exceptions and avoids + optimization-dependent assert behavior. ## Tests diff --git a/balance/utils/model_matrix.py b/balance/utils/model_matrix.py index 74e72bee6..d4a32dc5a 100644 --- a/balance/utils/model_matrix.py +++ b/balance/utils/model_matrix.py @@ -285,14 +285,34 @@ def _prepare_input_model_matrix( - Add na indicator if required. Args: - sample (pd.DataFrame | Any): This can either be a DataFrame or a Sample object. TODO: add text. - target (pd.DataFrame | Any | None, optional): This can either be a DataFrame or a Sample object.. Defaults to None. - variables (List[str] | None, optional): Defaults to None. TODO: add text. - add_na (bool, optional): Defaults to True. TODO: add text. - fix_columns_names (bool, optional): Defaults to True. If to fix the column names of the DataFrame by changing special characters to '_'. + sample (pd.DataFrame | Any): Input sample data as either a + ``pandas.DataFrame`` or a ``Sample`` object from + ``balance.sample_class`` (recognized via ``_isinstance_sample``). + target (pd.DataFrame | Any | None, optional): Optional target data as + either a ``DataFrame`` or a ``Sample`` object. If provided, the + model-matrix inputs are prepared from a sample/target union of + variables and rows. Defaults to None. + variables (List[str] | None, optional): Variables to use from both + inputs. If provided, `choose_variables` validates that each + requested variable exists in both sample and target (when target is + supplied), otherwise it raises ``ValueError``. For ``Sample`` + inputs, this validation/inference is based on covariate names + (``sample.covars().names()``), not all raw ``._df`` columns. If + None, variables are inferred by `choose_variables`. + add_na (bool, optional): If True, add NA indicator columns before + model-matrix creation. If False, drop rows containing missing + values; this can raise ``ValueError`` if dropping rows empties the + sample or target. Defaults to True. + fix_columns_names (bool, optional): Whether to sanitize column names by + replacing non-word characters with ``_`` and making duplicate names + unique. Defaults to True. Raises: - Exception: "Variable names cannot contain characters '[' or ']'" + ValueError: If requested ``variables`` are not present in the + provided input frame(s) (and in both sample and target when target + is supplied), if variables contain ``[`` or ``]``, or if + ``add_na=False`` drops all rows from sample/target, or if + sample has zero rows. Returns: Dict[str, Any]: returns a dictionary containing two keys: 'all_data' and 'sample_n'. @@ -312,7 +332,8 @@ def _prepare_input_model_matrix( sample_df = sample._df else: sample_df = sample - assert sample_df.shape[0] > 0, "sample must have more than zero rows" + if sample_df.shape[0] == 0: + raise ValueError("sample must have more than zero rows") # NOTE: .copy() not needed as it is copied anyway in _concat_frames sample_n = sample_df.shape[0] sample_df = sample_df.loc[:, variables] diff --git a/tests/test_util_model_matrix.py b/tests/test_util_model_matrix.py index 05b7a05f4..03b859f09 100644 --- a/tests/test_util_model_matrix.py +++ b/tests/test_util_model_matrix.py @@ -290,9 +290,9 @@ def test_model_matrix(self) -> None: t, ) - # Test zero rows warning: + # Test zero rows error: self.assertRaisesRegex( - AssertionError, + ValueError, "sample must have more than zero rows", model_matrix, pd.DataFrame(),