Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion skpro/distributions/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,30 @@
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)
# adapted from sktime

__all__ = ["BaseDistribution", "_DelegatedDistribution", "_BaseArrayDistribution"]
__all__ = [
"BaseDistribution",
"_DelegatedDistribution",
"_BaseArrayDistribution",
"BaseSet",
"IntervalSet",
"FiniteSet",
"IntegerSet",
"UnionSet",
"IntersectionSet",
"EmptySet",
"RealSet",
]

from skpro.distributions.base._base import BaseDistribution
from skpro.distributions.base._base_array import _BaseArrayDistribution
from skpro.distributions.base._delegate import _DelegatedDistribution
from skpro.distributions.base._set import (
BaseSet,
EmptySet,
FiniteSet,
IntegerSet,
IntersectionSet,
IntervalSet,
RealSet,
UnionSet,
)
193 changes: 166 additions & 27 deletions skpro/distributions/base/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,44 @@ def _sample_mean(self, spl):
else:
return spl.mean().iloc[0]

@property
def support(self):
"""Return the support of the distribution as a symbolic set.

The support is the smallest closed set whose complement has
probability zero.

Returns
-------
BaseSet
Symbolic set representation of the distribution's support.

Examples
--------
>>> from skpro.distributions.normal import Normal
>>> n = Normal(mu=0, sigma=1)
>>> n.support
RealSet()
"""
return self._support()

def _support(self):
r"""Return the support of the distribution.

Private method, to be overridden by subclasses.

Default returns ``RealSet``, appropriate for continuous distributions
with support on all of :math:`\\mathbb{R}` (e.g., Normal, Cauchy).

Returns
-------
BaseSet
The support of this distribution.
"""
from skpro.distributions.base._set import RealSet

return RealSet(index=self.index, columns=self.columns)

def mean(self):
r"""Return expected value of the distribution.

Expand Down Expand Up @@ -1696,8 +1734,7 @@ def plot(self, fun=None, ax=None, **kwargs):

if self.ndim > 0:
if "x_bounds" not in kwargs:
upper = self.ppf(0.999).values.flatten().max()
lower = self.ppf(0.001).values.flatten().min()
lower, upper = self._get_plot_bounds()
x_bounds = (lower, upper)
else:
x_bounds = kwargs.pop("x_bounds")
Expand Down Expand Up @@ -1755,8 +1792,43 @@ def get_ax(ax, i, j, shape):
ax = getattr(self, plot_fun_name)(ax=ax, fun=fun, **kwargs)
return ax

def _get_plot_bounds(self):
"""Get x-axis bounds for plotting using support information.

Uses ``support.boundary()`` for exact bounds when available.
Falls back to ``ppf(0.001)``/``ppf(0.999)`` for infinite or
unavailable bounds.

Returns
-------
tuple of (lower, upper)
x-axis bounds for plotting.
"""
lower, upper = None, None
supp = getattr(self, "support", None)

if supp is not None and hasattr(supp, "boundary"):
bd = supp.boundary()
if isinstance(bd, tuple) and len(bd) == 2:
lower = float(bd[0]) if not np.isinf(bd[0]) else None
upper = float(bd[1]) if not np.isinf(bd[1]) else None

# fall back to ppf for any missing bounds
if lower is None or upper is None:
ppf_low = self.ppf(0.001)
ppf_high = self.ppf(0.999)
if self.ndim > 0:
ppf_low = ppf_low.values.flatten().min()
ppf_high = ppf_high.values.flatten().max()
if lower is None:
lower = float(ppf_low)
if upper is None:
upper = float(ppf_high)

return lower, upper

def _plot_single(self, ax=None, **kwargs):
"""Plot the pdf of the distribution."""
"""Plot the distribution function, handling discrete/continuous/mixed types."""
import matplotlib.pyplot as plt

fun = kwargs.pop("fun")
Expand All @@ -1767,57 +1839,124 @@ def _plot_single(self, ax=None, **kwargs):
if "x_bounds" in kwargs:
lower, upper = kwargs.pop("x_bounds")
elif fun != "ppf":
lower, upper = self.ppf(0.001), self.ppf(0.999)

if fun == "ppf":
lower, upper = self._get_plot_bounds()
else:
lower, upper = 0.001, 0.999

is_discrete = self.get_tag("distr:measuretype", "mixed") == "discrete"

x_arr = self._get_x_for_plot(fun, lower, upper, is_discrete)

y_arr = [getattr(self, fun)(x) for x in x_arr]
y_arr = np.array(y_arr)

if ax is None:
ax = plt.gca()

# Use stem plot for discrete PMF, line plot otherwise
if is_discrete and fun == "pmf":
ax.stem(x_arr, y_arr, basefmt=" ", **kwargs)
measuretype = self.get_tag("distr:measuretype", "mixed")

if measuretype == "discrete":
# pure discrete: stem plot at exact support points
x_arr = self._get_discrete_support_points(lower, upper)
y_arr = np.array([getattr(self, fun)(x) for x in x_arr])
if fun == "pmf":
ax.stem(x_arr, y_arr, basefmt=" ", **kwargs)
elif fun == "cdf":
ax.step(x_arr, y_arr, where="post", **kwargs)
else:
ax.plot(x_arr, y_arr, **kwargs)

elif measuretype == "mixed":
# mixed distribution (e.g., ZeroInflated):
discrete_pts = self._get_discrete_support_points(lower, upper)

# 1. plot the continuous part as a line (if pdf or cdf)
x_cont = np.linspace(lower, upper, 1000)
if fun in ("pdf", "cdf"):
y_cont = np.array([getattr(self, fun)(x) for x in x_cont])

# Mask out continuous density overlapping exactly with Dirac spikes
if fun == "pdf" and len(discrete_pts) > 0:
for pt in discrete_pts:
y_cont[np.isclose(x_cont, pt, atol=1e-3)] = np.nan

ax.plot(x_cont, y_cont, **kwargs)

# 2. always overlay discrete mass points as stems
if len(discrete_pts) > 0:
# get the mass at these points from pmf
y_disc = np.array([self.pmf(x) for x in discrete_pts])
# use different style from line to distinguish
stem_kwargs = {k: v for k, v in kwargs.items() if k != "label"}
ax.stem(
discrete_pts,
y_disc,
basefmt=" ",
linefmt="C1-",
markerfmt="C1o",
**stem_kwargs,
)
else:
# pure continuous: dense line plot
x_arr = np.linspace(lower, upper, 1000)
y_arr = np.array([getattr(self, fun)(x) for x in x_arr])
ax.plot(x_arr, y_arr, **kwargs)

if print_labels == "on":
ax.set_xlabel(f"{x_argname}")
ax.set_ylabel(f"{fun}({x_argname})")
return ax

def _get_x_for_plot(self, fun, lower, upper, is_discrete):
"""Get x values for plotting, handling discrete distributions for PMF."""
# general case: not discrete, or not pmf
if not is_discrete or fun != "pmf":
# in this case, the function is on a continuous domain,
# so we can plot on a dense grid of points
return np.linspace(lower, upper, 1000)
def _get_discrete_support_points(self, lower, upper, max_points=1000):
"""Get discrete support points within bounds for plotting.

# special case: discrete distribution and pmf - plot at the support points
Inspects the distribution support for FiniteSet or IntegerSet components.
Falls back to _pmf_support() or integer grid if no such components are found.
"""
from skpro.distributions.base._set import FiniteSet, IntegerSet, UnionSet

pts = []
supp = getattr(self, "support", None)

if supp is not None:
# Flatten to check components (handles UnionSet natively)
sets_to_check = supp.sets if isinstance(supp, UnionSet) else [supp]
for s in sets_to_check:
if isinstance(s, FiniteSet):
s_vals = s.values
pts.extend(s_vals[(s_vals >= lower) & (s_vals <= upper)])
elif isinstance(s, IntegerSet):
lo = max(int(np.floor(lower)) - 1, int(s.lower))
hi_bound = s.upper
hi = min(
int(np.ceil(upper)) + 1,
int(hi_bound) if not np.isinf(hi_bound) else lo + max_points,
)
pts.extend(range(lo, hi + 1))

if pts:
pts_arr = np.array(pts)
# Filter out points failing inclusion bounds (Open vs Closed)
if supp is not None:
mask = [np.all(supp.contains(p)) for p in pts_arr]
pts_arr = pts_arr[mask]

# Unify sub-precision duplicate stems
if len(pts_arr) > 0:
_, unique_idx = np.unique(
np.round(pts_arr, decimals=8), return_index=True
)
return np.sort(pts_arr[unique_idx])
return np.array([])

# Define fallback array construction (used when _pmf_support not available)
def _get_fallback_arr():
arr = np.linspace(lower, upper, 1000)
arr = np.round(arr).astype(int)
return np.unique(arr)

# Use _pmf_support if the method exists and is callable
# fallback: use _pmf_support or integer rounding
if hasattr(self, "_pmf_support") and callable(self._pmf_support):
x_arr = self._pmf_support(lower, upper, max_points=1000)
x_arr = self._pmf_support(lower, upper, max_points=max_points)
if x_arr.size != 0:
return x_arr

return _get_fallback_arr()

def _pmf_support(self, lower, upper, max_points=100):
def _pmf_support(self, lower, upper, max_points=1000):
"""Get support points for discrete distributions.

Returns the support points of the probability mass function (PMF)
Expand Down
Loading
Loading