Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/polars_ds/exprs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .expr_knn import * # noqa : F403
from .expr_linear import * # noqa : F403
from .expr_iter import * # noqa : F403
from .metrics import * # noqa : F403
from .survival import * # noqa : F403
from .num import * # noqa : F403
from .stats import * # noqa : F403
from .string import * # noqa : F403
from .ts_features import * # noqa : F403
from .expr_iter import * # noqa : F403
86 changes: 44 additions & 42 deletions python/polars_ds/exprs/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
"query_multi_roc_auc",
"query_cat_cross_entropy",
"query_confusion_matrix",
"query_fairness",
"query_p_pct_score",
# "query_fairness",
# "query_p_pct_score",
"query_mcc",
"query_dcg_score",
"query_ndcg_score",
Expand Down Expand Up @@ -619,46 +619,46 @@ def query_mcc(y_true: str | pl.Expr, y_pred: str | pl.Expr) -> pl.Expr:
)


def query_fairness(pred: str | pl.Expr, sensitive_cond: pl.Expr) -> pl.Expr:
"""
A simple fairness metric for regression output. Computes the absolute difference between
the average of the `pred` values on when the `sensitive_cond` is true vs the
avg of the values when `sensitive_cond` is false.

The lower this value is, the more fair is the model on the sensitive condition.

Parameters
----------
pred
The predictions
sensitive_cond
A boolean expression representing the sensitive condition
"""
p = to_expr(pred)
return (p.filter(sensitive_cond).mean() - p.filter(~sensitive_cond).mean()).abs()


def query_p_pct_score(pred: str | pl.Expr, sensitive_cond: pl.Expr) -> pl.Expr:
"""
Computes the 'p-percent score', which measures the fairness of a classification
model on a sensitive_cond. Let z = the sensitive_cond, then:

p-percent score = min(P(y = 1 | z = 1) / P(y = 1 | z = 0), P(y = 1 | z = 0) / P(y = 1 | z = 1))

Parameters
----------
pred
The predictions. Must be 0s and 1s.
sensitive_cond
A boolean expression representing the sensitive condition
"""
p = to_expr(pred)
p_y1_z1 = p.filter(
sensitive_cond
).mean() # since p is 0s and 1s, this is equal to P(pred = 1 | sensitive_cond)
p_y1_z0 = p.filter(~sensitive_cond).mean()
ratio = p_y1_z1 / p_y1_z0
return pl.min_horizontal(ratio, 1 / ratio)
# def query_fairness(pred: str | pl.Expr, sensitive_cond: pl.Expr) -> pl.Expr:
# """
# A simple fairness metric for regression output. Computes the absolute difference between
# the average of the `pred` values on when the `sensitive_cond` is true vs the
# avg of the values when `sensitive_cond` is false.

# The lower this value is, the more fair is the model on the sensitive condition.

# Parameters
# ----------
# pred
# The predictions
# sensitive_cond
# A boolean expression representing the sensitive condition
# """
# p = to_expr(pred)
# return (p.filter(sensitive_cond).mean() - p.filter(~sensitive_cond).mean()).abs()


# def query_p_pct_score(pred: str | pl.Expr, sensitive_cond: pl.Expr) -> pl.Expr:
# """
# Computes the 'p-percent score', which measures the fairness of a classification
# model on a sensitive_cond. Let z = the sensitive_cond, then:

# p-percent score = min(P(y = 1 | z = 1) / P(y = 1 | z = 0), P(y = 1 | z = 0) / P(y = 1 | z = 1))

# Parameters
# ----------
# pred
# The predictions. Must be 0s and 1s.
# sensitive_cond
# A boolean expression representing the sensitive condition
# """
# p = to_expr(pred)
# p_y1_z1 = p.filter(
# sensitive_cond
# ).mean() # since p is 0s and 1s, this is equal to P(pred = 1 | sensitive_cond)
# p_y1_z0 = p.filter(~sensitive_cond).mean()
# ratio = p_y1_z1 / p_y1_z0
# return pl.min_horizontal(ratio, 1 / ratio)


def query_dcg_score(
Expand All @@ -669,6 +669,8 @@ def query_dcg_score(
ignore_ties: bool = False,
) -> pl.Expr:
"""
Calculates the Discounted Cumulative Gain score.

Parameters
----------
y_true:
Expand Down
35 changes: 35 additions & 0 deletions python/polars_ds/exprs/survival.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations
import polars as pl
import warnings

# Internal dependencies
from polars_ds._utils import pl_plugin, to_expr

__all__ = [
"query_kaplan_meier_prob"
]

def query_kaplan_meier_prob(
status: str | pl.Expr,
time_exit: str | pl.Expr
) -> pl.Expr:
"""
Computes probabilities given by the Kaplan Meier estimator. This returns a time column and the corresponding probabilities.

Parameters
----------
status
Status column. Can be booleans or 0s and 1s. True or 1 indicates an event and False or 0 indicates right-censoring.
time_exit
Time of event or censoring.
"""
warnings.warn(
"This function's API is considered unstable and might have breaking changes in the future."
, FutureWarning
, stacklevel = 2
)

return pl_plugin(
symbol="pl_kaplan_meier",
args=[to_expr(status).cast(pl.UInt32), to_expr(time_exit)],
)
3 changes: 1 addition & 2 deletions src/num_ext/tp_fp.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::f64;

/// All things true positive, false positive related.
/// ROC AUC, Average Precision, precision, recall, etc. m
use std::f64;
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;

Expand Down
38 changes: 38 additions & 0 deletions src/stats/kaplan_meier.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;

fn kaplan_meier_prob(_: &[Field]) -> PolarsResult<Field> {
let t = Field::new("t".into(), DataType::Float64);
let p = Field::new("prob".into(), DataType::Float64);
let v: Vec<Field> = vec![t, p];
Ok(Field::new("kaplan_meier".into(), DataType::Struct(v)))
}

#[polars_expr(output_type_func=kaplan_meier_prob)]
fn pl_kaplan_meier(inputs: &[Series]) -> PolarsResult<Series> {

let n1 = inputs[0].len();
let n2 = inputs[1].len();

if n1 != n2 {
return Err(PolarsError::ShapeMismatch("Length of status column is not the same as the length of survival time column.".into()));
}

if !inputs.iter().all(|s| s.dtype().is_numeric()) {
return Err(PolarsError::ComputeError("All columns must be numeric.".into()));
}

let df = df!("status"=>inputs[0].clone(), "time_exit"=>inputs[1].clone())?;
let table = df.lazy().group_by(["time_exit"]).agg([
len().alias("cnt")
, col("status").sum().alias("events")
]).sort(["time_exit"], SortMultipleOptions::default()).with_column(
(lit(n1 as u32) - col("cnt").cum_sum(false).shift_and_fill(1, lit(0))).alias("n_at_risk")
).select([
col("time_exit").alias("t"),
(lit(1f64) - col("events").cast(DataType::Float64) / col("n_at_risk").cast(DataType::Float64)).cum_prod(false).alias("prob")
]).collect()?;

let ca = table.into_struct("kaplan_meier".into());
Ok(ca.into_series())
}
1 change: 1 addition & 0 deletions src/stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod normal_test;
mod sample;
mod t_test;
mod xi_corr;
mod kaplan_meier;

use polars::prelude::*;

Expand Down
16 changes: 6 additions & 10 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use cfavml::safe_trait_distance_ops::DistanceOps;
use itertools::Itertools;
use num::Float;
use polars::{
datatypes::{DataType, Field},
Expand Down Expand Up @@ -42,22 +43,17 @@ where
return Err(PolarsError::NoData("Data is empty".into()));
}
if series.iter().any(|s| !s.dtype().is_numeric()) {
return Err(PolarsError::NoData("All columns need to be numeric.".into()));
return Err(PolarsError::ComputeError("All columns need to be numeric.".into()));
}
if !series.iter().map(|s| s.len()).all_equal() {
return Err(PolarsError::ShapeMismatch("Seires don't have the same length.".into()));
}

// Safe because series is not empty
let height: usize = series[0].len();
for s in &series[1..] {
if s.len() != height {
return Err(PolarsError::ShapeMismatch(
"Seires don't have the same length.".into(),
));
}
}
let m = series.len();
let mut membuf = Vec::with_capacity(height * m);
let ptr = membuf.as_ptr() as usize;
// let columns = self.get_columns();

POOL.install(|| {
series.par_iter().enumerate().try_for_each(|(col_idx, s)| {
let s = s.cast(&N::get_static_dtype())?;
Expand Down
3 changes: 2 additions & 1 deletion tests/requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ altair
vegafusion[embed]
vl-convert-python>=1.6
great-tables>=0.9
statsmodels
statsmodels
scikit-survival
34 changes: 34 additions & 0 deletions tests/test_survival.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import polars as pl
import polars_ds as pds
import pytest
from polars.testing import assert_frame_equal

def test_kaplan_meier():

from sksurv.datasets import load_veterans_lung_cancer
from sksurv.nonparametric import kaplan_meier_estimator

_, y = load_veterans_lung_cancer()
time, prob_surv, conf_int = kaplan_meier_estimator(
y["Status"], y["Survival_in_days"], conf_type="log-log"
)

df_result = pl.from_dict({
"t": time
, "prob": prob_surv
})

df = pl.from_dict({
"status": y['Status']
, "survival_time": y["Survival_in_days"]
})

df_pds_result = df.select(
pds.query_kaplan_meier_prob(
"status"
, "survival_time"
).alias("estimate")
).unnest("estimate")

assert_frame_equal(df_result, df_pds_result)