diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index 9ad4724c..19b6e1a2 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -1,5 +1,5 @@ --- -title: "Observation processes for count data" +title: Observation processes for count data format: gfm: code-fold: true @@ -17,8 +17,6 @@ jupyter: name: python3 --- -This tutorial demonstrates how to use the `Counts` observation process to model count data such as hospital admissions, emergency department visits, or deaths. - ```{python} # | label: setup # | output: false @@ -33,7 +31,12 @@ from plotnine.exceptions import PlotnineWarning warnings.filterwarnings("ignore", category=PlotnineWarning) from _tutorial_theme import theme_tutorial -from pyrenew.observation import Counts, NegativeBinomialNoise, PoissonNoise +from pyrenew.observation import ( + CountBase, + Counts, + NegativeBinomialNoise, + PoissonNoise, +) from pyrenew.deterministic import DeterministicVariable, DeterministicPMF from pyrenew import datasets ``` @@ -49,6 +52,14 @@ In the renewal modeling framework, observations are generated in two steps: Observed data can be aggregated or available as subpopulation-level counts, which are modeled by the classes `Counts` and `CountsBySubpop`, respectively. +### The generative model + +In PyRenew, the observation process maps latent daily infections *forward* to observed data. +The generative direction is always:
+   latent infections $\to$ ascertainment $\times$ delay convolution $\to$ predicted counts $\to$ noise $\to$ observed data.
+This forward direction is fundamental: the observation process transforms predictions to match the scale and resolution of the data, never the reverse. +During inference, the likelihood compares model predictions to observed data at the data's own resolution. + ### Observation equation The deterministic transformation is given by the observation equation: @@ -92,7 +103,68 @@ This yields a two-layer model: **Note on terminology:** In real-world inference, incident infections $I(t)$ are typically *latent* (unobserved) and must be inferred from observed data. In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce observed counts through convolution and sampling. -## Hospital admissions example + +`CountBase` provides the core ascertainment $\times$ delay convolution operations that all subclasses inherit: + +| Inherited method | What it does | +|---|---| +| `_predicted_obs(infections)` | Convolves infections with the delay PMF, scaled by the ascertainment rate. Returns daily predicted counts. | +| `validate()` | Validates the delay PMF, ascertainment rate, noise model, and optional day-of-week / right-truncation parameters. | +| `lookback_days()` | Returns `len(delay_pmf) - 1`: the number of initialization days needed. | +| `_apply_day_of_week(predicted, first_day_dow)` | Applies a multiplicative 7-day periodic pattern (optional). | +| `_apply_right_truncation(predicted, offset)` | Scales recent predictions for incomplete reporting (optional). | + +The last two methods are optional adjustments that subclasses can apply in their `sample()` method. Day-of-week effects model systematic within-week reporting patterns (see [Day-of-Week Effects](day_of_week_effects.md)). Right-truncation adjustment accounts for incomplete reporting of recent observations (see [Right Truncation](right_truncation.md)). + +Subclasses must implement three methods: + +| Required method | Purpose | +|---|---| +| `infection_resolution()` | Return `"aggregate"` or `"subpop"` to tell the model which latent infections to route to this observation. | +| `sample(infections, ...)` | The main forward model: compute predictions, apply any transformations, then sample from the noise distribution. | +| `validate_data(n_total, n_subpops, ...)` | Validate observation data shapes before JAX tracing begins. + +### Built-in subclasses + +PyRenew provides two observation process classes: + +- **`Counts`**: Aggregate daily counts. `infection_resolution()` returns `"aggregate"`. Accepts 1D infections of shape `(n_total,)`. Observations are on the shared dense time axis with NaN masking for initialization and missing data. +- **`CountsBySubpop`**: Subpopulation-level daily counts. `infection_resolution()` returns `"subpop"`. Accepts 2D infections of shape `(n_total, n_subpops)`. Uses sparse indexing via `times` and `subpop_indices`. + + +### Subclassing CountBase + +The following code sketch outlines the functions that must be defined on a subclass of `CountBase`. + +```python +from pyrenew.observation import CountBase +from pyrenew.observation.types import ObservationSample + +class MyCustomCounts(CountBase): + def __init__(self, name, ascertainment_rate_rv, delay_distribution_rv, noise): + super().__init__( + name=name, + ascertainment_rate_rv=ascertainment_rate_rv, + delay_distribution_rv=delay_distribution_rv, + noise=noise, + ) + + def infection_resolution(self): + return "aggregate" # or "subpop" + + def validate_data(self, n_total, n_subpops, **kwargs): + # Check observation data shapes + ... + + def sample(self, infections, ...): + predicted = self._predicted_obs(infections) # inherited + # ... transform predictions (e.g., aggregate to weekly) ... + observed = self.noise.sample(name=..., predicted=..., obs=...) + return ObservationSample(observed=observed, predicted=predicted) +``` + + +## Using the `Counts` class for Hospital Admissions Data For hospital admissions data, we construct a `Counts` observation process. @@ -116,7 +188,7 @@ As $\phi \to \infty$, the negative binomial distribution approaches a Poisson di In this example, we use fixed parameter values for illustration; in practice, these parameters would be estimated from data using weakly informative priors. -## Infection-to-hospitalization delay distribution +### Infection-to-hospitalization delay distribution The delay distribution specifies the probability that an infected person is hospitalized $d$ days after infection, conditional on the infection leading to a hospitalization. For example, if `hosp_delay_pmf[5] = 0.2`, then 20% of infections that result in hospitalization will appear as hospital admissions 5 days after infection. @@ -180,9 +252,9 @@ plot_delay = ( plot_delay ``` -## Creating a Counts observation process +### Defining a Counts observation process -A `Counts` object takes the following arguments: +A `Counts` object inherits the full convolution pipeline from `CountBase`. It takes the following arguments: - **`name`**: unique, meaningful identifier for this observation process (e.g., `"hospital"`, `"deaths"`) - **`ascertainment_rate_rv`**: the probability an infection results in an observation (e.g., IHR) @@ -234,7 +306,6 @@ def first_valid_observation_day(obs_process) -> int: To demonstrate how a `Counts` observation process works, we examine how infections occurring on a single day result in observed hospital admissions. - ```{python} # | label: simulate-spike n_days = 100 @@ -247,6 +318,7 @@ infections = infections.at[infection_spike_day].set(2000) ``` We plot the infections starting from day_one (the first valid observation day, after the lookback period). + ```{python} # | label: plot-infections # Plot relative to first valid observation day @@ -288,7 +360,7 @@ plot_infections Because all infections occur on a single day, this example shows how a single pulse of infections produces observed events over time through the delay distribution. -## Predicted admissions without observation noise. +### Predicted admissions without observation noise First, we compute the predicted admissions from the convolution alone, without observation noise. This gives the predicted number of observations $\mu(t)$. @@ -372,7 +444,7 @@ plot_predicted The predicted admissions follow the shape of the delay distribution, shifted by the infection spike day and scaled by the ascertainment rate. -## Observation Noise (Negative Binomial) +### Observation noise (Negative Binomial) The negative binomial distribution adds stochastic variation around $\mu(t)$, corresponding to the stochastic observation layer. Sampling multiple times from the same infections shows the range of possible observations: @@ -472,7 +544,7 @@ print( ) ``` -## Effect of the ascertainment rate +### Effect of the ascertainment rate The ascertainment rate (here, the infection-hospitalization rate or IHR) directly scales the number of predicted hospital admissions. We compare two contrasting IHR values: **0.5%** and **2.5%**. @@ -505,7 +577,6 @@ for ihr_val in ihr_values: ) ``` - ```{python} # | label: plot-ihr-comparisons results_df = pd.DataFrame(results_list) @@ -525,7 +596,7 @@ plot_ihr = ( plot_ihr ``` -## Negative binomial concentration parameter +### Negative binomial concentration parameter The concentration parameter $\phi$ controls overdispersion: @@ -598,14 +669,14 @@ plot_concentration = ( + p9.labs( x="Day", y="Hospital Admissions", - title="Effect of Concentration Parameter on Variability", + title="Effect of Negative Binomial Concentration Parameter on Variability", ) + theme_tutorial ) plot_concentration ``` -## Swapping noise models +### Swapping noise models To use Poisson noise instead of negative binomial, change the noise model: @@ -629,44 +700,373 @@ print( ) ``` -We can visualize the Poisson noise model using the same constant infection scenario as the concentration comparison above. Since Poisson assumes variance equals the mean, it produces less variability than the negative binomial with low concentration values. - -To see the reduction in noise, it is necessary to keep the y-axis on the same scale as in the previous plot. +To compare Poisson noise directly against negative binomial, we plot 10 replicates from three noise models side by side using the same constant infection input. The shared y-axis makes the difference in variability immediately visible: Poisson ($\text{Var} = \mu$) is the tightest, negative binomial with $\phi = 100$ is nearly identical, and $\phi = 10$ shows noticeably more spread. ```{python} # | label: poisson-realizations -# Sample multiple realizations with Poisson noise -n_replicates_poisson = 10 +noise_comparison = [] +noise_configs = [ + ("Poisson", PoissonNoise()), + ( + "NegBin $\\phi$=100", + NegativeBinomialNoise(DeterministicVariable("c100", 100.0)), + ), + ( + "NegBin $\\phi$=10", + NegativeBinomialNoise(DeterministicVariable("c10", 10.0)), + ), +] + +for label, noise_model in noise_configs: + process_tmp = Counts( + name="hospital", + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=noise_model, + ) + for seed in range(10): + with numpyro.handlers.seed(rng_seed=seed): + result_tmp = process_tmp.sample( + infections=infections_constant, obs=None + ) + for i, admit in enumerate(result_tmp.observed[day_one:]): + noise_comparison.append( + { + "day": i, + "admissions": float(admit), + "noise": label, + "replicate": seed, + } + ) + +noise_df = pd.DataFrame(noise_comparison) +noise_df["noise"] = pd.Categorical( + noise_df["noise"], + categories=["Poisson", "NegBin $\\phi$=100", "NegBin $\\phi$=10"], + ordered=True, +) + +( + p9.ggplot(noise_df, p9.aes(x="day", y="admissions", group="replicate")) + + p9.geom_line(alpha=0.5, size=0.8, color="steelblue") + + p9.facet_wrap("~ noise", ncol=3) + + p9.labs( + x="Day", + y="Hospital Admissions", + title="Noise Model Comparison: Poisson vs. Negative Binomial", + ) + + theme_tutorial +) +``` + +## Weekly Observations with WeeklyCounts + +Some surveillance signals are reported at coarser temporal resolution than daily. For example, NHSN hospital admissions are now reported as weekly (MMWR epiweek) totals rather than daily counts. + +The correct approach is to aggregate *predictions* up to the observation's temporal resolution, not to disaggregate observations down to daily. +Disaggregating weekly counts to daily values would fabricate within-week timing information that does not exist in the data. +Instead, the latent model produces daily predictions via `_predicted_obs()` (inherited from `CountBase`), and the observation process sums them into weekly totals. +The likelihood then evaluates at the weekly resolution - comparing weekly predicted totals to weekly observed totals. +This preserves the generative model's causal direction: latent daily infections flow forward through the observation process to produce predictions at whatever resolution the data requires. + +The predicted weekly admissions for epiweek $w$ are: + +$$\mu_w = \sum_{d \in w} \mu(d)$$ + +where $\mu(d)$ is the daily predicted count. Observations are weekly totals with negative binomial noise: + +$$Y_w \sim \text{NegativeBinomial}(\text{mean} = \mu_w, \text{concentration} = \phi)$$ + +Weekly aggregation naturally reduces variance relative to daily counts, so weekly observations typically use a higher concentration parameter (less overdispersion) than daily observations. +The choice of $\phi$ at each resolution reflects prior knowledge about the noise structure of the data. +Daily counts are subject to day-to-day reporting irregularities: staffing variation, batch reporting, and weekday/weekend effects all introduce overdispersion beyond what the Poisson model predicts. +A moderate $\phi$ (e.g., 10) captures this extra daily noise. +Weekly totals average over these within-week fluctuations, so the remaining noise after aggregation is closer to Poisson. +A high $\phi$ (e.g., 100) is appropriate because most of the reporting-driven overdispersion has been smoothed out by summing over 7 days. +In practice, both concentration parameters would be given informative priors and estimated from data, but the prior for the weekly $\phi$ should be centered higher than the prior for the daily $\phi$. + +Day-of-week effects and right-truncation are not applicable to weekly data: weekly aggregation absorbs within-week patterns and mitigates reporting delays. + +### Implementing the WeeklyCounts class + +```{python} +# | label: weekly-counts-class +from jax.typing import ArrayLike +from pyrenew.observation import CountBase +from pyrenew.observation.noise import CountNoise +from pyrenew.observation.types import ObservationSample +from pyrenew.metaclass import RandomVariable +from pyrenew.time import daily_to_mmwr_epiweekly + + +class WeeklyCounts(CountBase): + """Weekly (MMWR epiweek) aggregate count observation process.""" + + def __init__( + self, + name: str, + ascertainment_rate_rv: RandomVariable, + delay_distribution_rv: RandomVariable, + noise: CountNoise, + ) -> None: + """ + Initialize weekly count observation process. + + Parameters + ---------- + name : str + Unique name for this observation process. + ascertainment_rate_rv : RandomVariable + Ascertainment rate in [0, 1] (e.g., IHR). + delay_distribution_rv : RandomVariable + Delay distribution PMF (must sum to ~1.0). + noise : CountNoise + Noise model for weekly count observations. + """ + super().__init__( + name=name, + ascertainment_rate_rv=ascertainment_rate_rv, + delay_distribution_rv=delay_distribution_rv, + noise=noise, + ) + + def infection_resolution(self) -> str: + """Return 'aggregate' for jurisdiction-level observations.""" + return "aggregate" + + def validate_data( + self, + n_total: int, + n_subpops: int, + first_day_dow: int | None = None, + week_indices: ArrayLike | None = None, + obs: ArrayLike | None = None, + **kwargs, + ) -> None: + """ + Validate weekly observation data. + + Parameters + ---------- + n_total : int + Total time steps on the shared daily axis. + n_subpops : int + Number of subpopulations (unused). + first_day_dow : int | None + Day of the week for element 0 of the shared time axis. + week_indices : ArrayLike | None + Indices into the weekly-aggregated predictions array. + obs : ArrayLike | None + Weekly observed counts. + **kwargs + Additional keyword arguments (ignored). + """ + if obs is not None and week_indices is not None: + obs = jnp.asarray(obs) + week_indices = jnp.asarray(week_indices) + if obs.shape != week_indices.shape: + raise ValueError( + f"Observation '{self.name}': obs shape {obs.shape} " + f"must match week_indices shape {week_indices.shape}" + ) + + def sample( + self, + infections: ArrayLike, + first_day_dow: int, + week_indices: ArrayLike, + obs: ArrayLike | None = None, + ) -> ObservationSample: + """ + Sample weekly aggregated counts. + + Parameters + ---------- + infections : ArrayLike + Daily aggregate infections, shape (n_total,). + first_day_dow : int + ISO day-of-week for element 0 of the shared time axis + (0=Monday, 6=Sunday). + week_indices : ArrayLike + Indices into the weekly predictions array identifying + which weeks have observations. + obs : ArrayLike | None + Weekly observed counts, shape (n_obs_weeks,). + None for prior predictive sampling. + + Returns + ------- + ObservationSample + Named tuple with observed (weekly) and predicted (daily). + """ + daily_predicted = self._predicted_obs(infections) + self._deterministic("predicted_daily", daily_predicted) + + weekly_predicted = daily_to_mmwr_epiweekly( + daily_predicted, input_data_first_dow=first_day_dow + ) + self._deterministic("predicted_weekly", weekly_predicted) + + predicted_at_obs = weekly_predicted[week_indices] + + observed = self.noise.sample( + name=self._sample_site_name("obs"), + predicted=predicted_at_obs, + obs=obs, + ) + + return ObservationSample(observed=observed, predicted=daily_predicted) +``` + +Key design choices: + +- **No day-of-week or right-truncation**: The constructor passes neither `day_of_week_rv` nor `right_truncation_rv` to `CountBase`. Weekly aggregation absorbs within-week patterns and mitigates reporting delays. +- **`week_indices`**: Maps observed weeks to positions in the aggregated predictions. This handles partial weeks at the start/end of the time series and allows for missing weeks. +- **Two deterministic sites**: `predicted_daily` (full daily time series) and `predicted_weekly` (aggregated epiweek totals) are both recorded for posterior analysis. + +### Configuring a weekly hospital admissions process + +```{python} +# | label: create-weekly-process +weekly_ihr_rv = DeterministicVariable("weekly_ihr", 0.01) +weekly_concentration_rv = DeterministicVariable("weekly_concentration", 100.0) + +weekly_hosp_process = WeeklyCounts( + name="hospital_weekly", + ascertainment_rate_rv=weekly_ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(weekly_concentration_rv), +) + +print(f"Required lookback: {weekly_hosp_process.lookback_days()} days") +``` + +### Comparing daily and weekly observations from the same infections + +Using the exponentially decaying infection curve from earlier, we can see how the same underlying epidemic produces different observations at daily vs. weekly resolution. + +```{python} +# | label: weekly-simulate +import datetime as dt + +peak_value = 3000 +infections_decay = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) + +# The shared time axis starts on a Sunday (2023-01-01 was a Sunday = ISO dow 6) +first_dow = 6 + +# Compute weekly predictions to determine valid week indices +with numpyro.handlers.seed(rng_seed=0): + daily_predicted = weekly_hosp_process._predicted_obs(infections_decay) + +weekly_predicted = daily_to_mmwr_epiweekly( + daily_predicted, input_data_first_dow=first_dow +) +n_valid_weeks = int(jnp.sum(~jnp.isnan(weekly_predicted))) +n_total_weeks = len(weekly_predicted) + +# Use all valid (non-NaN) weeks +all_week_indices = jnp.arange(n_total_weeks) +valid_mask = ~jnp.isnan(weekly_predicted) +week_indices = all_week_indices[valid_mask] + +print( + f"Total weeks: {n_total_weeks}, " + f"valid weeks (after lookback): {n_valid_weeks}" +) +``` + +```{python} +# | label: weekly-daily-comparison +# Sample daily observations (using existing Counts process) +daily_process = Counts( + name="hospital_daily", + ascertainment_rate_rv=weekly_ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(DeterministicVariable("conc_daily", 10.0)), +) + +# Collect daily and weekly samples +comparison_list = [] + +for seed in range(50): + with numpyro.handlers.seed(rng_seed=seed): + daily_result = daily_process.sample( + infections=infections_decay, obs=None + ) + + for i, val in enumerate(daily_result.observed[day_one:]): + comparison_list.append( + { + "time": i, + "admissions": float(val), + "resolution": "Daily ($\\phi$=10)", + "replicate": seed, + } + ) -poisson_results = [] -for seed in range(n_replicates_poisson): with numpyro.handlers.seed(rng_seed=seed): - poisson_temp = hosp_process_poisson.sample( - infections=infections_constant, + weekly_result = weekly_hosp_process.sample( + infections=infections_decay, + first_day_dow=first_dow, + week_indices=week_indices, obs=None, ) - # Slice from day_one to align with valid observation period - for i, admit in enumerate(poisson_temp.observed[day_one:]): - poisson_results.append( + for j, (wi, val) in enumerate(zip(week_indices, weekly_result.observed)): + comparison_list.append( { - "day": i, - "admissions": float(admit), + "time": int(wi) * 7 + 3 - day_one, + "admissions": float(val), + "resolution": "Weekly ($\\phi$=100)", "replicate": seed, } ) -poisson_df = pd.DataFrame(poisson_results) -plot_poisson = ( - p9.ggplot(poisson_df, p9.aes(x="day", y="admissions", group="replicate")) - + p9.geom_line(alpha=0.5, size=0.8, color="steelblue") +comparison_df = pd.DataFrame(comparison_list) +``` + +```{python} +# | label: plot-weekly-daily-comparison +daily_comp = comparison_df[comparison_df["resolution"] == "Daily ($\\phi$=10)"] +weekly_comp = comparison_df[ + comparison_df["resolution"] == "Weekly ($\\phi$=100)" +] + +( + p9.ggplot() + + p9.geom_line( + p9.aes(x="time", y="admissions", group="replicate"), + data=daily_comp, + color="steelblue", + alpha=0.2, + size=0.5, + ) + + p9.geom_jitter( + p9.aes(x="time", y="admissions", group="replicate"), + data=weekly_comp, + color="orange", + alpha=0.6, + size=2, + width=1.2, + height=0, + ) + p9.labs( - x="Day", + x="Day (relative to first valid observation day)", y="Hospital Admissions", - title="Poisson Noise Model (Variance = Mean)", + title="Daily vs. Weekly: 50 Sample Observations from Same Infections ", + subtitle="Blue lines: daily ($\\phi$=10) | Orange points: weekly totals ($\\phi$=100)", ) + theme_tutorial - + p9.ylim(0, 105) ) -plot_poisson ``` + +Weekly aggregation collapses seven daily values into a single total per epiweek. +The weekly points may appear more dispersed than the daily lines, even though the weekly process uses a much higher concentration parameter ($\phi = 100$ vs. $\phi = 10$). +This is not a modeling error. +The negative binomial variance is $\text{Var}[Y] = \mu + \mu^2 / \phi$. +Weekly totals have means roughly 7 times larger than daily means ($\mu_w \approx 7 \mu_d$), so the quadratic term $\mu_w^2 / \phi$ grows with the square of the mean. +Even with $\phi = 100$, the absolute spread of the weekly distribution is wider than the daily distribution with $\phi = 10$, because the weekly mean is so much larger. +In relative terms (coefficient of variation), the weekly observations are tighter, which is why weekly data is often considered less noisy for inference. + +In a multi-signal model, pairing weekly hospital admissions with a daily signal (such as ED visits) allows the daily signal to resolve within-week dynamics that the weekly signal cannot capture. diff --git a/pyrenew/datasets/hospital_admissions.py b/pyrenew/datasets/hospital_admissions.py index 0aa1cda4..cfc52aad 100644 --- a/pyrenew/datasets/hospital_admissions.py +++ b/pyrenew/datasets/hospital_admissions.py @@ -3,7 +3,7 @@ Load hospital admissions data for use in tutorials and examples. This module provides functions to load COVID-19 hospital admissions -data from the CDC's cfa-forecast-renewal-ww project. +data (daily and weekly) from the CDC's cfa-forecast-renewal-ww project. """ from importlib.resources import files diff --git a/pyrenew/observation/__init__.py b/pyrenew/observation/__init__.py index fcbb9394..4f47d696 100644 --- a/pyrenew/observation/__init__.py +++ b/pyrenew/observation/__init__.py @@ -4,6 +4,7 @@ ``BaseObservationProcess`` is the abstract base. Concrete subclasses: +- ``CountBase``: Base class for count observations (ascertainment x delay convolution) - ``Counts``: Aggregate counts (admissions, deaths) - ``CountsBySubpop``: Subpopulation-level counts - ``Measurements``: Continuous subpopulation-level signals (e.g., wastewater) @@ -19,7 +20,7 @@ """ from pyrenew.observation.base import BaseObservationProcess -from pyrenew.observation.count_observations import Counts, CountsBySubpop +from pyrenew.observation.count_observations import CountBase, Counts, CountsBySubpop from pyrenew.observation.measurements import Measurements from pyrenew.observation.negativebinomial import NegativeBinomialObservation from pyrenew.observation.noise import ( @@ -44,6 +45,7 @@ "MeasurementNoise", "HierarchicalNormalNoise", # Observation processes + "CountBase", "Counts", "CountsBySubpop", "Measurements", diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 80026061..15323d08 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -19,9 +19,9 @@ from pyrenew.time import get_sequential_day_of_week_indices -class _CountBase(BaseObservationProcess): +class CountBase(BaseObservationProcess): """ - Internal base for count observation processes. + Base class for count observation processes. Implements ascertainment x delay convolution with pluggable noise model. """ @@ -115,17 +115,6 @@ def lookback_days(self) -> int: """ return len(self.temporal_pmf_rv()) - 1 - def infection_resolution(self) -> str: - """ - Return required infection resolution. - - Returns - ------- - str - "aggregate" or "subpop". - """ - raise NotImplementedError("Subclasses must implement infection_resolution()") - def _predicted_obs( self, infections: ArrayLike, @@ -246,7 +235,7 @@ def _apply_day_of_week( return predicted * daily_effect -class Counts(_CountBase): +class Counts(CountBase): """ Aggregated count observation. @@ -399,7 +388,7 @@ def sample( return ObservationSample(observed=observed, predicted=predicted_counts) -class CountsBySubpop(_CountBase): +class CountsBySubpop(CountBase): """ Subpopulation-level count observation. diff --git a/test/test_interface_coverage.py b/test/test_interface_coverage.py index 500b65dc..d9360062 100644 --- a/test/test_interface_coverage.py +++ b/test/test_interface_coverage.py @@ -196,11 +196,11 @@ def test_measurements_infection_resolution(): def test_base_count_observation_infection_resolution_raises(): - """Base _CountBase.infection_resolution() raises NotImplementedError.""" - from pyrenew.observation.count_observations import _CountBase + """Subclass of CountBase without infection_resolution cannot be instantiated.""" + from pyrenew.observation.count_observations import CountBase - class _MinimalCounts(_CountBase): - """Minimal subclass that inherits infection_resolution unchanged.""" + class _MinimalCounts(CountBase): + """Minimal subclass missing infection_resolution.""" def sample(self, *args, **kwargs): # numpydoc ignore=GL08 pass @@ -208,14 +208,13 @@ def sample(self, *args, **kwargs): # numpydoc ignore=GL08 def validate_data(self, n_total, n_subpops, **obs_data): # numpydoc ignore=GL08 pass - obs = _MinimalCounts( - name="test_base", - ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), - delay_distribution_rv=DeterministicPMF("delay", jnp.array([1.0])), - noise=PoissonNoise(), - ) - with pytest.raises(NotImplementedError): - obs.infection_resolution() + with pytest.raises(TypeError, match="infection_resolution"): + _MinimalCounts( + name="test_base", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", jnp.array([1.0])), + noise=PoissonNoise(), + ) # =============================================================================