Skip to content

Commit db49cd3

Browse files
committed
Merge branch 'devel' of github.com:HLT-ISTI/QuaPy into devel
2 parents 9da4fd5 + 5f6a151 commit db49cd3

File tree

6 files changed

+212
-5
lines changed

6 files changed

+212
-5
lines changed

quapy/data/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, instances, labels, classes=None):
4444

4545
@property
4646
def index(self):
47-
if self._index is None:
47+
if not hasattr(self, '_index') or self._index is None:
4848
self._index = {class_: np.arange(len(self))[self.labels == class_] for class_ in self.classes_}
4949
return self._index
5050

quapy/method/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
aggregative.KDEyHD,
3030
# aggregative.OneVsAllAggregative,
3131
confidence.BayesianCC,
32+
confidence.PQ,
3233
}
3334

3435
BINARY_METHODS = {
@@ -40,6 +41,7 @@
4041
aggregative.MAX,
4142
aggregative.MS,
4243
aggregative.MS2,
44+
confidence.PQ,
4345
}
4446

4547
MULTICLASS_METHODS = {

quapy/method/_bayesian.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,22 @@
22
Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods.
33
"""
44
import numpy as np
5+
import importlib.resources
56

67
try:
78
import jax
89
import jax.numpy as jnp
910
import numpyro
1011
import numpyro.distributions as dist
12+
import stan
1113

1214
DEPENDENCIES_INSTALLED = True
1315
except ImportError:
1416
jax = None
1517
jnp = None
1618
numpyro = None
1719
dist = None
20+
stan = None
1821

1922
DEPENDENCIES_INSTALLED = False
2023

@@ -77,3 +80,56 @@ def sample_posterior(
7780
rng_key = jax.random.PRNGKey(seed)
7881
mcmc.run(rng_key, n_c_unlabeled=n_c_unlabeled, n_y_and_c_labeled=n_y_and_c_labeled)
7982
return mcmc.get_samples()
83+
84+
85+
86+
def load_stan_file():
87+
return importlib.resources.files('quapy.method').joinpath('stan/pq.stan').read_text(encoding='utf-8')
88+
89+
def pq_stan(stan_code, n_bins, pos_hist, neg_hist, test_hist, number_of_samples, num_warmup, stan_seed):
90+
"""
91+
Perform Bayesian prevalence estimation using a Stan model for probabilistic quantification.
92+
93+
This function builds and samples from a Stan model that implements a bin-based Bayesian
94+
quantifier. It uses the class-conditional histograms of the classifier
95+
outputs for positive and negative examples, along with the test histogram, to estimate
96+
the posterior distribution of prevalence in the test set.
97+
98+
Parameters
99+
----------
100+
stan_code : str
101+
The Stan model code as a string.
102+
n_bins : int
103+
Number of bins used to build the histograms for positive and negative examples.
104+
pos_hist : array-like of shape (n_bins,)
105+
Histogram counts of the classifier outputs for the positive class.
106+
neg_hist : array-like of shape (n_bins,)
107+
Histogram counts of the classifier outputs for the negative class.
108+
test_hist : array-like of shape (n_bins,)
109+
Histogram counts of the classifier outputs for the test set, binned using the same bins.
110+
number_of_samples : int
111+
Number of post-warmup samples to draw from the Stan posterior.
112+
num_warmup : int
113+
Number of warmup iterations for the sampler.
114+
stan_seed : int
115+
Random seed for Stan model compilation and sampling, ensuring reproducibility.
116+
117+
Returns
118+
-------
119+
prev_samples : numpy.ndarray
120+
An array of posterior samples of the prevalence (`prev`) in the test set.
121+
Each element corresponds to one draw from the posterior distribution.
122+
"""
123+
124+
stan_data = {
125+
'n_bucket': n_bins,
126+
'train_neg': neg_hist.tolist(),
127+
'train_pos': pos_hist.tolist(),
128+
'test': test_hist.tolist(),
129+
'posterior': 1
130+
}
131+
132+
stan_model = stan.build(stan_code, data=stan_data, random_seed=stan_seed)
133+
fit = stan_model.sample(num_chains=1, num_samples=number_of_samples,num_warmup=num_warmup)
134+
135+
return fit['prev']

quapy/method/confidence.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
import quapy as qp
66
import quapy.functional as F
77
from quapy.method import _bayesian
8-
from quapy.method.aggregative import AggregativeCrispQuantifier
98
from quapy.data import LabelledCollection
10-
from quapy.method.aggregative import AggregativeQuantifier
9+
from quapy.method.aggregative import AggregativeQuantifier, AggregativeCrispQuantifier, AggregativeSoftQuantifier, BinaryAggregativeQuantifier
1110
from scipy.stats import chi2
1211
from sklearn.utils import resample
1312
from abc import ABC, abstractmethod
@@ -587,8 +586,113 @@ def aggregate(self, classif_predictions):
587586
return np.asarray(samples.mean(axis=0), dtype=float)
588587

589588
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
589+
if confidence_level is None:
590+
confidence_level = self.confidence_level
590591
classif_predictions = self.classify(instances)
591592
point_estimate = self.aggregate(classif_predictions)
592593
samples = self.get_prevalence_samples() # available after calling "aggregate" function
593-
region = WithConfidenceABC.construct_region(samples, confidence_level=self.confidence_level, method=self.region)
594+
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
594595
return point_estimate, region
596+
597+
598+
class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
599+
"""
600+
`Precise Quantifier: Bayesian distribution matching quantifier <https://arxiv.org/abs/2507.06061>,
601+
which is a variant of :class:`HDy` that calculates the posterior probability distribution
602+
over the prevalence vectors, rather than providing a point estimate.
603+
604+
This method relies on extra dependencies, which have to be installed via:
605+
`$ pip install quapy[bayes]`
606+
607+
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
608+
the one indicated in `qp.environ['DEFAULT_CLS']`
609+
:param val_split: specifies the data used for generating classifier predictions. This specification
610+
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
611+
be extracted from the training set; or as an integer (default 5), indicating that the predictions
612+
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
613+
for `k`); or as a tuple `(X,y)` defining the specific set of data to use for validation. Set to
614+
None when the method does not require any validation data, in order to avoid that some portion of
615+
the training data be wasted.
616+
:param num_warmup: number of warmup iterations for the STAN sampler (default 500)
617+
:param num_samples: number of samples to draw from the posterior (default 1000)
618+
:param stan_seed: random seed for the STAN sampler (default 0)
619+
:param region: string, set to `intervals` for constructing confidence intervals (default), or to
620+
`ellipse` for constructing an ellipse in the probability simplex, or to `ellipse-clr` for
621+
constructing an ellipse in the Centered-Log Ratio (CLR) unconstrained space.
622+
"""
623+
def __init__(self,
624+
classifier: BaseEstimator=None,
625+
fit_classifier=True,
626+
val_split: int = 5,
627+
n_bins: int = 4,
628+
fixed_bins: bool = False,
629+
num_warmup: int = 500,
630+
num_samples: int = 1_000,
631+
stan_seed: int = 0,
632+
confidence_level: float = 0.95,
633+
region: str = 'intervals'):
634+
635+
if num_warmup <= 0:
636+
raise ValueError(f'parameter {num_warmup=} must be a positive integer')
637+
if num_samples <= 0:
638+
raise ValueError(f'parameter {num_samples=} must be a positive integer')
639+
640+
if not _bayesian.DEPENDENCIES_INSTALLED:
641+
raise ImportError("Auxiliary dependencies are required. "
642+
"Run `$ pip install quapy[bayes]` to install them.")
643+
644+
super().__init__(classifier, fit_classifier, val_split)
645+
self.n_bins = n_bins
646+
self.fixed_bins = fixed_bins
647+
self.num_warmup = num_warmup
648+
self.num_samples = num_samples
649+
self.stan_seed = stan_seed
650+
self.stan_code = _bayesian.load_stan_file()
651+
self.confidence_level = confidence_level
652+
self.region = region
653+
654+
def aggregation_fit(self, classif_predictions, labels):
655+
y_pred = classif_predictions[:, self.pos_label]
656+
657+
# Compute bin limits
658+
if self.fixed_bins:
659+
# Uniform bins in [0,1]
660+
self.bin_limits = np.linspace(0, 1, self.n_bins + 1)
661+
else:
662+
# Quantile bins
663+
self.bin_limits = np.quantile(y_pred, np.linspace(0, 1, self.n_bins + 1))
664+
665+
# Assign each prediction to a bin
666+
bin_indices = np.digitize(y_pred, self.bin_limits[1:-1], right=True)
667+
668+
# Positive and negative masks
669+
pos_mask = (labels == self.pos_label)
670+
neg_mask = ~pos_mask
671+
672+
# Count positives and negatives per bin
673+
self.pos_hist = np.bincount(bin_indices[pos_mask], minlength=self.n_bins)
674+
self.neg_hist = np.bincount(bin_indices[neg_mask], minlength=self.n_bins)
675+
676+
def aggregate(self, classif_predictions):
677+
Px_test = classif_predictions[:, self.pos_label]
678+
test_hist, _ = np.histogram(Px_test, bins=self.bin_limits)
679+
prevs = _bayesian.pq_stan(
680+
self.stan_code, self.n_bins, self.pos_hist, self.neg_hist, test_hist,
681+
self.num_samples, self.num_warmup, self.stan_seed
682+
).flatten()
683+
self.prev_distribution = np.vstack([1-prevs, prevs]).T
684+
return self.prev_distribution.mean(axis=0)
685+
686+
def aggregate_conf(self, predictions, confidence_level=None):
687+
if confidence_level is None:
688+
confidence_level = self.confidence_level
689+
point_estimate = self.aggregate(predictions)
690+
samples = self.prev_distribution
691+
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
692+
return point_estimate, region
693+
694+
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
695+
predictions = self.classify(instances)
696+
return self.aggregate_conf(predictions, confidence_level=confidence_level)
697+
698+

quapy/method/stan/pq.stan

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
data {
2+
int<lower=0> n_bucket;
3+
array[n_bucket] int<lower=0> train_pos;
4+
array[n_bucket] int<lower=0> train_neg;
5+
array[n_bucket] int<lower=0> test;
6+
int<lower=0,upper=1> posterior;
7+
}
8+
9+
transformed data{
10+
row_vector<lower=0>[n_bucket] train_pos_rv;
11+
row_vector<lower=0>[n_bucket] train_neg_rv;
12+
row_vector<lower=0>[n_bucket] test_rv;
13+
real n_test;
14+
15+
train_pos_rv = to_row_vector( train_pos );
16+
train_neg_rv = to_row_vector( train_neg );
17+
test_rv = to_row_vector( test );
18+
n_test = sum( test );
19+
}
20+
21+
parameters {
22+
simplex[n_bucket] p_neg;
23+
simplex[n_bucket] p_pos;
24+
real<lower=0,upper=1> prev_prior;
25+
}
26+
27+
model {
28+
if( posterior ) {
29+
target += train_neg_rv * log( p_neg );
30+
target += train_pos_rv * log( p_pos );
31+
target += test_rv * log( p_neg * ( 1 - prev_prior) + p_pos * prev_prior );
32+
}
33+
}
34+
35+
generated quantities {
36+
real<lower=0,upper=1> prev;
37+
prev = sum( binomial_rng(test, 1 / ( 1 + (p_neg./p_pos) *(1-prev_prior)/prev_prior ) ) ) / n_test;
38+
}
39+

setup.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def get_version(rel_path):
111111
#
112112
packages=find_packages(include=['quapy', 'quapy.*']), # Required
113113

114+
package_data={
115+
# For the 'quapy.method' package, include all files
116+
# in the 'stan' subdirectory that end with .stan
117+
'quapy.method': ['stan/*.stan']
118+
},
119+
114120
python_requires='>=3.8, <4',
115121

116122
install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention', 'ucimlrepo', 'certifi'],
@@ -124,7 +130,7 @@ def get_version(rel_path):
124130
# Similar to `install_requires` above, these must be valid existing
125131
# projects.
126132
extras_require={ # Optional
127-
'bayes': ['jax', 'jaxlib', 'numpyro'],
133+
'bayes': ['jax', 'jaxlib', 'numpyro', 'pystan'],
128134
'neural': ['torch'],
129135
'tests': ['certifi'],
130136
'docs' : ['sphinx-rtd-theme', 'myst-parser'],

0 commit comments

Comments
 (0)