Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions quapy/method/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
aggregative.KDEyHD,
# aggregative.OneVsAllAggregative,
confidence.BayesianCC,
confidence.PQ,
}

BINARY_METHODS = {
Expand All @@ -40,6 +41,7 @@
aggregative.MAX,
aggregative.MS,
aggregative.MS2,
confidence.PQ,
}

MULTICLASS_METHODS = {
Expand Down
56 changes: 56 additions & 0 deletions quapy/method/_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods.
"""
import numpy as np
import importlib.resources

try:
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import stan

DEPENDENCIES_INSTALLED = True
except ImportError:
jax = None
jnp = None
numpyro = None
dist = None
stan = None

DEPENDENCIES_INSTALLED = False

Expand Down Expand Up @@ -77,3 +80,56 @@ def sample_posterior(
rng_key = jax.random.PRNGKey(seed)
mcmc.run(rng_key, n_c_unlabeled=n_c_unlabeled, n_y_and_c_labeled=n_y_and_c_labeled)
return mcmc.get_samples()



def load_stan_file():
return importlib.resources.files('quapy.method').joinpath('stan/pq.stan').read_text(encoding='utf-8')

def pq_stan(stan_code, n_bins, pos_hist, neg_hist, test_hist, number_of_samples, num_warmup, stan_seed):
"""
Perform Bayesian prevalence estimation using a Stan model for probabilistic quantification.

This function builds and samples from a Stan model that implements a bin-based Bayesian
quantifier. It uses the class-conditional histograms of the classifier
outputs for positive and negative examples, along with the test histogram, to estimate
the posterior distribution of prevalence in the test set.

Parameters
----------
stan_code : str
The Stan model code as a string.
n_bins : int
Number of bins used to build the histograms for positive and negative examples.
pos_hist : array-like of shape (n_bins,)
Histogram counts of the classifier outputs for the positive class.
neg_hist : array-like of shape (n_bins,)
Histogram counts of the classifier outputs for the negative class.
test_hist : array-like of shape (n_bins,)
Histogram counts of the classifier outputs for the test set, binned using the same bins.
number_of_samples : int
Number of post-warmup samples to draw from the Stan posterior.
num_warmup : int
Number of warmup iterations for the sampler.
stan_seed : int
Random seed for Stan model compilation and sampling, ensuring reproducibility.

Returns
-------
prev_samples : numpy.ndarray
An array of posterior samples of the prevalence (`prev`) in the test set.
Each element corresponds to one draw from the posterior distribution.
"""

stan_data = {
'n_bucket': n_bins,
'train_neg': neg_hist.tolist(),
'train_pos': pos_hist.tolist(),
'test': test_hist.tolist(),
'posterior': 1
}

stan_model = stan.build(stan_code, data=stan_data, random_seed=stan_seed)
fit = stan_model.sample(num_chains=1, num_samples=number_of_samples,num_warmup=num_warmup)

return fit['prev']
103 changes: 101 additions & 2 deletions quapy/method/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import quapy as qp
import quapy.functional as F
from quapy.method import _bayesian
from quapy.method.aggregative import AggregativeCrispQuantifier
from quapy.data import LabelledCollection
from quapy.method.aggregative import AggregativeQuantifier
from quapy.method.aggregative import AggregativeQuantifier, AggregativeCrispQuantifier, AggregativeSoftQuantifier, BinaryAggregativeQuantifier
from scipy.stats import chi2
from sklearn.utils import resample
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -571,3 +570,103 @@ def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, Confide
samples = self.get_prevalence_samples() # available after calling "aggregate" function
region = WithConfidenceABC.construct_region(samples, confidence_level=self.confidence_level, method=self.region)
return point_estimate, region


class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
"""
`Precise Quantifier: Bayesian distribution matching quantifier <https://arxiv.org/abs/2507.06061>,
which is a variant of :class:`HDy` that calculates the posterior probability distribution
over the prevalence vectors, rather than providing a point estimate.

This method relies on extra dependencies, which have to be installed via:
`$ pip install quapy[bayes]`

:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple `(X,y)` defining the specific set of data to use for validation. Set to
None when the method does not require any validation data, in order to avoid that some portion of
the training data be wasted.
:param num_warmup: number of warmup iterations for the STAN sampler (default 500)
:param num_samples: number of samples to draw from the posterior (default 1000)
:param stan_seed: random seed for the STAN sampler (default 0)
:param region: string, set to `intervals` for constructing confidence intervals (default), or to
`ellipse` for constructing an ellipse in the probability simplex, or to `ellipse-clr` for
constructing an ellipse in the Centered-Log Ratio (CLR) unconstrained space.
"""
def __init__(self,
classifier: BaseEstimator=None,
fit_classifier=True,
val_split: int = 5,
n_bins: int = 4,
fixed_bins: bool = False,
num_warmup: int = 500,
num_samples: int = 1_000,
stan_seed: int = 0,
region: str = 'intervals'):

if num_warmup <= 0:
raise ValueError(f'parameter {num_warmup=} must be a positive integer')
if num_samples <= 0:
raise ValueError(f'parameter {num_samples=} must be a positive integer')

if _bayesian.DEPENDENCIES_INSTALLED is False:
raise ImportError("Auxiliary dependencies are required. "
"Run `$ pip install quapy[bayes]` to install them.")

super().__init__(classifier, fit_classifier, val_split)
self.n_bins = n_bins
self.fixed_bins = fixed_bins
self.num_warmup = num_warmup
self.num_samples = num_samples
self.region = region
self.stan_seed = stan_seed
self.stan_code = _bayesian.load_stan_file()

def aggregation_fit(self, classif_predictions, labels):
y_pred = classif_predictions[:, self.pos_label]

# Compute bin limits
if self.fixed_bins:
# Uniform bins in [0,1]
self.bin_limits = np.linspace(0, 1, self.n_bins + 1)
else:
# Quantile bins
self.bin_limits = np.quantile(y_pred, np.linspace(0, 1, self.n_bins + 1))

# Assign each prediction to a bin
bin_indices = np.digitize(y_pred, self.bin_limits[1:-1], right=True)

# Positive and negative masks
pos_mask = (labels == self.pos_label)
neg_mask = ~pos_mask

# Count positives and negatives per bin
self.pos_hist = np.bincount(bin_indices[pos_mask], minlength=self.n_bins)
self.neg_hist = np.bincount(bin_indices[neg_mask], minlength=self.n_bins)


def aggregate(self, classif_predictions):
Px_test = classif_predictions[:, self.pos_label]
test_hist, _ = np.histogram(Px_test, bins=self.bin_limits)
self.prev_distribution = _bayesian.pq_stan(self.stan_code,
self.n_bins,
self.pos_hist,
self.neg_hist,
test_hist,
self.num_samples,
self.num_warmup,
self.stan_seed)
return F.as_binary_prevalence(self.prev_distribution.mean())


def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
classif_predictions = self.classify(instances)
point_estimate = self.aggregate(classif_predictions)
samples = self.prev_distribution
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
return point_estimate, region

2 changes: 1 addition & 1 deletion quapy/method/non_aggregative.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.utils import resample
from sklearn.preprocessing import normalize

from method.confidence import WithConfidenceABC, ConfidenceRegionABC
from quapy.method.confidence import WithConfidenceABC, ConfidenceRegionABC
from quapy.functional import get_divergence
from quapy.method.base import BaseQuantifier, BinaryQuantifier
import quapy.functional as F
Expand Down
38 changes: 38 additions & 0 deletions quapy/method/stan/pq.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
data {
int<lower=0> n_bucket;
array[n_bucket] int<lower=0> train_pos;
array[n_bucket] int<lower=0> train_neg;
array[n_bucket] int<lower=0> test;
int<lower=0,upper=1> posterior;
}

transformed data{
row_vector<lower=0>[n_bucket] train_pos_rv;
row_vector<lower=0>[n_bucket] train_neg_rv;
row_vector<lower=0>[n_bucket] test_rv;
real n_test;

train_pos_rv = to_row_vector( train_pos );
train_neg_rv = to_row_vector( train_neg );
test_rv = to_row_vector( test );
n_test = sum( test );
}

parameters {
simplex[n_bucket] p_neg;
simplex[n_bucket] p_pos;
real<lower=0,upper=1> prev_prior;
}

model {
if( posterior ) {
target += train_neg_rv * log( p_neg );
target += train_pos_rv * log( p_pos );
target += test_rv * log( p_neg * ( 1 - prev_prior) + p_pos * prev_prior );
}
}

generated quantities {
real<lower=0,upper=1> prev;
prev = sum( binomial_rng(test, 1 / ( 1 + (p_neg./p_pos) *(1-prev_prior)/prev_prior ) ) ) / n_test;
}
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ def get_version(rel_path):
#
packages=find_packages(include=['quapy', 'quapy.*']), # Required

package_data={
# For the 'quapy.method' package, include all files
# in the 'stan' subdirectory that end with .stan
'quapy.method': ['stan/*.stan']
},

python_requires='>=3.8, <4',

install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention', 'ucimlrepo', 'certifi'],
Expand All @@ -124,7 +130,7 @@ def get_version(rel_path):
# Similar to `install_requires` above, these must be valid existing
# projects.
extras_require={ # Optional
'bayes': ['jax', 'jaxlib', 'numpyro'],
'bayes': ['jax', 'jaxlib', 'numpyro', 'pystan'],
'neural': ['torch'],
'tests': ['certifi'],
'docs' : ['sphinx-rtd-theme', 'myst-parser'],
Expand Down