diff --git a/.gitignore b/.gitignore index 75afd71..1d19e29 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ # Data +benchmarks/ *.csv *.tsv *.parquet diff --git a/README.md b/README.md index 49bb45a..25fcf69 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,23 @@ uv run mrp run example_model.mrp.toml --input seed=42 --input max_gen=10 You can run `uv tool install cfa-mrp` to omit the `uv run`. +## Running a calibration + +The repository includes a complete calibration example for the bundled example model: + +```bash +uv sync --all-packages --all-extras +uv run python -m example_model.calibrate +``` + +This runs the ABC-SMC calibration workflow defined in [packages/example_model/src/example_model/calibrate.py](/home/as81/work/cfa-calibration-tools-wtk-mp/packages/example_model/src/example_model/calibrate.py) and prints the posterior summary and diagnostics. + +To compare serial and parallel execution for the same example, run: + +```bash +uv run python -m example_model.benchmark +``` + ## General Disclaimer This repository was created for use by CDC programs to collaborate on public health related projects in support of the [CDC mission](https://www.cdc.gov/about/organization/mission.htm). GitHub is not hosted by the CDC, but is a third party website used by CDC and its partners to share information and collaborate on software. CDC use of GitHub does not imply an endorsement of any one particular service, product, or enterprise. diff --git a/example_model.mrp.toml b/example_model.mrp.toml index db17211..43b0fca 100644 --- a/example_model.mrp.toml +++ b/example_model.mrp.toml @@ -6,7 +6,8 @@ version = "0.0.1" [runtime] env = "uv" -command = "example_model" +command = "python" +args = ["-m", "example_model"] [output] spec = "filesystem" diff --git a/packages/example_model/src/example_model/benchmark.py b/packages/example_model/src/example_model/benchmark.py index 66136f4..83a1446 100644 --- a/packages/example_model/src/example_model/benchmark.py +++ b/packages/example_model/src/example_model/benchmark.py @@ -2,6 +2,7 @@ import json import timeit +from pathlib import Path import numpy as np from mrp import Environment @@ -132,5 +133,8 @@ def outputs_to_distance(model_output, target_data): for result in benchmark_results: print(f"workers: {result['max_workers']}, time: {result['time']}") -with open("./benchmarks/parallelization_check.json", "w") as fp: +benchmark_dir = Path("./benchmarks") +benchmark_dir.mkdir(exist_ok=True) + +with open(benchmark_dir / "parallelization_check.json", "w") as fp: json.dump(benchmark_results, fp) diff --git a/src/calibrationtools/async_runner.py b/src/calibrationtools/async_runner.py new file mode 100644 index 0000000..6e55951 --- /dev/null +++ b/src/calibrationtools/async_runner.py @@ -0,0 +1,50 @@ +"""Run coroutines from synchronous sampler code. + +This module centralizes the event-loop bridging used by sampler execution +paths so synchronous orchestration can safely invoke async helpers in both +normal scripts and already-running event loops. +""" + +import asyncio +import threading +from typing import Any, Callable, NoReturn + + +def run_coroutine_from_sync(coroutine_factory: Callable[[], Any]) -> Any: + """Run an async workflow from synchronous code. + + This helper executes the coroutine directly when no event loop is active. + If the caller already runs inside an event loop, it executes the coroutine + in a dedicated worker thread and re-raises any exception from that thread. + + Args: + coroutine_factory (Callable[[], Any]): Factory returning the coroutine + to execute. + + Returns: + Any: The value returned by the coroutine. + """ + + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coroutine_factory()) + + result: dict[str, Any] = {} + error: dict[str, BaseException] = {} + + def runner() -> None: + try: + result["value"] = asyncio.run(coroutine_factory()) + except BaseException as exc: # pragma: no cover - passthrough + error["value"] = exc + + def raise_worker_error(exc: BaseException) -> NoReturn: + raise exc + + thread = threading.Thread(target=runner, daemon=True) + thread.start() + thread.join() + if "value" in error: + raise_worker_error(error["value"]) + return result["value"] diff --git a/src/calibrationtools/batch_generation_runner.py b/src/calibrationtools/batch_generation_runner.py new file mode 100644 index 0000000..b45a2d6 --- /dev/null +++ b/src/calibrationtools/batch_generation_runner.py @@ -0,0 +1,568 @@ +"""Run batched ABC-SMC generations outside the sampler facade. + +This module contains the execution engine for the batched sampling path. It +keeps batch proposal sizing, evaluation, acceptance, and progress reporting +out of `ABCSampler`. +""" + +import asyncio +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable + +from .async_runner import run_coroutine_from_sync +from .particle import Particle +from .particle_population import ParticlePopulation +from .sampler_reporting import ProgressHandle, SamplerReporter +from .sampler_run_state import SamplerRunState +from .sampler_types import GenerationStats + + +@dataclass(frozen=True, slots=True) +class BatchGenerationConfig: + """Store static settings and callbacks for batched generations. + + This configuration object groups the sampler collaborators that remain + stable across batched generations so the runner constructor stays small + and execution methods operate on named fields. + + Attributes: + generation_particle_count (int): Number of accepted particles required + to complete a generation. + tolerance_values (list[float]): Acceptance tolerance for each + generation. + sample_particle_from_priors (Callable[[Any], Particle]): Proposal + function for the initial generation. + sample_and_perturb_particle (Callable[[Any], Particle]): Proposal + function for later generations. + particle_to_distance (Callable[..., float]): Function used to evaluate + the distance of one proposed particle. + calculate_weight (Callable[[Particle], float]): Function used to weight + accepted particles after the first generation. + replace_particle_population (Callable[[ParticlePopulation], None]): + Callback that stores the finalized population on the sampler. + reporter (SamplerReporter): Reporter used for progress and summary + output. + """ + + generation_particle_count: int + tolerance_values: list[float] + sample_particle_from_priors: Callable[[Any], Particle] + sample_and_perturb_particle: Callable[[Any], Particle] + particle_to_distance: Callable[..., float] + calculate_weight: Callable[[Particle], float] + replace_particle_population: Callable[[ParticlePopulation], None] + reporter: SamplerReporter + + +@dataclass(frozen=True, slots=True) +class BatchGenerationRequest: + """Describe one batched generation execution. + + This carrier groups the runtime inputs for the batched generation path so + helper methods can depend on a single named object instead of long + positional argument lists. + + Attributes: + generation (int): Zero-based generation index being executed. + batchsize (int): Target batch size for proposal generation. + warmup (bool): Whether warmup sizing should be used for adaptive + proposal estimation. + chunksize (int): Number of particles evaluated per chunk. + executor (ThreadPoolExecutor | None): Executor used for concurrent + chunk evaluation when available. + overall_start_time (float): Timestamp recorded at the start of the full + sampler run. + generation_start_time (float): Timestamp recorded at the start of the + generation. + particle_kwargs (dict[str, Any]): Keyword arguments forwarded into + particle evaluation. + """ + + generation: int + batchsize: int + warmup: bool + chunksize: int + executor: ThreadPoolExecutor | None + overall_start_time: float + generation_start_time: float + particle_kwargs: dict[str, Any] + + +@dataclass(slots=True) +class BatchGenerationState: + """Store mutable state for one batched generation. + + This state groups the accepted population and the running attempt count so + helper methods can share batch-generation state without long parameter + lists. + + Attributes: + proposed_population (ParticlePopulation): Population being filled for + the active generation. + attempts (int): Total proposal attempts consumed so far. + """ + + proposed_population: ParticlePopulation + attempts: int = 0 + + +class BatchGenerationRunner: + """Run one batched ABC-SMC generation in serial or threaded mode. + + This runner isolates the batched execution engine from `ABCSampler`. It + handles batch proposal sizing, chunk evaluation, acceptance accounting, + and generation-level run-state updates. + + Args: + config (BatchGenerationConfig): Static settings and callbacks used + across batched generations. + run_state (SamplerRunState): Mutable bookkeeping for the active sampler + run. + """ + + def __init__( + self, + config: BatchGenerationConfig, + run_state: SamplerRunState, + ) -> None: + self.config = config + self.run_state = run_state + + def resolve_settings( + self, + batchsize: int | None, + chunksize: int, + ) -> tuple[int, bool]: + """Resolve validated batch-execution settings. + + This helper normalizes the requested batch configuration and returns + both the resolved batch size and whether warmup estimation should be + used for the first adaptive batch. + + Args: + batchsize (int | None): Optional batch-size override. + chunksize (int): Number of particles processed serially inside one + executor task. + + Returns: + tuple[int, bool]: Resolved batch size and whether warmup mode is + enabled. + + Raises: + ValueError: Raised when `chunksize` or an explicit `batchsize` is + not positive. + """ + + if chunksize <= 0: + raise ValueError("chunksize must be positive") + if batchsize is None: + return self.config.generation_particle_count, True + if batchsize <= 0: + raise ValueError("batchsize must be positive") + return batchsize, False + + def run_generation( + self, request: BatchGenerationRequest + ) -> GenerationStats: + """Execute one batched generation and store its final population. + + This method coordinates adaptive proposal sizing, batch evaluation, + acceptance accounting, and final population storage for the batched + sampler path. + + Args: + request (BatchGenerationRequest): Runtime inputs for the generation + being executed. + + Returns: + GenerationStats: Attempts, successes, and timing metrics recorded + for the completed generation. + """ + + state = BatchGenerationState(proposed_population=ParticlePopulation()) + description = ( + f"Generation {request.generation + 1} " + f"(tolerance {self.config.tolerance_values[request.generation]})..." + ) + with self.config.reporter.create_collection_progress() as progress: + handle = self.config.reporter.start_collection_task( + progress=progress, + description=description, + total=self.config.generation_particle_count, + ) + while ( + state.proposed_population.size + < self.config.generation_particle_count + ): + sample_size = self._get_batch_sample_size( + state=state, + batchsize=request.batchsize, + warmup=request.warmup, + ) + proposed_particles = self._sample_generation_particles( + generation=request.generation, + sample_size=sample_size, + ) + state.attempts += self._process_particle_batch( + request=request, + state=state, + proposed_particles=proposed_particles, + ) + self._update_progress( + handle=handle, + state=state, + generation_start_time=request.generation_start_time, + ) + + generation_stats = self._build_generation_stats( + request=request, + state=state, + ) + self.config.reporter.print_generation_summary( + generation=request.generation, + tolerance=self.config.tolerance_values[request.generation], + generation_stats=generation_stats, + ) + + self.run_state.record_attempts( + generation=request.generation, + attempts=generation_stats.attempts, + successes=generation_stats.successes, + ) + self.config.replace_particle_population(state.proposed_population) + self.config.reporter.print_timing_summary( + processing_time=generation_stats.processing_time, + total_time=generation_stats.total_time, + ) + return generation_stats + + def _get_batch_sample_size( + self, + state: BatchGenerationState, + batchsize: int, + warmup: bool, + ) -> int: + """Estimate the proposal count for the next batch. + + This helper increases the early batch size during warmup and then + adapts future proposal counts using the observed acceptance rate from + the current generation. + + Args: + state (BatchGenerationState): Mutable state for the active + generation. + batchsize (int): Configured batch size for normal operation. + warmup (bool): Whether the warmup heuristic should still apply. + + Returns: + int: Number of particles to propose for the next batch. + """ + + effective_batchsize = ( + 10_000 + if warmup and state.proposed_population.size > 0 + else batchsize + ) + if state.proposed_population.size == 0: + return effective_batchsize + + remaining = ( + self.config.generation_particle_count + - state.proposed_population.size + ) + sample_size = min( + effective_batchsize, + remaining * state.attempts / state.proposed_population.size, + ) + return max(int(sample_size), 1) + + def _sample_generation_particles( + self, + generation: int, + sample_size: int, + ) -> list[Particle]: + """Sample a batch of proposed particles for one generation. + + This helper chooses the correct proposal function for the generation + and returns the requested number of proposed particles. + + Args: + generation (int): Zero-based generation index being sampled. + sample_size (int): Number of particles to propose. + + Returns: + list[Particle]: Proposed particles for the batch. + """ + + sample_method = ( + self.config.sample_particle_from_priors + if generation == 0 + else self.config.sample_and_perturb_particle + ) + return [sample_method(None) for _ in range(sample_size)] + + def _evaluate_particle_chunk( + self, + proposed_particles: list[Particle], + particle_kwargs: dict[str, Any], + ) -> list[float]: + """Evaluate a chunk of proposed particles serially. + + This helper keeps chunk evaluation reusable between the serial and + threaded batch-processing paths. + + Args: + proposed_particles (list[Particle]): Proposed particles to score. + particle_kwargs (dict[str, Any]): Additional keyword arguments + forwarded into particle evaluation. + + Returns: + list[float]: Distances computed for the proposed particles. + """ + + return [ + self.config.particle_to_distance( + proposed_particle, + **particle_kwargs, + ) + for proposed_particle in proposed_particles + ] + + async def _process_particle_batch_async( + self, + request: BatchGenerationRequest, + state: BatchGenerationState, + proposed_particles: list[Particle], + ) -> int: + """Evaluate and accept one batch of particles concurrently. + + This helper splits proposed particles into chunks, evaluates each chunk + on the executor, and accepts particles in chunk order until the + generation population is full or all proposed particles are + considered. + + Args: + request (BatchGenerationRequest): Runtime inputs for the generation + being executed. + state (BatchGenerationState): Mutable state for the active + generation. + proposed_particles (list[Particle]): Proposed particles to + evaluate. + + Returns: + int: Number of proposed particles that were considered. + """ + + assert request.executor is not None + + loop = asyncio.get_running_loop() + worker = partial( + self._evaluate_particle_chunk, + particle_kwargs=request.particle_kwargs, + ) + particle_chunks = [ + proposed_particles[index : index + request.chunksize] + for index in range(0, len(proposed_particles), request.chunksize) + ] + tasks = [] + for chunk in particle_chunks: + task = loop.run_in_executor(request.executor, worker, chunk) + tasks.append((task, chunk)) + + attempts = 0 + try: + for task, chunk in tasks: + chunk_results = await task + attempts += self._accept_particle_batch( + generation=request.generation, + proposed_population=state.proposed_population, + proposed_particles=chunk, + errs=chunk_results, + ) + if ( + state.proposed_population.size + >= self.config.generation_particle_count + ): + break + finally: + for task, _ in tasks: + task.cancel() + + return attempts + + def _process_particle_batch( + self, + request: BatchGenerationRequest, + state: BatchGenerationState, + proposed_particles: list[Particle], + ) -> int: + """Evaluate and accept one batch of proposed particles. + + This helper dispatches batch evaluation either serially or through the + executor-backed async path, then returns how many proposed particles + were considered before the population filled or the batch was + exhausted. + + Args: + request (BatchGenerationRequest): Runtime inputs for the generation + being executed. + state (BatchGenerationState): Mutable state for the active + generation. + proposed_particles (list[Particle]): Proposed particles to + evaluate. + + Returns: + int: Number of proposed particles that were considered. + """ + + if ( + request.executor is None + or len(proposed_particles) <= request.chunksize + ): + attempts = 0 + for index in range(0, len(proposed_particles), request.chunksize): + chunk = proposed_particles[index : index + request.chunksize] + errs = self._evaluate_particle_chunk( + proposed_particles=chunk, + particle_kwargs=request.particle_kwargs, + ) + attempts += self._accept_particle_batch( + generation=request.generation, + proposed_population=state.proposed_population, + proposed_particles=chunk, + errs=errs, + ) + if ( + state.proposed_population.size + >= self.config.generation_particle_count + ): + break + return attempts + + return run_coroutine_from_sync( + lambda: self._process_particle_batch_async( + request=request, + state=state, + proposed_particles=proposed_particles, + ) + ) + + def _accept_particle_batch( + self, + generation: int, + proposed_population: ParticlePopulation, + proposed_particles: list[Particle], + errs: list[float], + ) -> int: + """Accept evaluated particles into the proposed population. + + This helper applies the generation tolerance to evaluated particles, + computes weights for accepted particles, and stops early once the + proposed population reaches the target size. + + Args: + generation (int): Zero-based generation index being executed. + proposed_population (ParticlePopulation): Population being filled + for the generation. + proposed_particles (list[Particle]): Proposed particles that were + evaluated. + errs (list[float]): Distances computed for the proposed particles. + + Returns: + int: Number of evaluated particles that were considered. + """ + + considered = 0 + for err, proposed_particle in zip(errs, proposed_particles): + if ( + proposed_population.size + >= self.config.generation_particle_count + ): + break + + considered += 1 + if err <= self.config.tolerance_values[generation]: + particle_weight = ( + 1.0 + if generation == 0 + else self.config.calculate_weight(proposed_particle) + ) + proposed_population.add_particle( + proposed_particle, + particle_weight, + ) + + return considered + + def _update_progress( + self, + handle: ProgressHandle, + state: BatchGenerationState, + generation_start_time: float, + ) -> None: + """Update generation progress for the batched path. + + This helper computes ETA and acceptance rate for the current batched + generation state before delegating the actual Rich update to the + reporter. + + Args: + handle (ProgressHandle): Handle referencing the active progress + task. + state (BatchGenerationState): Mutable state for the active + generation. + generation_start_time (float): Timestamp recorded at the start of + the generation. + + Returns: + None: This helper does not return a value. + """ + + elapsed = time.time() - generation_start_time + completed = state.proposed_population.size + eta = ( + elapsed + * (self.config.generation_particle_count - completed) + / completed + if elapsed > 0 and completed > 0 + else 0.0 + ) + acceptance_rate = ( + 100.0 * completed / state.attempts if state.attempts > 0 else 0.0 + ) + self.config.reporter.update_collection_progress( + handle=handle, + completed=completed, + acceptance_rate=acceptance_rate, + eta_seconds=eta, + ) + + def _build_generation_stats( + self, + request: BatchGenerationRequest, + state: BatchGenerationState, + ) -> GenerationStats: + """Build summary metrics for a completed batched generation. + + This helper records attempts, accepted particles, and elapsed timing in + the shared `GenerationStats` carrier used by sampler execution paths. + + Args: + request (BatchGenerationRequest): Runtime inputs for the generation + being executed. + state (BatchGenerationState): Mutable state for the active + generation. + + Returns: + GenerationStats: Summary metrics for the completed generation. + """ + + return GenerationStats( + attempts=state.attempts, + successes=state.proposed_population.size, + processing_time=time.time() - request.generation_start_time, + total_time=time.time() - request.overall_start_time, + ) diff --git a/src/calibrationtools/calibration_results.py b/src/calibrationtools/calibration_results.py index 1b1ad89..497fbf8 100644 --- a/src/calibrationtools/calibration_results.py +++ b/src/calibrationtools/calibration_results.py @@ -1,12 +1,11 @@ from typing import Any -from numpy.random import SeedSequence - from .particle import Particle from .particle_population import ParticlePopulation from .particle_population_metrics import ParticlePopulationMetrics from .particle_updater import _ParticleUpdater from .prior_distribution import nonseed_param_names +from .sampler_types import GeneratorSlot class CalibrationResults: @@ -15,7 +14,7 @@ class CalibrationResults: Args: _updater (_ParticleUpdater):The particle population updater availble from a fitted sampler object, which contains the final particle population and the perturbation kernel used for sampling particles in the final generation. - generator_history (dict[int, list[dict[int, int | SeedSequence]]]): A dictionary mapping generation indices to their corresponding lists of dictionaries containing particle IDs and their associated seed sequences, representing the history of particle sampling and perturbation across generations when called with the appropriate particle updater. + generator_history (dict[int, list[GeneratorSlot]]): A dictionary mapping generation indices to their corresponding lists of generator slots containing particle IDs and their associated seed sequences, representing the history of particle sampling and perturbation across generations when called with the appropriate particle updater. population_archive (dict[int, ParticlePopulation]): A dictionary mapping generation indices to their corresponding particle populations, representing the history of particle populations across generations if saved during the sampler run. success_counts (dict[str, list[int]]): A dictionary containing lists of particles per generation, success counts, and attempt counts for each generation, with keys "generation_particle_count", "successes" and "attempts". tolerance_values (list[float]): A list of tolerance values for each generation @@ -33,7 +32,7 @@ class CalibrationResults: def __init__( self, _updater: _ParticleUpdater, - generator_history: dict[int, list[dict[int, int | SeedSequence]]], + generator_history: dict[int, list[GeneratorSlot]], population_archive: dict[int, ParticlePopulation], success_counts: dict[str, list[int]], tolerance_values: list[float], diff --git a/src/calibrationtools/formatting.py b/src/calibrationtools/formatting.py index c1f64bf..fa57395 100644 --- a/src/calibrationtools/formatting.py +++ b/src/calibrationtools/formatting.py @@ -1,8 +1,25 @@ +from io import StringIO + from rich.console import Console -from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn -def get_console() -> Console: - return Console(force_terminal=True) + +def get_console(verbose: bool = True) -> Console: + """Return the console used for sampler reporting. + + This helper creates a visible Rich console for normal runs and a hidden + in-memory console when sampler output should be suppressed. + + Args: + verbose (bool): Whether console output should be visible. + + Returns: + Console: Rich console configured for the requested verbosity. + """ + + if verbose: + return Console(force_terminal=True) + return Console(file=StringIO(), force_terminal=False) + def _format_time(seconds: float) -> str: """Format time duration in human-readable units. diff --git a/src/calibrationtools/load_priors.py b/src/calibrationtools/load_priors.py index 747c79b..f413c35 100644 --- a/src/calibrationtools/load_priors.py +++ b/src/calibrationtools/load_priors.py @@ -18,8 +18,10 @@ def load_schema() -> dict: """Load the JSON schema for validating priors from the package resources.""" - with importlib.resources.open_text( - "calibrationtools.assets", "schema.json" + with ( + importlib.resources.files("calibrationtools.assets") + .joinpath("schema.json") + .open("r", encoding="utf-8") ) as f: schema = json.load(f) return schema diff --git a/src/calibrationtools/particle_evaluator.py b/src/calibrationtools/particle_evaluator.py new file mode 100644 index 0000000..24210fd --- /dev/null +++ b/src/calibrationtools/particle_evaluator.py @@ -0,0 +1,77 @@ +"""Evaluate particles by running the model and scoring its outputs. + +This module isolates the particle-to-params, simulate, and distance-scoring +steps behind one small collaborator so sampler execution code stays focused on +proposal and acceptance flow. +""" + +from typing import Any, Callable + +from mrp import MRPModel + +from .particle import Particle + + +class ParticleEvaluator: + """Evaluate particles by running the model and scoring its outputs. + + This class holds the user-supplied mapping and scoring functions together + with the model runner so particle evaluation has a single, testable + boundary. + + Args: + particles_to_params (Callable[..., dict]): Function mapping a + particle to model parameters. + outputs_to_distance (Callable[..., float]): Function scoring simulated + outputs against target data. + target_data (Any): Observed data used for distance evaluation. + model_runner (MRPModel): Model runner used to simulate outputs. + """ + + def __init__( + self, + particles_to_params: Callable[..., dict], + outputs_to_distance: Callable[..., float], + target_data: Any, + model_runner: MRPModel, + ) -> None: + self.particles_to_params = particles_to_params + self.outputs_to_distance = outputs_to_distance + self.target_data = target_data + self.model_runner = model_runner + + def simulate(self, particle: Particle, **kwargs: Any) -> Any: + """Run the model for a particle and return the simulated outputs. + + This method translates a particle into model parameters and runs the + model to produce simulated outputs. + + Args: + particle (Particle): Particle to evaluate. + **kwargs (Any): Additional keyword arguments forwarded to + `particles_to_params`. + + Returns: + Any: Simulated outputs produced by the model. + """ + + params = self.particles_to_params(particle, **kwargs) + return self.model_runner.simulate(params) + + def distance(self, particle: Particle, **kwargs: Any) -> float: + """Return the distance between simulated outputs and target data. + + This method translates a particle into model parameters, runs the + model, and scores the resulting outputs against the stored target data. + + Args: + particle (Particle): Particle to evaluate. + **kwargs (Any): Additional keyword arguments forwarded to + `particles_to_params`. + + Returns: + float: Distance between the simulated outputs and the target data. + """ + + outputs = self.simulate(particle, **kwargs) + return self.outputs_to_distance(outputs, self.target_data) diff --git a/src/calibrationtools/particlewise_generation_runner.py b/src/calibrationtools/particlewise_generation_runner.py new file mode 100644 index 0000000..ab18049 --- /dev/null +++ b/src/calibrationtools/particlewise_generation_runner.py @@ -0,0 +1,602 @@ +"""Run particlewise ABC-SMC generations outside the sampler facade. + +This module contains the dedicated execution engine for the particlewise +sampling path. It keeps generation setup, proposal collection, progress +reporting, and population finalization out of `ABCSampler`. +""" + +import asyncio +import time +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable + +from numpy.random import SeedSequence + +from .async_runner import run_coroutine_from_sync +from .particle import Particle +from .particle_population import ParticlePopulation +from .sampler_reporting import ProgressHandle, SamplerReporter +from .sampler_run_state import SamplerRunState +from .sampler_types import AcceptedProposal, GenerationStats, GeneratorSlot + + +@dataclass(frozen=True, slots=True) +class ParticlewiseGenerationConfig: + """Store static settings and callbacks for particlewise generations. + + This configuration object groups the sampler collaborators that remain + stable across generations so the runner constructor stays small and the + execution methods can depend on named fields instead of long parameter + lists. + + Attributes: + generation_particle_count (int): Number of accepted particles required + to complete a generation. + tolerance_values (list[float]): Acceptance tolerance for each + generation. + seed_sequence (SeedSequence): Root seed sequence used to spawn + deterministic generator slots. + max_attempts_per_proposal (int): Maximum number of proposal attempts + allowed for one generator slot. + sample_particle_from_priors (Callable[[SeedSequence | None], Particle]): + Proposal function for the initial generation. + sample_and_perturb_particle (Callable[[SeedSequence | None], Particle]): + Proposal function for later generations. + particle_to_distance (Callable[..., float]): Function used to evaluate + the distance of one proposed particle. + calculate_weight (Callable[[Particle], float]): Function used to weight + accepted particles after the first generation. + replace_particle_population (Callable[[ParticlePopulation], None]): + Callback that stores the finalized population on the sampler. + reporter (SamplerReporter): Reporter used for progress and summary + output. + """ + + generation_particle_count: int + tolerance_values: list[float] + seed_sequence: SeedSequence + max_attempts_per_proposal: int + sample_particle_from_priors: Callable[[SeedSequence | None], Particle] + sample_and_perturb_particle: Callable[[SeedSequence | None], Particle] + particle_to_distance: Callable[..., float] + calculate_weight: Callable[[Particle], float] + replace_particle_population: Callable[[ParticlePopulation], None] + reporter: SamplerReporter + + +@dataclass(frozen=True, slots=True) +class ParticlewiseGenerationRequest: + """Describe one particlewise generation execution. + + This request object captures the per-generation runtime inputs that vary + across sampler iterations, including executor access, timing markers, and + keyword arguments forwarded into particle evaluation. + + Attributes: + generation (int): Zero-based generation index to execute. + n_workers (int): Number of workers available to the generation. + parallel_executor (ThreadPoolExecutor | None): Executor used for + threaded proposal collection when parallel execution is active. + overall_start_time (float): Timestamp recorded at the start of the full + sampler run. + generation_start_time (float): Timestamp recorded at the start of the + generation. + particle_kwargs (dict[str, Any]): Keyword arguments forwarded into + particle evaluation. + """ + + generation: int + n_workers: int + parallel_executor: ThreadPoolExecutor | None + overall_start_time: float + generation_start_time: float + particle_kwargs: dict[str, Any] + + +@dataclass(slots=True) +class ParticlewiseGenerationState: + """Store mutable state for one particlewise generation. + + This state groups the proposed population, deterministic generator slots, + and generation-specific sample method so helper methods can share that data + without long positional argument lists. + + Attributes: + proposed_population (ParticlePopulation): Population being filled for + the active generation. + generator_slots (list[GeneratorSlot]): Proposal slots used to preserve + deterministic ordering across execution modes. + sample_method (Callable[[SeedSequence | None], Particle]): Proposal + function for the active generation. + """ + + proposed_population: ParticlePopulation + generator_slots: list[GeneratorSlot] + sample_method: Callable[[SeedSequence | None], Particle] + + +class ParticlewiseGenerationRunner: + """Run one particlewise ABC-SMC generation in serial or threaded mode. + + This runner isolates the particlewise execution engine from `ABCSampler`. + It handles proposal generation, progress reporting, acceptance accounting, + and population finalization while writing sampler run bookkeeping through a + dedicated run-state object. + + Args: + config (ParticlewiseGenerationConfig): Static settings and callbacks + used across particlewise generations. + run_state (SamplerRunState): Mutable bookkeeping for the active sampler + run. + """ + + def __init__( + self, + config: ParticlewiseGenerationConfig, + run_state: SamplerRunState, + ) -> None: + self.config = config + self.run_state = run_state + + def run_generation( + self, + request: ParticlewiseGenerationRequest, + ) -> GenerationStats: + """Execute one particlewise generation and store its final population. + + This method coordinates generation setup, proposal collection, and + population finalization for the particlewise sampler path while keeping + the caller focused on top-level orchestration. + + Args: + request (ParticlewiseGenerationRequest): Runtime inputs for the + generation being executed. + + Returns: + GenerationStats: Attempts, successes, and timing metrics recorded + for the completed generation. + """ + + state = self._init_generation(request.generation) + accepted_list, generation_stats = self._collect_accepted_particles( + request=request, + state=state, + ) + self._finalize_generation( + request=request, + state=state, + accepted_list=accepted_list, + generation_stats=generation_stats, + ) + return generation_stats + + def sample_particles_until_accepted( + self, + generator: GeneratorSlot, + tolerance: float, + sample_method: Callable[[SeedSequence | None], Particle], + evaluation_kwargs: dict[str, Any], + max_attempts: int | None = None, + ) -> AcceptedProposal: + """Propose particles until one is accepted or attempts are exhausted. + + This rejection-sampling loop repeatedly generates one particle for the + provided slot until its distance is at most the generation tolerance or + the configured attempt limit is reached. + + Args: + generator (GeneratorSlot): Deterministic generator slot to evaluate. + tolerance (float): Maximum accepted distance for the proposal. + sample_method (Callable[[SeedSequence | None], Particle]): Proposal + function for the active generation. + evaluation_kwargs (dict[str, Any]): Keyword arguments forwarded into + particle evaluation. + max_attempts (int | None): Override for the maximum number of + proposal attempts allowed for the slot. + + Returns: + AcceptedProposal: Accepted particle data for the slot, or a record + showing that no particle was accepted before attempts were + exhausted. + """ + + if max_attempts is None: + max_attempts = self.config.max_attempts_per_proposal + + for attempt in range(max_attempts): + proposed_particle = sample_method(generator.seed_sequence) + err = self.config.particle_to_distance( + proposed_particle, + **evaluation_kwargs, + ) + if err <= tolerance: + return AcceptedProposal( + slot_id=generator.id, + particle=proposed_particle, + attempts=attempt + 1, + ) + return AcceptedProposal( + slot_id=generator.id, + particle=None, + attempts=max_attempts, + ) + + def _get_sample_method( + self, + generation: int, + ) -> Callable[[SeedSequence | None], Particle]: + """Return the proposal function for the requested generation. + + The initial generation samples directly from the priors, while later + generations sample from the previous population and perturb the chosen + particle. + + Args: + generation (int): Zero-based generation index being executed. + + Returns: + Callable[[SeedSequence | None], Particle]: Proposal function for the + requested generation. + """ + + if generation == 0: + return self.config.sample_particle_from_priors + return self.config.sample_and_perturb_particle + + def _init_generation( + self, + generation: int, + ) -> ParticlewiseGenerationState: + """Create the mutable state needed to execute one generation. + + This method prepares the next empty population, spawns deterministic + generator slots, and selects the proposal function for the generation. + + Args: + generation (int): Zero-based generation index being initialized. + + Returns: + ParticlewiseGenerationState: Mutable generation state shared across + the collection and finalization steps. + """ + + generator_slots = [ + GeneratorSlot(id=index, seed_sequence=seed_sequence) + for index, seed_sequence in enumerate( + self.config.seed_sequence.spawn( + self.config.generation_particle_count + ) + ) + ] + return ParticlewiseGenerationState( + proposed_population=ParticlePopulation(), + generator_slots=generator_slots, + sample_method=self._get_sample_method(generation), + ) + + def _collect_accepted_particles_serial( + self, + request: ParticlewiseGenerationRequest, + state: ParticlewiseGenerationState, + handle: ProgressHandle, + ) -> tuple[list[AcceptedProposal], int]: + """Collect accepted proposals serially for one generation. + + This path evaluates one generator slot at a time while keeping the + shared progress display up to date using deterministic proposal order. + + Args: + request (ParticlewiseGenerationRequest): Runtime inputs for the + generation being executed. + state (ParticlewiseGenerationState): Mutable state for the active + generation. + handle (ProgressHandle): Handle referencing the active progress + task. + + Returns: + tuple[list[AcceptedProposal], int]: Accepted proposals for the + generation and the total number of attempts consumed. + """ + + accepted_list: list[AcceptedProposal] = [] + total_attempts = 0 + for completed, generator in enumerate(state.generator_slots, start=1): + accepted_proposal = self.sample_particles_until_accepted( + generator=generator, + tolerance=self.config.tolerance_values[request.generation], + sample_method=state.sample_method, + evaluation_kwargs=request.particle_kwargs, + ) + accepted_list.append(accepted_proposal) + total_attempts += accepted_proposal.attempts + self._update_progress( + handle=handle, + completed=completed, + total_attempts=total_attempts, + generation_start_time=request.generation_start_time, + ) + + return accepted_list, total_attempts + + async def _collect_accepted_particles_async( + self, + request: ParticlewiseGenerationRequest, + state: ParticlewiseGenerationState, + handle: ProgressHandle, + ) -> tuple[list[AcceptedProposal], int]: + """Collect accepted proposals concurrently over the executor. + + This path submits one generator slot per executor task, records results + as tasks complete, and keeps the shared progress display synchronized + with the aggregate attempt count. + + Args: + request (ParticlewiseGenerationRequest): Runtime inputs for the + generation being executed. + state (ParticlewiseGenerationState): Mutable state for the active + generation. + handle (ProgressHandle): Handle referencing the active progress + task. + + Returns: + tuple[list[AcceptedProposal], int]: Accepted proposals for the + generation and the total number of attempts consumed. + + Raises: + BaseException: Re-raises any exception raised while collecting + proposals after cancelling the outstanding tasks. + """ + + assert request.parallel_executor is not None + + accepted_list: list[AcceptedProposal] = [] + total_attempts = 0 + completed = 0 + loop = asyncio.get_running_loop() + worker = partial( + self.sample_particles_until_accepted, + tolerance=self.config.tolerance_values[request.generation], + sample_method=state.sample_method, + evaluation_kwargs=request.particle_kwargs, + ) + tasks = [ + loop.run_in_executor(request.parallel_executor, worker, generator) + for generator in state.generator_slots + ] + + try: + for task in asyncio.as_completed(tasks): + accepted_proposal = await task + accepted_list.append(accepted_proposal) + total_attempts += accepted_proposal.attempts + completed += 1 + self._update_progress( + handle=handle, + completed=completed, + total_attempts=total_attempts, + generation_start_time=request.generation_start_time, + ) + except BaseException: + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + + return accepted_list, total_attempts + + def _collect_accepted_particles( + self, + request: ParticlewiseGenerationRequest, + state: ParticlewiseGenerationState, + ) -> tuple[list[AcceptedProposal], GenerationStats]: + """Collect accepted proposals and emit progress output for one generation. + + This method owns the shared progress UI, dispatches to either the + serial or threaded collection path, and produces the generation summary + metrics once proposal collection is complete. + + Args: + request (ParticlewiseGenerationRequest): Runtime inputs for the + generation being executed. + state (ParticlewiseGenerationState): Mutable state for the active + generation. + + Returns: + tuple[list[AcceptedProposal], GenerationStats]: Accepted proposals + for the generation and the summary statistics derived from the + collection phase. + """ + + description = ( + f"Generation {request.generation + 1} " + f"(tolerance {self.config.tolerance_values[request.generation]})..." + ) + with self.config.reporter.create_collection_progress() as progress: + handle = self.config.reporter.start_collection_task( + progress=progress, + description=description, + total=self.config.generation_particle_count, + ) + if request.n_workers == 1: + accepted_list, total_attempts = ( + self._collect_accepted_particles_serial( + request=request, + state=state, + handle=handle, + ) + ) + else: + accepted_list, total_attempts = run_coroutine_from_sync( + lambda: self._collect_accepted_particles_async( + request=request, + state=state, + handle=handle, + ) + ) + + generation_stats = self._build_generation_stats( + request=request, + total_attempts=total_attempts, + accepted_count=len(accepted_list), + ) + self.config.reporter.print_generation_summary( + generation=request.generation, + tolerance=self.config.tolerance_values[request.generation], + generation_stats=generation_stats, + ) + + return accepted_list, generation_stats + + def _update_progress( + self, + handle: ProgressHandle, + completed: int, + total_attempts: int, + generation_start_time: float, + ) -> None: + """Update generation progress using completed slots and total attempts. + + This helper keeps ETA and acceptance-rate calculation in one place so + both serial and threaded proposal collection report progress the same + way. + + Args: + handle (ProgressHandle): Handle referencing the active progress + task. + completed (int): Number of generator slots completed so far. + total_attempts (int): Total proposal attempts consumed so far. + generation_start_time (float): Timestamp recorded at the start of + the generation. + + Returns: + None: This helper does not return a value. + """ + + elapsed = time.time() - generation_start_time + eta = ( + elapsed + * (self.config.generation_particle_count - completed) + / completed + if elapsed > 0 and completed > 0 + else 0.0 + ) + acceptance_rate = ( + 100.0 * completed / total_attempts if total_attempts > 0 else 0.0 + ) + self.config.reporter.update_collection_progress( + handle=handle, + completed=completed, + acceptance_rate=acceptance_rate, + eta_seconds=eta, + ) + + def _build_generation_stats( + self, + request: ParticlewiseGenerationRequest, + total_attempts: int, + accepted_count: int, + ) -> GenerationStats: + """Build summary metrics for a completed collection phase. + + This helper records attempts, accepted particles, and elapsed timing in + the shared `GenerationStats` carrier used by sampler execution paths. + + Args: + request (ParticlewiseGenerationRequest): Runtime inputs for the + generation being executed. + total_attempts (int): Total proposal attempts consumed. + accepted_count (int): Number of accepted particles collected. + + Returns: + GenerationStats: Summary metrics for the completed generation. + """ + + return GenerationStats( + attempts=total_attempts, + successes=accepted_count, + processing_time=time.time() - request.generation_start_time, + total_time=time.time() - request.overall_start_time, + ) + + def _finalize_generation( + self, + request: ParticlewiseGenerationRequest, + state: ParticlewiseGenerationState, + accepted_list: list[AcceptedProposal], + generation_stats: GenerationStats, + ) -> None: + """Convert accepted proposals into the next weighted population. + + This method sorts accepted proposals back into deterministic slot order, + computes particle weights, records generation bookkeeping, and stores + the finalized population on the sampler. + + Args: + request (ParticlewiseGenerationRequest): Runtime inputs for the + generation being executed. + state (ParticlewiseGenerationState): Mutable state for the active + generation. + accepted_list (list[AcceptedProposal]): Accepted proposal records + collected for the generation. + generation_stats (GenerationStats): Summary metrics for the + completed generation. + + Returns: + None: This helper does not return a value. + + Raises: + UserWarning: Raised when a proposal slot exhausts all attempts + without producing an accepted particle. + """ + + with self.config.reporter.create_weight_progress() as progress: + handle = self.config.reporter.start_weight_task( + progress=progress, + total=self.config.generation_particle_count, + ) + for accepted_proposal in sorted( + accepted_list, + key=lambda proposal: proposal.slot_id, + ): + if accepted_proposal.particle is None: + raise UserWarning( + "Particle proposal attempt " + f"{accepted_proposal.slot_id} used " + f"{accepted_proposal.attempts} samples and found no " + "acceptable values." + ) + particle_weight = ( + 1.0 + if request.generation == 0 + else self.config.calculate_weight( + accepted_proposal.particle + ) + ) + state.proposed_population.add_particle( + accepted_proposal.particle, + particle_weight, + ) + self.config.reporter.advance(handle) + + self.run_state.record_generation_history( + request.generation, + state.generator_slots, + ) + self.run_state.record_attempts( + generation=request.generation, + attempts=generation_stats.attempts, + successes=generation_stats.successes, + ) + self.config.replace_particle_population(state.proposed_population) + weights_time = ( + time.time() + - request.generation_start_time + - generation_stats.processing_time + ) + self.config.reporter.print_timing_summary( + processing_time=generation_stats.processing_time, + total_time=generation_stats.total_time, + weights_time=weights_time, + ) diff --git a/src/calibrationtools/sampler.py b/src/calibrationtools/sampler.py index 5856bcc..a4f68d6 100644 --- a/src/calibrationtools/sampler.py +++ b/src/calibrationtools/sampler.py @@ -1,24 +1,33 @@ import copy -import multiprocessing as mp -import sys import time -from concurrent.futures import ProcessPoolExecutor -from functools import partial +from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, Callable, Literal, Sequence import numpy as np from mrp import MRPModel from numpy.random import SeedSequence -from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn -from . import formatting +from .batch_generation_runner import ( + BatchGenerationConfig, + BatchGenerationRequest, + BatchGenerationRunner, +) from .calibration_results import CalibrationResults from .particle import Particle +from .particle_evaluator import ParticleEvaluator from .particle_population import ParticlePopulation from .particle_updater import _ParticleUpdater +from .particlewise_generation_runner import ( + ParticlewiseGenerationConfig, + ParticlewiseGenerationRequest, + ParticlewiseGenerationRunner, +) from .perturbation_kernel import PerturbationKernel from .prior_distribution import PriorDistribution +from .sampler_reporting import SamplerReporter +from .sampler_run_state import SamplerRunState +from .sampler_types import GeneratorSlot from .variance_adapter import VarianceAdapter @@ -33,22 +42,27 @@ class ABCSampler: generation_particle_count (int): Number of particles to accept per generation for a complete population. tolerance_values (list[float]): List of tolerance values for each generation for evaluating acceptance criterion. priors (PriorDistribution | dict | Path): Prior distribution of the parameters being calibrated. Can be provided as a PriorDistribution object, a dictionary, or a path to a JSON file containing a valid priors schema. - particles_to_params (Callable[[Particle], dict]): Function to map particles to model parameters. + particles_to_params (Callable[..., dict]): Function to map particles to model parameters. outputs_to_distance (Callable[..., float]): Function to compute distance between model outputs and target data. target_data (Any): Observed data to compare against. model_runner (MRPModel): Model runner to simulate outputs given parameters. perturbation_kernel (PerturbationKernel): Initial kernel used to perturb particles across SMC steps. variance_adapter (VarianceAdapter): Adapter to adjust perturbation variance across SMC steps. max_attempts_per_proposal (int): Maximum number of sample and perturb attempts to propose a particle. + parallel_worker_count (int): Default number of workers to use for sampler parallel execution when `max_workers` is not supplied. seed (int | None): Random seed for reproducibility. verbose (bool): Whether to print verbose output during execution. - drop_previous_population_data (bool): Whether to drop previous population data when storing the accepted particles between SMC steps. + keep_previous_population_data (bool): Whether to retain previous + population data in the per-run archive when storing accepted + particles between SMC steps. seed_parameter_name (str | None): The name of the seed parameter to include in the priors if `incl_seed_parameter` is True when loading priors from a dictionary or JSON file. + Raises: + ValueError: If `parallel_worker_count` is not positive. + Methods: particle_population: - Getter and setter for the current particle population. Automatically archives - the previous population if `drop_previous_population_data` is False. + Getter and setter for the current particle population. run(**kwargs: Any): Executes the ABC-SMC algorithm. Raises an error if any keyword argument conflicts @@ -75,29 +89,38 @@ def __init__( generation_particle_count: int, tolerance_values: list[float], priors: PriorDistribution | dict | Path, - particles_to_params: Callable[[Particle], dict], + particles_to_params: Callable[..., dict], outputs_to_distance: Callable[..., float], target_data: Any, model_runner: MRPModel, perturbation_kernel: PerturbationKernel, variance_adapter: VarianceAdapter, max_attempts_per_proposal: int = np.iinfo(np.int32).max, + parallel_worker_count: int = 10, seed: int | None = None, verbose: bool = True, - drop_previous_population_data: bool = False, + keep_previous_population_data: bool = False, seed_parameter_name: str | None = "seed", ): + if parallel_worker_count <= 0: + raise ValueError("parallel_worker_count must be positive") self.generation_particle_count = generation_particle_count self.max_attempts_per_proposal = max_attempts_per_proposal + self.parallel_worker_count = parallel_worker_count self.tolerance_values = tolerance_values self._variance_adapter = variance_adapter self.particles_to_params = particles_to_params self.outputs_to_distance = outputs_to_distance self.target_data = target_data self.model_runner = model_runner + self._particle_evaluator = ParticleEvaluator( + particles_to_params=particles_to_params, + outputs_to_distance=outputs_to_distance, + target_data=target_data, + model_runner=model_runner, + ) self.seed = seed - self.drop_previous_population_data = drop_previous_population_data - self.population_archive: dict[int, ParticlePopulation] = {} + self.keep_previous_population_data = keep_previous_population_data self.verbose = verbose if isinstance(priors, PriorDistribution): @@ -116,54 +139,117 @@ def __init__( self._priors = load_priors_from_json(priors) self.init_updater(perturbation_kernel) - self.step_successes = [0] * len(self.tolerance_values) - self.step_attempts = [0] * len(self.tolerance_values) - self.generator_history = {} + self._run_state = SamplerRunState( + generation_count=len(self.tolerance_values), + keep_previous_population_data=keep_previous_population_data, + ) + + @property + def step_successes(self) -> list[int]: + """Return accepted-particle counts for the active run. + + This property exposes the generation-level success counts recorded in + the sampler run state. + + Returns: + list[int]: Accepted-particle count for each generation in the + active run. + """ + + return self._run_state.step_successes + + @property + def step_attempts(self) -> list[int]: + """Return proposal-attempt counts for the active run. + + This property exposes the generation-level attempt counts recorded in + the sampler run state. + + Returns: + list[int]: Proposal-attempt count for each generation in the active + run. + """ + + return self._run_state.step_attempts + + @property + def generator_history(self) -> dict[int, list[GeneratorSlot]]: + """Return generator slots used for each completed generation. + + This property exposes the deterministic generator slots recorded during + particlewise execution. + + Returns: + dict[int, list[GeneratorSlot]]: Generator slots keyed by generation + index. + """ + + return self._run_state.generator_history + + @property + def population_archive(self) -> dict[int, ParticlePopulation]: + """Return archived populations captured during the active run. + + This property exposes the per-run archive of previous populations that + was recorded while the active run progressed across generations. + + Returns: + dict[int, ParticlePopulation]: Archived populations keyed by their + archive step. + """ + + return self._run_state.population_archive @property def particle_population(self) -> ParticlePopulation: + """Return the current particle population. + + This property exposes the sampler's active population without changing + any run bookkeeping. + + Returns: + ParticlePopulation: Current particle population stored on the + updater. + """ + return self._updater.particle_population @particle_population.setter - def particle_population(self, population: ParticlePopulation): - """ - Updates the particle population for the sampler. + def particle_population(self, population: ParticlePopulation) -> None: + """Set the current particle population without altering bookkeeping. - If `drop_previous_population_data` is set to False and there is existing - particle population data, the current particle population is archived - before updating to the new population. + This setter updates the population stored on the updater while leaving + archive and generation counters untouched. Args: - population (ParticlePopulation): The new particle population to set. - - Attributes: - drop_previous_population_data (bool): Determines whether to discard - previous population data or archive it. - _updater.particle_population (ParticlePopulation): The current particle - population managed by the updater. - population_archive (dict): A dictionary storing archived particle - populations, indexed by step. - - Behavior: - - If `drop_previous_population_data` is False and there is existing - particle population data, the current population is archived with - a step index. - - Updates the `_updater.particle_population` with the new population. - - Weights of the new population are normalized and the perturbation - variance is adapted by the particle updater's setter method. - """ - if ( - not self.drop_previous_population_data - and self._updater.particle_population.size > 0 - ): - step = ( - max(self.population_archive.keys()) + 1 - if self.population_archive - else 0 - ) - self.population_archive.update({step: self.particle_population}) + population (ParticlePopulation): Population to store as the current + sampler population. + + Returns: + None: This setter does not return a value. + """ + self._updater.particle_population = population + def _replace_particle_population( + self, population: ParticlePopulation + ) -> None: + """Archive the current population in run state, then store the new one. + + This helper keeps archive bookkeeping explicit by recording the + outgoing population before updating the current population. + + Args: + population (ParticlePopulation): New population to store on the + sampler. + + Returns: + None: This helper does not return a value. + """ + + self._run_state.archive_population(self.particle_population) + self.particle_population = population + @property def perturbation_kernel(self) -> PerturbationKernel: return self._updater.perturbation_kernel @@ -238,76 +324,63 @@ def sample_and_perturb_particle( ) def particle_to_distance(self, particle: Particle, **kwargs: Any) -> float: - """ - Computes the distance between the model output generated from the given particle and the target data using the user-supplied `particles_to_params` and `outputs_to_distance` functions. + """Compute the distance for one proposed particle. + + This method keeps `ABCSampler` as the public entry point for particle + evaluation while delegating the actual model execution and scoring work + to the extracted `ParticleEvaluator`. + Args: particle (Particle): The particle for which to compute the distance. - **kwargs (Any): Additional keyword arguments that can be passed to the `particles_to_params` function. These arguments are supplied from the `run()` method and can include any user-defined parameters needed for mapping particles to model parameters. + **kwargs (Any): Additional keyword arguments forwarded to + `particles_to_params`. + Returns: - float: The computed distance between the model output generated from the given particle and the target data, as calculated by the `outputs_to_distance` function. + float: Distance between the simulated outputs and the target data. """ - params = self.particles_to_params(particle, **kwargs) - outputs = self.model_runner.simulate(params) - err = self.outputs_to_distance(outputs, self.target_data) - return err + return self._particle_evaluator.distance(particle, **kwargs) def calculate_weight(self, particle: Particle) -> float: - """ - Calculates the weight of a given particle based on its prior and perturbed probabilities using the particle updater's calculate_weight method. + """Calculate the importance weight for one accepted particle. + + This method preserves the public sampler API while delegating the + actual weight calculation to the particle updater. + Args: particle (Particle): The particle for which to calculate the weight. + Returns: - float: The calculated weight of the particle, which is based on the prior probability density and the + float: Importance weight for the particle under the current + population and perturbation kernel. """ return self._updater.calculate_weight(particle) def get_results_and_reset( self, perturbation_kernel: PerturbationKernel ) -> CalibrationResults: - """ - Compiles the results of the calibration process into a CalibrationResults object and resets the sampler for potential future runs. + """Build calibration results and reset mutable sampler state. + + This method validates that each generation produced a full accepted + population, creates the immutable `CalibrationResults` snapshot, and + then resets the sampler so it can be reused for a later run. + Args: - perturbation_kernel (PerturbationKernel): The originator perturbation kernel to reset to after the run. + perturbation_kernel (PerturbationKernel): Perturbation kernel to + restore on the sampler after result construction. + Returns: - CalibrationResults: An object containing the results of the calibration process, including the final particle population - and the history of successes and attempts for each generation. - Raises: - UserWarning: If the number of successful particles in any generation is less than the specified generation + CalibrationResults: Snapshot containing the final posterior, + generation history, archive data, and success statistics. """ - if any( - [ - count < self.generation_particle_count - for count in self.step_successes - ] - ): - raise UserWarning( - "The number of successful particles in at least one generation is less than the specified generation_particle_count. This may indicate that the maximum particle proposal attempts are too low or the error tolerance values are too strict for the model and target data." - ) - results = CalibrationResults( - copy.deepcopy(self._updater), - self.generator_history, - self.population_archive, - { - "generation_particle_count": [self.generation_particle_count] - * len(self.tolerance_values), - "successes": self.step_successes, - "attempts": self.step_attempts, - }, - self.tolerance_values, - ) - - # Reset particle sampler and successes - self.init_updater(perturbation_kernel) - self.step_successes = [0] * len(self.tolerance_values) - self.step_attempts = [0] * len(self.tolerance_values) - self.generator_history = {} + results = self._build_results() + self._reset_after_run(perturbation_kernel) return results def run_parallel( self, max_workers: int | None = None, **kwargs: Any ) -> CalibrationResults: """ - Executes the Sequential Monte Carlo (SMC) sampling process in parallel using a ProcessPoolExecutor. + Executes the Sequential Monte Carlo (SMC) sampling process in parallel using async orchestration over a thread pool. This method performs the SMC algorithm to generate a population of particles that approximate the posterior distribution of the model parameters. The process @@ -316,7 +389,7 @@ def run_parallel( The execution is parallelized to improve performance. Args: - max_workers (int | None): The maximum number of worker processes to use when running in parallel. If None, it defaults to the number of CPU cores available. + max_workers (int | None): The maximum number of worker threads to use when running in parallel. If None, it defaults to the sampler's configured `parallel_worker_count`. **kwargs (Any): Additional keyword arguments that can be passed to the method. These arguments are supplied to the particles_to_params function. Note that the keyword arguments must not conflict with existing @@ -348,6 +421,187 @@ def run_serial(self, **kwargs: Any) -> CalibrationResults: """ return self.run(execution="serial", **kwargs) + def _resolve_worker_count(self, max_workers: int | None) -> int: + """Resolve the worker count for a parallel sampler run. + + This helper applies the sampler default when the caller does not supply + `max_workers` and validates that the resolved value is positive. + + Args: + max_workers (int | None): Optional worker-count override supplied + by the caller. + + Returns: + int: Positive worker count for the run. + + Raises: + ValueError: Raised when the resolved worker count is not positive. + """ + worker_count = ( + max_workers + if max_workers is not None + else self.parallel_worker_count + ) + if worker_count <= 0: + raise ValueError("max_workers must be positive") + return worker_count + + def _build_reporter(self) -> SamplerReporter: + """Create the reporter used for one sampler run. + + This helper centralizes reporter construction so the public run + methods can share the same output behavior and honor the sampler's + `verbose` flag consistently. + + Returns: + SamplerReporter: Reporter configured for the current verbosity. + """ + + return SamplerReporter(verbose=self.verbose) + + def _build_particlewise_generation_runner( + self, + reporter: SamplerReporter, + ) -> ParticlewiseGenerationRunner: + """Create the particlewise execution engine for the active run. + + This helper collects the stable callbacks and configuration needed by + the extracted particlewise runner while keeping the public sampler + facade small. + + Args: + reporter (SamplerReporter): Reporter used for progress and summary + output during the run. + + Returns: + ParticlewiseGenerationRunner: Runner configured for the current + sampler state. + """ + + return ParticlewiseGenerationRunner( + config=ParticlewiseGenerationConfig( + generation_particle_count=self.generation_particle_count, + tolerance_values=self.tolerance_values, + seed_sequence=self._seed_sequence, + max_attempts_per_proposal=self.max_attempts_per_proposal, + sample_particle_from_priors=self.sample_particle_from_priors, + sample_and_perturb_particle=self.sample_and_perturb_particle, + particle_to_distance=self.particle_to_distance, + calculate_weight=self.calculate_weight, + replace_particle_population=self._replace_particle_population, + reporter=reporter, + ), + run_state=self._run_state, + ) + + def _build_batch_generation_runner( + self, + reporter: SamplerReporter, + ) -> BatchGenerationRunner: + """Create the batched execution engine for the active run. + + This helper collects the stable callbacks and configuration needed by + the extracted batch runner while keeping the public sampler facade + focused on orchestration. + + Args: + reporter (SamplerReporter): Reporter used for progress and summary + output during the run. + + Returns: + BatchGenerationRunner: Runner configured for the current sampler + state. + """ + + return BatchGenerationRunner( + config=BatchGenerationConfig( + generation_particle_count=self.generation_particle_count, + tolerance_values=self.tolerance_values, + sample_particle_from_priors=self.sample_particle_from_priors, + sample_and_perturb_particle=self.sample_and_perturb_particle, + particle_to_distance=self.particle_to_distance, + calculate_weight=self.calculate_weight, + replace_particle_population=self._replace_particle_population, + reporter=reporter, + ), + run_state=self._run_state, + ) + + def _validate_run_kwargs(self, kwargs: dict[str, Any]) -> None: + """Validate keyword arguments forwarded into particle evaluation. + + This helper protects sampler execution from accidental collisions + between run-time keyword arguments and existing class attributes. + + Args: + kwargs (dict[str, Any]): Keyword arguments supplied to a run method. + + Returns: + None: This helper does not return a value. + + Raises: + ValueError: Raised when a run-time keyword collides with an + existing class attribute. + """ + for key in kwargs: + if key in self.__class__.__dict__: + raise ValueError( + f"Keyword argument '{key}' conflicts with existing attribute. Please choose a different name for the argument. ABCSampler attributes cannot be set from `.run()`" + ) + + def _build_results(self) -> CalibrationResults: + """Build the immutable results snapshot for the completed run. + + This helper validates that every generation reached the target + population size and constructs the `CalibrationResults` object from the + sampler's current state. + + Returns: + CalibrationResults: Snapshot containing the final posterior, + generation history, archive data, and success statistics. + + Raises: + UserWarning: Raised when any generation finished with fewer accepted + particles than `generation_particle_count`. + """ + + if any( + count < self.generation_particle_count + for count in self.step_successes + ): + raise UserWarning( + "The number of successful particles in at least one generation is less than the specified generation_particle_count. This may indicate that the maximum particle proposal attempts are too low or the error tolerance values are too strict for the model and target data." + ) + return CalibrationResults( + copy.deepcopy(self._updater), + self.generator_history, + self.population_archive, + self._run_state.build_success_counts( + self.generation_particle_count + ), + self.tolerance_values, + ) + + def _reset_after_run( + self, + perturbation_kernel: PerturbationKernel, + ) -> None: + """Reset mutable sampler state after a completed run. + + This helper restores the original perturbation kernel and clears all + per-run bookkeeping so the sampler can be reused safely. + + Args: + perturbation_kernel (PerturbationKernel): Perturbation kernel to + restore on the sampler. + + Returns: + None: This helper does not return a value. + """ + + self.init_updater(perturbation_kernel) + self._run_state.reset() + def run( self, execution: Literal["serial", "parallel"] = "parallel", @@ -364,259 +618,53 @@ def run( Args: execution (Literal['serial', 'parallel']): Determines whether to run the SMC sampling process in serial or parallel. Defaults to 'serial'. - max_workers (int | None): The maximum number of worker processes to use when running in parallel. If None, it defaults to the number of CPU cores available. This argument is ignored when execution is set to 'serial'. + max_workers (int | None): The maximum number of worker threads to use when running in parallel. If None, it defaults to the sampler's configured `parallel_worker_count`. This argument is ignored when execution is set to 'serial'. **kwargs (Any): Additional keyword arguments that can be passed to the method. These arguments are supplied to the particles_to_params function. Note that the keyword arguments must not conflict with existing attributes of the class. Returns: CalibrationResults: An object containing the results of the calibration process. - Raises: - ValueError: If any keyword argument in `kwargs` conflicts with existing attributes of the class - UserWarning: If the number of successful particles in any generation is less than the specified generation_particle_count """ - for k in kwargs.keys(): - if k in self.__class__.__dict__: - raise ValueError( - f"Keyword argument '{k}' conflicts with existing attribute. Please choose a different name for the argument. ABCSampler attributes cannot be set from `.run()`" - ) - + self._validate_run_kwargs(kwargs) originator_perturbation_kernel = copy.deepcopy( self.perturbation_kernel ) - - console = formatting.get_console() + reporter = self._build_reporter() overall_start_time = time.time() + n_workers = ( + self._resolve_worker_count(max_workers) + if execution == "parallel" + else 1 + ) + particlewise_runner = self._build_particlewise_generation_runner( + reporter=reporter + ) + parallel_executor = ( + ThreadPoolExecutor(max_workers=n_workers) + if execution == "parallel" and n_workers > 1 + else None + ) - if execution == "parallel": - if sys.platform.startswith("linux"): - import multiprocessing - - multiprocessing.set_start_method("fork", force=True) - n_procs = ( - min(max_workers, (max(mp.cpu_count(), 1))) - if max_workers - else (mp.cpu_count() or 1) - ) - else: - n_procs = 1 - - for generation in range(len(self.tolerance_values)): - generation_start_time = time.time() - - # Init the proposed population - proposed_population = ParticlePopulation() - - # Generate the seed sequences used for each particle in the proposed population - _sequences = self._seed_sequence.spawn( - self.generation_particle_count - ) - generator_list: list[dict[str, Any]] = [ - {"id": i, "seed_sequence": v} for i, v in enumerate(_sequences) - ] - - # Select sample method based on generation - from the priors if uninitiated, from the most recent population if available - if generation == 0: - sample_method = self.sample_particle_from_priors - else: - sample_method = self.sample_and_perturb_particle - - with Progress( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), - TextColumn("•"), - TextColumn("acceptance: {task.fields[acceptance]}"), - TextColumn("•"), - TimeElapsedColumn(), - TextColumn("•"), - TextColumn("ETA: {task.fields[eta]}"), - console=console, - transient=True, - ) as progress: - running_progress_bar = progress.add_task( - f"Generation {generation + 1} (tolerance {self.tolerance_values[generation]})...", - total=self.generation_particle_count, - acceptance="N/A", - eta="calculating...", - ) - accepted_list = [] - total_attempts = 0 - completed = 0 - # For each particle generator id and seed sequence, create an accepted particle and map to the id. - # Serial execution - if n_procs == 1: - for generator in generator_list: - accepted_list.append( - self.sample_particles_until_accepted( - generator=generator, - tolerance=self.tolerance_values[generation], - sample_method=sample_method, - **kwargs, - ) - ) - total_attempts += accepted_list[-1][2] - completed += 1 - elapsed = time.time() - generation_start_time - eta = ( - elapsed - * (self.generation_particle_count - completed) - / (completed or 1) - if elapsed > 0 and completed > 0 - else 0.0 - ) - acceptance_rate = ( - 100.0 * completed / total_attempts - if total_attempts > 0 - else 0.0 - ) - progress.update( - running_progress_bar, - completed=completed, - acceptance=f"{acceptance_rate:.1f}%", - eta=formatting._format_time(eta), - ) - - # Parallel execution - else: - if mp.current_process().name == "MainProcess": - with ProcessPoolExecutor( - max_workers=n_procs - ) as executor: - for completed_generator in executor.map( - partial( - self.sample_particles_until_accepted, - tolerance=self.tolerance_values[ - generation - ], - sample_method=sample_method, - **kwargs, - ), - generator_list, - ): - accepted_list.append(completed_generator) - total_attempts += completed_generator[2] - completed += 1 - elapsed = time.time() - generation_start_time - eta = ( - elapsed - * ( - self.generation_particle_count - - completed - ) - / (completed or 1) - if elapsed > 0 and completed > 0 - else 0.0 - ) - acceptance_rate = ( - 100.0 * completed / total_attempts - if total_attempts > 0 - else 0.0 - ) - progress.update( - running_progress_bar, - completed=completed, - acceptance=f"{acceptance_rate:.1f}%", - eta=formatting._format_time(eta), - ) - - total_time = time.time() - overall_start_time - processing_time = time.time() - generation_start_time - - # Store acceptance statistics - acceptance_rate = ( - 100.0 * completed / total_attempts - if total_attempts > 0 - else 0.0 - ) - console.print( - f"[green]✓[/green] Generation {generation + 1} run complete! " - f"Tolerance: {self.tolerance_values[generation]}, acceptance rate: {acceptance_rate:.1f}% of {total_attempts} attempts" - ) - - with Progress( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), - TextColumn("•"), - TimeElapsedColumn(), - console=console, - transient=True, - ) as progress: - task_id = progress.add_task( - "Calculating weights...", - total=self.generation_particle_count, - ) - # Collect the results from across the accepted particles - for id, accepted_particle, samples in sorted( - accepted_list, key=lambda x: x[0] - ): - if accepted_particle is not None: - self.step_successes[generation] += 1 - if generation == 0: - particle_weight = 1.0 - else: - particle_weight = self.calculate_weight( - accepted_particle - ) - proposed_population.add_particle( - accepted_particle, particle_weight - ) - progress.update(task_id, advance=1) - else: - raise UserWarning( - f"Particle proposal attempt {id} used {samples} samples and found no acceptable values." - ) - self.step_attempts[generation] += samples - - self.generator_history.update({generation: generator_list}) - self.particle_population = proposed_population - - weights_time = ( - time.time() - generation_start_time - processing_time - ) - # Summary with checkmark - console.print( - f"(Run: {formatting._format_time(processing_time)}, Weights calculation: {formatting._format_time(weights_time)}, total time: {formatting._format_time(total_time)})" + try: + for generation in range(len(self.tolerance_values)): + generation_stats = particlewise_runner.run_generation( + ParticlewiseGenerationRequest( + generation=generation, + n_workers=n_workers, + parallel_executor=parallel_executor, + overall_start_time=overall_start_time, + generation_start_time=time.time(), + particle_kwargs=dict(kwargs), + ) ) + finally: + if parallel_executor is not None: + parallel_executor.shutdown(wait=True) - # Summary with checkmark - console.print( - f"[green]✓[/green] Calibration complete! " - f"(total time: {formatting._format_time(total_time)})" - ) + reporter.print_run_summary(generation_stats.total_time) return self.get_results_and_reset(originator_perturbation_kernel) - def sample_particles_until_accepted( - self, - generator: dict[str, int | SeedSequence], - tolerance: float, - sample_method: Callable[[SeedSequence], Particle], - max_attempts: int | None = None, - **kwargs: Any, - ) -> tuple[int, Particle | None, int]: - """ - Rejection sampling routine to return a single value - - Args: - generator (dict[str, int | SeedSequence]): A dictionary containing the particle id and seed sequence generator for the random number generator spawn used in sampling. - tolerance (float): The tolerance value for accepting a particle based on the error returned from particle_to_distance(). - sample_method (Callable[[SeedSequence], Particle]): The method used to sample particles, which can be either from the priors or by perturbing existing particles when called from the sampler SMC routine. Any method that accepts a seed sequence and returns a particle is valid. - max_attempts (int | None): The maximum number of attempts to sample and perturb a particle before aborting. If None, it defaults to the sampler's `max_attempts_per_proposal` attribute. - **kwargs (Any): Additional keyword arguments that can be passed to the `particles_to_params` function. These arguments are supplied from the `run()` method and can include any user-defined parameters needed for mapping particles to model parameters. - Returns: - tuple[int, Particle | None, int]: A tuple containing the particle id, the accepted particle (or None if no acceptable particle was found within the maximum attempts), and the number of samples taken to find an acceptable particle below the provided tolerance. - """ - if not max_attempts: - max_attempts = self.max_attempts_per_proposal - - for i in range(max_attempts): - proposed_particle = sample_method(generator["seed_sequence"]) - err = self.particle_to_distance(proposed_particle, **kwargs) - if err <= tolerance: - return (generator["id"], proposed_particle, i + 1) - return (generator["id"], None, max_attempts) - def run_parallel_batches( self, chunksize: int = 1, @@ -625,7 +673,7 @@ def run_parallel_batches( **kwargs: Any, ) -> CalibrationResults: """ - Executes the Sequential Monte Carlo (SMC) sampling process in parallel using a LocalParallelExecutor. + Executes the Sequential Monte Carlo (SMC) sampling process in parallel using async orchestration over a thread pool. This method performs the SMC algorithm to generate a population of particles that approximate the posterior distribution of the model parameters. The process @@ -635,110 +683,50 @@ def run_parallel_batches( Args: chunksize (int): The approximate number of parameter sets to process in serial for each task when evaluating in parallel. Defaults to 1. - batchsize (int | None): The number of proposed particles to generate in each batch when evaluating in parallel. If None, it defaults to the generation_particle_count. This controls how many particles are proposed at once and submitted to the process pool. - max_workers (int | None): The maximum number of worker processes to use when running in parallel. If None, it defaults to the number of CPU cores available. + batchsize (int | None): The number of proposed particles to generate in each batch when evaluating in parallel. If None, it defaults to the generation_particle_count. This controls how many particles are proposed at once and submitted to the executor. + max_workers (int | None): The maximum number of worker threads to use when running in parallel. If None, it defaults to the sampler's configured `parallel_worker_count`. **kwargs (Any): Additional keyword arguments that can be passed to the method. These arguments are supplied to the particles_to_params function. Note that the keyword arguments must not conflict with existing attributes of the class. Returns: CalibrationResults: An object containing the results of the calibration process. - Raises: - ValueError: If any keyword argument in `kwargs` conflicts with existing attributes of the class """ - for k in kwargs.keys(): - if k in self.__class__.__dict__: - raise ValueError( - f"Keyword argument '{k}' conflicts with existing attribute. Please choose a different name for the argument. Args cannot be set from `.run()`" - ) - - actual_workers = ( - min(max_workers, (max(mp.cpu_count(), 1))) - if max_workers - else (mp.cpu_count() or 1) + self._validate_run_kwargs(kwargs) + actual_workers = self._resolve_worker_count(max_workers) + reporter = self._build_reporter() + batch_runner = self._build_batch_generation_runner(reporter=reporter) + resolved_batchsize, warmup = batch_runner.resolve_settings( + batchsize=batchsize, chunksize=chunksize ) - if not batchsize: - batchsize = self.generation_particle_count - warmup = True - else: - warmup = False - originator_perturbation_kernel = copy.deepcopy( self.perturbation_kernel ) + overall_start_time = time.time() + executor = ( + ThreadPoolExecutor(max_workers=actual_workers) + if actual_workers > 1 + else None + ) - if mp.current_process().name == "MainProcess": - with ProcessPoolExecutor(max_workers=actual_workers) as executor: - for generation in range(len(self.tolerance_values)): - if self.verbose: - print( - f"Running generation {generation + 1} with tolerance {self.tolerance_values[generation]}..." - ) - - proposed_population = ParticlePopulation() - - # Rejection sampling algorithm - attempts = 0 - while ( - proposed_population.size - < self.generation_particle_count - ): - if proposed_population.size > 0: - if warmup: - batchsize = 10_000 - sample_size = int( - min( - batchsize, - ( - self.generation_particle_count - - proposed_population.size - ) - * attempts - / proposed_population.size, - ) - ) - else: - sample_size = batchsize - if generation == 0: - proposed_particles = [ - self.sample_particle_from_priors() - for _ in range(sample_size) - ] - else: - proposed_particles = [ - self.sample_and_perturb_particle() - for _ in range(sample_size) - ] - if self.verbose and attempts > 0: - print( - f"Attempt {attempts}... current population size is {proposed_population.size}. Acceptance rate is {proposed_population.size / attempts if attempts > 0 else 0:.4f}", - end="\r", - ) - errs = executor.map( - partial(self.particle_to_distance, **kwargs), - proposed_particles, - chunksize=chunksize, - ) - for err, proposed_particle in zip( - errs, proposed_particles - ): - if ( - err < self.tolerance_values[generation] - and proposed_population.size - < self.generation_particle_count - ): - if generation == 0: - particle_weight = 1.0 - else: - particle_weight = self.calculate_weight( - proposed_particle - ) - proposed_population.add_particle( - proposed_particle, particle_weight - ) - attempts += len(proposed_particles) - self.step_successes[generation] = proposed_population.size - self.step_attempts[generation] = attempts - self.particle_population = proposed_population + try: + for generation in range(len(self.tolerance_values)): + generation_start_time = time.time() + generation_stats = batch_runner.run_generation( + BatchGenerationRequest( + generation=generation, + batchsize=resolved_batchsize, + warmup=warmup, + chunksize=chunksize, + executor=executor, + overall_start_time=overall_start_time, + generation_start_time=generation_start_time, + particle_kwargs=dict(kwargs), + ) + ) + finally: + if executor is not None: + executor.shutdown(wait=True) + reporter.print_run_summary(generation_stats.total_time) return self.get_results_and_reset(originator_perturbation_kernel) diff --git a/src/calibrationtools/sampler_reporting.py b/src/calibrationtools/sampler_reporting.py new file mode 100644 index 0000000..6290e65 --- /dev/null +++ b/src/calibrationtools/sampler_reporting.py @@ -0,0 +1,280 @@ +"""Report sampler progress and summaries through Rich. + +This module centralizes progress-bar construction and run-summary printing so +the execution runners can focus on sampling logic instead of Rich wiring. +""" + +from dataclasses import dataclass + +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + TaskID, + TextColumn, + TimeElapsedColumn, +) + +from . import formatting +from .sampler_types import GenerationStats + + +@dataclass(frozen=True, slots=True) +class ProgressHandle: + """Reference an active Rich progress task. + + This carrier keeps the `Progress` instance and the active task id together + so runners can update progress through a single object. + + Attributes: + progress (Progress): Active Rich progress instance. + task_id (TaskID): Task identifier within that progress instance. + """ + + progress: Progress + task_id: TaskID + + +class SamplerReporter: + """Create progress displays and print run summaries. + + This helper owns the Rich console and the formatting of generation and run + summaries so execution engines do not need to duplicate UI setup. + + Args: + verbose (bool): Whether progress and summary output should be visible. + console (Console | None): Optional console override used for tests. + """ + + def __init__( + self, + verbose: bool, + console: Console | None = None, + ) -> None: + self.console = ( + console if console is not None else formatting.get_console(verbose) + ) + + def create_collection_progress(self) -> Progress: + """Create the progress layout used during proposal collection. + + This helper centralizes the Rich progress columns used by both the + particlewise and batched generation paths. + + Returns: + Progress: Configured Rich progress instance for collection work. + """ + + return Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TextColumn("•"), + TextColumn("acceptance: {task.fields[acceptance]}"), + TextColumn("•"), + TimeElapsedColumn(), + TextColumn("•"), + TextColumn("ETA: {task.fields[eta]}"), + console=self.console, + transient=True, + ) + + def start_collection_task( + self, + progress: Progress, + description: str, + total: int, + ) -> ProgressHandle: + """Start a collection-phase progress task. + + This helper creates the task used to track accepted particles during a + generation and returns a handle for later updates. + + Args: + progress (Progress): Active Rich progress instance. + description (str): Description shown for the task. + total (int): Total items required to complete the task. + + Returns: + ProgressHandle: Handle referencing the created progress task. + """ + + task_id = progress.add_task( + description, + total=total, + acceptance="N/A", + eta="calculating...", + ) + return ProgressHandle(progress=progress, task_id=task_id) + + def update_collection_progress( + self, + handle: ProgressHandle, + completed: int, + acceptance_rate: float, + eta_seconds: float, + ) -> None: + """Update collection progress for a generation. + + This helper formats the acceptance rate and ETA consistently before + applying the update to the active task. + + Args: + handle (ProgressHandle): Handle referencing the active task. + completed (int): Number of completed items for the task. + acceptance_rate (float): Current acceptance rate as a percentage. + eta_seconds (float): Estimated remaining time in seconds. + + Returns: + None: This helper does not return a value. + """ + + handle.progress.update( + handle.task_id, + completed=completed, + acceptance=f"{acceptance_rate:.1f}%", + eta=formatting._format_time(eta_seconds), + ) + + def create_weight_progress(self) -> Progress: + """Create the progress layout used during weight calculation. + + This helper centralizes the simplified Rich progress display used for + the weight-calculation phase after particle acceptance. + + Returns: + Progress: Configured Rich progress instance for weight updates. + """ + + return Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TextColumn("•"), + TimeElapsedColumn(), + console=self.console, + transient=True, + ) + + def start_weight_task( + self, + progress: Progress, + total: int, + ) -> ProgressHandle: + """Start a weight-calculation progress task. + + This helper creates the task used to track post-acceptance weight + calculation and returns a handle for later updates. + + Args: + progress (Progress): Active Rich progress instance. + total (int): Total number of accepted particles to process. + + Returns: + ProgressHandle: Handle referencing the created progress task. + """ + + task_id = progress.add_task("Calculating weights...", total=total) + return ProgressHandle(progress=progress, task_id=task_id) + + def advance(self, handle: ProgressHandle, steps: int = 1) -> None: + """Advance a progress task by the requested number of steps. + + This helper keeps direct task advancement out of the execution + runners. + + Args: + handle (ProgressHandle): Handle referencing the active task. + steps (int): Number of steps to advance the task. + + Returns: + None: This helper does not return a value. + """ + + handle.progress.update(handle.task_id, advance=steps) + + def print_generation_summary( + self, + generation: int, + tolerance: float, + generation_stats: GenerationStats, + ) -> None: + """Print the summary for a completed generation. + + This helper prints the generation completion message with a consistent + acceptance-rate format across execution engines. + + Args: + generation (int): Zero-based generation index that completed. + tolerance (float): Tolerance used by the generation. + generation_stats (GenerationStats): Summary metrics for the + generation. + + Returns: + None: This helper does not return a value. + """ + + acceptance_rate = ( + 100.0 * generation_stats.successes / generation_stats.attempts + if generation_stats.attempts > 0 + else 0.0 + ) + self.console.print( + f"[green]✓[/green] Generation {generation + 1} run complete! " + f"Tolerance: {tolerance}, acceptance rate: {acceptance_rate:.1f}% " + f"of {generation_stats.attempts} attempts" + ) + + def print_timing_summary( + self, + processing_time: float, + total_time: float, + weights_time: float | None = None, + ) -> None: + """Print the timing summary for a completed generation. + + This helper keeps the generation timing message consistent across + particlewise and batched execution while allowing the weight phase to + be omitted for the batched path. + + Args: + processing_time (float): Seconds spent in the generation's main + processing phase. + total_time (float): Seconds elapsed since the run began. + weights_time (float | None): Optional seconds spent in the weight + calculation phase. + + Returns: + None: This helper does not return a value. + """ + + if weights_time is None: + self.console.print( + f"(Run: {formatting._format_time(processing_time)}, " + f"total time: {formatting._format_time(total_time)})" + ) + return + + self.console.print( + f"(Run: {formatting._format_time(processing_time)}, " + f"Weights calculation: {formatting._format_time(weights_time)}, " + f"total time: {formatting._format_time(total_time)})" + ) + + def print_run_summary(self, total_time: float) -> None: + """Print the summary for the completed sampler run. + + This helper keeps the final run-completion message in one place so the + public sampler facade does not need to format Rich output directly. + + Args: + total_time (float): Seconds elapsed for the full sampler run. + + Returns: + None: This helper does not return a value. + """ + + self.console.print( + f"[green]✓[/green] Calibration complete! " + f"(total time: {formatting._format_time(total_time)})" + ) diff --git a/src/calibrationtools/sampler_run_state.py b/src/calibrationtools/sampler_run_state.py new file mode 100644 index 0000000..fb6f66e --- /dev/null +++ b/src/calibrationtools/sampler_run_state.py @@ -0,0 +1,137 @@ +"""Track mutable sampler bookkeeping for one execution. + +This module keeps per-run counters, generator history, and population archives +separate from the public sampler facade so state reset behavior stays explicit. +""" + +from .particle_population import ParticlePopulation +from .sampler_types import GeneratorSlot + + +class SamplerRunState: + """Track mutable per-run bookkeeping for an `ABCSampler` execution. + + This class owns the generation counters and archive data that should be + reset between sampler runs while allowing `ABCSampler` to stay focused on + orchestration and public API concerns. + + Args: + generation_count (int): Number of configured generations in the run. + keep_previous_population_data (bool): Whether previous populations + should be archived between generations. + """ + + def __init__( + self, + generation_count: int, + keep_previous_population_data: bool, + ) -> None: + self.generation_count = generation_count + self.keep_previous_population_data = keep_previous_population_data + self.reset() + + def reset(self) -> None: + """Reset generation counters and archived run data. + + This method clears all per-run bookkeeping so the sampler can start a + fresh execution without leaking counters or archived populations from a + previous run. + + """ + + self.step_successes = [0] * self.generation_count + self.step_attempts = [0] * self.generation_count + self.generator_history: dict[int, list[GeneratorSlot]] = {} + self.population_archive: dict[int, ParticlePopulation] = {} + + def record_generation_history( + self, + generation: int, + generator_slots: list[GeneratorSlot], + ) -> None: + """Store the generator slots used to propose one generation. + + This method captures the deterministic generator slots for later result + inspection and for serial-versus-parallel comparisons in tests. + + Args: + generation (int): Zero-based generation index being recorded. + generator_slots (list[GeneratorSlot]): Generator slots used for the + generation. + + """ + + self.generator_history[generation] = list(generator_slots) + + def record_attempts( + self, + generation: int, + attempts: int, + successes: int, + ) -> None: + """Store the attempt and success counts for one generation. + + This method records the accepted-particle count and total proposal + attempts so result construction can report generation-level acceptance + diagnostics. + + Args: + generation (int): Zero-based generation index being recorded. + attempts (int): Total proposal attempts consumed by the generation. + successes (int): Total accepted particles produced by the + generation. + + """ + + self.step_attempts[generation] = attempts + self.step_successes[generation] = successes + + def archive_population( + self, + previous_population: ParticlePopulation, + ) -> None: + """Archive the previous population before a replacement is stored. + + This method records the outgoing population when archive retention is + enabled and the previous population is not empty. + + Args: + previous_population (ParticlePopulation): Population currently + stored on the sampler before replacement. + + """ + + if ( + self.keep_previous_population_data + and not previous_population.is_empty() + ): + step = ( + max(self.population_archive.keys()) + 1 + if self.population_archive + else 0 + ) + self.population_archive[step] = previous_population + + def build_success_counts( + self, generation_particle_count: int + ) -> dict[str, list[int]]: + """Build the success-count payload for `CalibrationResults`. + + This method packages generation size, successes, and attempts into the + structure expected by the existing `CalibrationResults` API. + + Args: + generation_particle_count (int): Target accepted-particle count for + each generation. + + Returns: + dict[str, list[int]]: Success-count payload used by result + construction. + """ + + return { + "generation_particle_count": [generation_particle_count] + * self.generation_count, + "successes": list(self.step_successes), + "attempts": list(self.step_attempts), + } diff --git a/src/calibrationtools/sampler_types.py b/src/calibrationtools/sampler_types.py new file mode 100644 index 0000000..b3ee965 --- /dev/null +++ b/src/calibrationtools/sampler_types.py @@ -0,0 +1,67 @@ +"""Define typed carriers shared across sampler execution helpers. + +This module replaces positional tuples and ad hoc dictionaries with small data +objects so sampler control flow reads in terms of named responsibilities. +""" + +from dataclasses import dataclass + +from numpy.random import SeedSequence + +from .particle import Particle + + +@dataclass(frozen=True, slots=True) +class GeneratorSlot: + """Identify one deterministic proposal stream for a generation. + + This carrier keeps the slot identifier and spawned seed sequence together + so proposal ordering remains explicit across serial and parallel execution. + + Attributes: + id (int): Stable proposal-slot identifier within the generation. + seed_sequence (SeedSequence): Seed sequence used for the slot. + """ + + id: int + seed_sequence: SeedSequence + + +@dataclass(frozen=True, slots=True) +class AcceptedProposal: + """Store the result of proposing until one particle is accepted. + + This carrier keeps the accepted particle, or an exhausted-attempt marker, + together with the slot identifier and attempt count used to produce it. + + Attributes: + slot_id (int): Proposal-slot identifier within the generation. + particle (Particle | None): Accepted particle for the slot, or `None` + when attempts were exhausted. + attempts (int): Proposal attempts consumed for the slot. + """ + + slot_id: int + particle: Particle | None + attempts: int + + +@dataclass(frozen=True, slots=True) +class GenerationStats: + """Summarize attempts, successes, and timing for one generation. + + This carrier provides a named summary object for generation-level metrics + that are shared across the particlewise and batched execution paths. + + Attributes: + attempts (int): Total proposal attempts consumed by the generation. + successes (int): Total accepted particles produced by the generation. + processing_time (float): Seconds spent in the generation processing + phase. + total_time (float): Seconds elapsed since the full sampler run began. + """ + + attempts: int + successes: int + processing_time: float + total_time: float diff --git a/src/example_model/__init__.py b/src/example_model/__init__.py new file mode 100644 index 0000000..0513b7b --- /dev/null +++ b/src/example_model/__init__.py @@ -0,0 +1,15 @@ +from pathlib import Path + +_WORKSPACE_PACKAGE = ( + Path(__file__).resolve().parents[2] + / "packages" + / "example_model" + / "src" + / "example_model" +) +if _WORKSPACE_PACKAGE.is_dir(): + __path__.append(str(_WORKSPACE_PACKAGE)) + +from .example_model import Binom_BP_Model # noqa: E402 + +__all__ = ["Binom_BP_Model"] diff --git a/tests/test_batch_generation_runner.py b/tests/test_batch_generation_runner.py new file mode 100644 index 0000000..37130ce --- /dev/null +++ b/tests/test_batch_generation_runner.py @@ -0,0 +1,91 @@ +import time +from io import StringIO + +from rich.console import Console + +from calibrationtools.batch_generation_runner import ( + BatchGenerationConfig, + BatchGenerationRequest, + BatchGenerationRunner, +) +from calibrationtools.particle import Particle +from calibrationtools.particle_population import ParticlePopulation +from calibrationtools.sampler_reporting import SamplerReporter +from calibrationtools.sampler_run_state import SamplerRunState + + +def test_batch_generation_runner_accepts_equal_tolerance(): + reporter = SamplerReporter( + verbose=True, + console=Console(file=StringIO(), force_terminal=True), + ) + runner = BatchGenerationRunner( + config=BatchGenerationConfig( + generation_particle_count=1, + tolerance_values=[0.5], + sample_particle_from_priors=lambda _: Particle({"p": 0.25}), + sample_and_perturb_particle=lambda _: Particle({"p": 0.25}), + particle_to_distance=lambda particle, **_: abs( + particle["p"] - 0.25 + ), + calculate_weight=lambda _: 1.0, + replace_particle_population=lambda _: None, + reporter=reporter, + ), + run_state=SamplerRunState(1, False), + ) + proposed_population = ParticlePopulation() + + considered = runner._accept_particle_batch( + generation=0, + proposed_population=proposed_population, + proposed_particles=[Particle({"p": 0.25})], + errs=[0.5], + ) + + assert considered == 1 + assert proposed_population.size == 1 + + +def test_batch_generation_runner_run_generation_records_state(): + stored_populations: list[ParticlePopulation] = [] + reporter = SamplerReporter( + verbose=True, + console=Console(file=StringIO(), force_terminal=True), + ) + run_state = SamplerRunState(1, False) + runner = BatchGenerationRunner( + config=BatchGenerationConfig( + generation_particle_count=1, + tolerance_values=[0.5], + sample_particle_from_priors=lambda _: Particle({"p": 0.25}), + sample_and_perturb_particle=lambda _: Particle({"p": 0.8}), + particle_to_distance=lambda particle, **_: abs( + particle["p"] - 0.25 + ), + calculate_weight=lambda _: 1.0, + replace_particle_population=stored_populations.append, + reporter=reporter, + ), + run_state=run_state, + ) + generation_start_time = time.time() + + generation_stats = runner.run_generation( + BatchGenerationRequest( + generation=0, + batchsize=1, + warmup=False, + chunksize=1, + executor=None, + overall_start_time=generation_start_time, + generation_start_time=generation_start_time, + particle_kwargs={}, + ) + ) + + assert generation_stats.successes == 1 + assert generation_stats.attempts == 1 + assert run_state.step_successes == [1] + assert run_state.step_attempts == [1] + assert stored_populations[0].size == 1 diff --git a/tests/test_particle_evaluator.py b/tests/test_particle_evaluator.py new file mode 100644 index 0000000..79be3f3 --- /dev/null +++ b/tests/test_particle_evaluator.py @@ -0,0 +1,41 @@ +import pytest + +from calibrationtools.particle import Particle +from calibrationtools.particle_evaluator import ParticleEvaluator + + +class DummyModelRunner: + def simulate(self, params): + return 0.5 + params["p"] + + +def test_particle_evaluator_distance(): + evaluator = ParticleEvaluator( + particles_to_params=lambda particle: particle, + outputs_to_distance=lambda model_output, target_data: abs( + model_output - target_data + ), + target_data=0.75, + model_runner=DummyModelRunner(), + ) + + distance = evaluator.distance(Particle({"p": 0.1})) + + assert distance == pytest.approx(0.15) + + +def test_particle_evaluator_distance_passes_kwargs(): + evaluator = ParticleEvaluator( + particles_to_params=lambda particle, scale: { + "p": particle["p"] * scale + }, + outputs_to_distance=lambda model_output, target_data: abs( + model_output - target_data + ), + target_data=0.9, + model_runner=DummyModelRunner(), + ) + + distance = evaluator.distance(Particle({"p": 0.2}), scale=2.0) + + assert distance == pytest.approx(0.0) diff --git a/tests/test_particlewise_generation_runner.py b/tests/test_particlewise_generation_runner.py new file mode 100644 index 0000000..31da523 --- /dev/null +++ b/tests/test_particlewise_generation_runner.py @@ -0,0 +1,96 @@ +import time +from io import StringIO + +from numpy.random import SeedSequence +from rich.console import Console + +from calibrationtools.particle import Particle +from calibrationtools.particle_population import ParticlePopulation +from calibrationtools.particlewise_generation_runner import ( + ParticlewiseGenerationConfig, + ParticlewiseGenerationRequest, + ParticlewiseGenerationRunner, +) +from calibrationtools.sampler_reporting import SamplerReporter +from calibrationtools.sampler_run_state import SamplerRunState +from calibrationtools.sampler_types import GeneratorSlot + + +def test_particlewise_generation_runner_sample_particles_until_accepted(): + reporter = SamplerReporter( + verbose=True, + console=Console(file=StringIO(), force_terminal=True), + ) + runner = ParticlewiseGenerationRunner( + config=ParticlewiseGenerationConfig( + generation_particle_count=1, + tolerance_values=[0.5], + seed_sequence=SeedSequence(123), + max_attempts_per_proposal=5, + sample_particle_from_priors=lambda _: Particle({"p": 0.2}), + sample_and_perturb_particle=lambda _: Particle({"p": 0.8}), + particle_to_distance=lambda particle, **_: abs( + particle["p"] - 0.2 + ), + calculate_weight=lambda _: 1.0, + replace_particle_population=lambda _: None, + reporter=reporter, + ), + run_state=SamplerRunState(1, False), + ) + + accepted_proposal = runner.sample_particles_until_accepted( + generator=GeneratorSlot(id=7, seed_sequence=SeedSequence(456)), + tolerance=0.1, + sample_method=lambda _: Particle({"p": 0.2}), + evaluation_kwargs={}, + ) + + assert accepted_proposal.slot_id == 7 + assert accepted_proposal.particle == Particle({"p": 0.2}) + assert accepted_proposal.attempts == 1 + + +def test_particlewise_generation_runner_run_generation_records_state(): + stored_populations: list[ParticlePopulation] = [] + run_state = SamplerRunState(1, False) + reporter = SamplerReporter( + verbose=True, + console=Console(file=StringIO(), force_terminal=True), + ) + runner = ParticlewiseGenerationRunner( + config=ParticlewiseGenerationConfig( + generation_particle_count=1, + tolerance_values=[0.5], + seed_sequence=SeedSequence(123), + max_attempts_per_proposal=5, + sample_particle_from_priors=lambda _: Particle({"p": 0.2}), + sample_and_perturb_particle=lambda _: Particle({"p": 0.8}), + particle_to_distance=lambda particle, **_: abs( + particle["p"] - 0.2 + ), + calculate_weight=lambda _: 1.0, + replace_particle_population=stored_populations.append, + reporter=reporter, + ), + run_state=run_state, + ) + generation_start_time = time.time() + + generation_stats = runner.run_generation( + ParticlewiseGenerationRequest( + generation=0, + n_workers=1, + parallel_executor=None, + overall_start_time=generation_start_time, + generation_start_time=generation_start_time, + particle_kwargs={}, + ) + ) + + assert generation_stats.successes == 1 + assert generation_stats.attempts == 1 + assert run_state.step_successes == [1] + assert run_state.step_attempts == [1] + assert len(run_state.generator_history[0]) == 1 + assert stored_populations[0].size == 1 diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 507044a..8f0cd0f 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -1,8 +1,12 @@ +import threading +import time from copy import deepcopy import pytest +import calibrationtools.sampler as sampler_module from calibrationtools.calibration_results import CalibrationResults +from calibrationtools.particle import Particle from calibrationtools.perturbation_kernel import ( IndependentKernels, NormalKernel, @@ -16,6 +20,28 @@ def simulate(self, params): return 0.5 + params["p"] +class UnpickleableModelRunner: + def __init__(self): + self.bad = lambda x: x + + def simulate(self, params): + return 0.5 + params["p"] + + +class NonThreadSafeModelRunner: + def __init__(self): + self._lock = threading.Lock() + + def simulate(self, params): + if not self._lock.acquire(blocking=False): + raise RuntimeError("concurrent simulate on shared runner") + try: + time.sleep(0.01) + return 0.5 + params["p"] + finally: + self._lock.release() + + def particles_to_params(particle): return particle @@ -40,37 +66,59 @@ def sampler(K, P, Vnorm) -> ABCSampler: ) -def test_abc_sampler_run(K, sampler: ABCSampler): +@pytest.fixture() +def sampler_with_archive(K, P, Vnorm) -> ABCSampler: + return ABCSampler( + generation_particle_count=5, + tolerance_values=[0.5, 0.1], + priors=P, + perturbation_kernel=K, + variance_adapter=Vnorm, + particles_to_params=particles_to_params, + outputs_to_distance=outputs_to_distance, + target_data=0.75, + model_runner=DummyModelRunner(), + seed=123, + keep_previous_population_data=True, + ) + + +def test_abc_sampler_run(K, sampler_with_archive: ABCSampler): original_std_dev = K.kernels[0].std_dev - results = sampler.run_serial() + results = sampler_with_archive.run_serial() assert isinstance(results, CalibrationResults) posterior_particles = results.posterior.particle_population # Assert success condition after run assert all( [ - count == sampler.generation_particle_count + count == sampler_with_archive.generation_particle_count for count in results.smc_step_successes ] ) # Assess population handling and updating assert ( - len(results.population_archive) == len(sampler.tolerance_values) - ) - 1 + len(results.population_archive) + == len(sampler_with_archive.tolerance_values) - 1 + ) for pop in results.population_archive.values(): - assert len(pop.particles) == sampler.generation_particle_count + assert ( + len(pop.particles) + == sampler_with_archive.generation_particle_count + ) assert pop.total_weight == pytest.approx(1.0) assert all( p not in posterior_particles.particles for p in pop.particles ) assert ( - len(posterior_particles.particles) == sampler.generation_particle_count + len(posterior_particles.particles) + == sampler_with_archive.generation_particle_count ) # Test that the perturbation kernel has been updated by adapter Vnorm - reset_perturbation = sampler._updater.perturbation_kernel + reset_perturbation = sampler_with_archive._updater.perturbation_kernel assert isinstance(reset_perturbation, IndependentKernels) reset_perturbation_kernels = reset_perturbation.kernels @@ -87,6 +135,14 @@ def test_abc_sampler_run(K, sampler: ABCSampler): assert isinstance(posterior_perturbation_kernels[1], SeedKernel) +def test_sampler_run_does_not_archive_previous_population_by_default( + sampler: ABCSampler, +): + results = sampler.run_serial() + + assert results.population_archive == {} + + def test_sampler_run_repeatable(sampler): # Sampler produces same results when seed is set results1 = sampler.run_serial() @@ -97,6 +153,43 @@ def test_sampler_run_repeatable(sampler): assert results1.acceptance_rates == results2.acceptance_rates +def test_sampler_particle_to_distance_delegates_to_evaluator(sampler): + class RecordingEvaluator: + def __init__(self): + self.calls = [] + + def distance(self, particle, **kwargs): + self.calls.append((particle, kwargs)) + return 1.23 + + recording_evaluator = RecordingEvaluator() + sampler._particle_evaluator = recording_evaluator + particle = Particle({"p": 0.1, "seed": 0}) + + distance = sampler.particle_to_distance(particle, scale=2.0) + + assert distance == 1.23 + assert recording_evaluator.calls == [(particle, {"scale": 2.0})] + + +def test_sampler_run_resets_internal_run_state(sampler_with_archive): + results1 = sampler_with_archive.run_serial() + results2 = sampler_with_archive.run_serial() + + expected_archive_size = len(sampler_with_archive.tolerance_values) - 1 + + assert sampler_with_archive.step_successes == [0] * len( + sampler_with_archive.tolerance_values + ) + assert sampler_with_archive.step_attempts == [0] * len( + sampler_with_archive.tolerance_values + ) + assert sampler_with_archive.generator_history == {} + assert sampler_with_archive.population_archive == {} + assert len(results1.population_archive) == expected_archive_size + assert len(results2.population_archive) == expected_archive_size + + def test_sample_from_priors(sampler): # Test that sampling from priors works before any population is set states = sampler.sample_priors(5) @@ -142,12 +235,150 @@ def test_sampler_run_parallel_equal(sampler: ABCSampler): for gen_serial, gen_parallel in zip( generator_list, parallel_generator_list ): - assert gen_serial["id"] == gen_parallel["id"] + assert gen_serial.id == gen_parallel.id assert ( - gen_serial["seed_sequence"].entropy - == gen_parallel["seed_sequence"].entropy + gen_serial.seed_sequence.entropy + == gen_parallel.seed_sequence.entropy ) assert ( - gen_serial["seed_sequence"].spawn_key - == gen_parallel["seed_sequence"].spawn_key + gen_serial.seed_sequence.spawn_key + == gen_parallel.seed_sequence.spawn_key ) + + +def test_sampler_run_parallel_with_unpickleable_runner(K, P, Vnorm): + sampler = ABCSampler( + generation_particle_count=5, + tolerance_values=[0.5, 0.1], + priors=P, + perturbation_kernel=K, + variance_adapter=Vnorm, + particles_to_params=particles_to_params, + outputs_to_distance=outputs_to_distance, + target_data=0.75, + model_runner=UnpickleableModelRunner(), + seed=123, + ) + + results = sampler.run_parallel(max_workers=2) + + assert isinstance(results, CalibrationResults) + + +def test_sampler_run_parallel_batches_repeatable(sampler): + results1 = sampler.run_parallel_batches( + max_workers=2, chunksize=2, batchsize=4 + ) + results2 = sampler.run_parallel_batches( + max_workers=1, chunksize=1, batchsize=4 + ) + + assert results1.point_estimates == results2.point_estimates + assert results1.ess == results2.ess + assert results1.acceptance_rates == results2.acceptance_rates + assert ( + results1.posterior.particle_population.particles + == results2.posterior.particle_population.particles + ) + + +def test_sampler_run_parallel_batches_with_unpickleable_runner(K, P, Vnorm): + sampler = ABCSampler( + generation_particle_count=5, + tolerance_values=[0.5, 0.1], + priors=P, + perturbation_kernel=K, + variance_adapter=Vnorm, + particles_to_params=particles_to_params, + outputs_to_distance=outputs_to_distance, + target_data=0.75, + model_runner=UnpickleableModelRunner(), + seed=123, + ) + + results = sampler.run_parallel_batches( + max_workers=2, chunksize=2, batchsize=4 + ) + + assert isinstance(results, CalibrationResults) + + +def test_sampler_parallel_worker_count_default_is_configured( + K, P, Vnorm, monkeypatch +): + recorded = {} + real_executor = sampler_module.ThreadPoolExecutor + + class RecordingExecutor(real_executor): + def __init__(self, *args, **kwargs): + recorded["max_workers"] = kwargs.get("max_workers") + super().__init__(*args, **kwargs) + + monkeypatch.setattr( + sampler_module, "ThreadPoolExecutor", RecordingExecutor + ) + + sampler = ABCSampler( + generation_particle_count=5, + tolerance_values=[0.5, 0.1], + priors=P, + perturbation_kernel=K, + variance_adapter=Vnorm, + particles_to_params=particles_to_params, + outputs_to_distance=outputs_to_distance, + target_data=0.75, + model_runner=DummyModelRunner(), + parallel_worker_count=3, + seed=123, + ) + + sampler.run_parallel() + + assert recorded["max_workers"] == 3 + + +def test_sampler_parallel_worker_failure_does_not_leak_future_errors( + K, P, Vnorm, capfd +): + sampler = ABCSampler( + generation_particle_count=5, + tolerance_values=[0.5], + priors=P, + perturbation_kernel=K, + variance_adapter=Vnorm, + particles_to_params=particles_to_params, + outputs_to_distance=outputs_to_distance, + target_data=0.75, + model_runner=NonThreadSafeModelRunner(), + seed=123, + verbose=False, + ) + + with pytest.raises( + RuntimeError, match="concurrent simulate on shared runner" + ): + sampler.run_parallel(max_workers=2) + + captured = capfd.readouterr() + assert "Future exception was never retrieved" not in captured.err + + +def test_sampler_verbose_false_suppresses_output(K, P, Vnorm, capfd): + sampler = ABCSampler( + generation_particle_count=5, + tolerance_values=[0.5], + priors=P, + perturbation_kernel=K, + variance_adapter=Vnorm, + particles_to_params=particles_to_params, + outputs_to_distance=outputs_to_distance, + target_data=0.75, + model_runner=DummyModelRunner(), + seed=123, + verbose=False, + ) + + sampler.run_serial() + + captured = capfd.readouterr() + assert captured.out == "" diff --git a/tests/test_sampler_run_state.py b/tests/test_sampler_run_state.py new file mode 100644 index 0000000..cb8777e --- /dev/null +++ b/tests/test_sampler_run_state.py @@ -0,0 +1,52 @@ +from calibrationtools.particle_population import ParticlePopulation +from calibrationtools.sampler_run_state import SamplerRunState +from calibrationtools.sampler_types import GeneratorSlot + + +def test_sampler_run_state_archives_previous_population_when_enabled( + particle_population, +): + state = SamplerRunState( + generation_count=2, + keep_previous_population_data=True, + ) + state.archive_population(ParticlePopulation()) + assert state.population_archive == {} + + state.archive_population(particle_population) + + assert state.population_archive == {0: particle_population} + + +def test_sampler_run_state_does_not_archive_previous_population_when_disabled( + particle_population, +): + state = SamplerRunState( + generation_count=2, + keep_previous_population_data=False, + ) + + state.archive_population(particle_population) + + assert state.population_archive == {} + + +def test_sampler_run_state_reset_clears_bookkeeping( + particle_population, seed_sequence +): + state = SamplerRunState( + generation_count=2, + keep_previous_population_data=True, + ) + generator_slots = [GeneratorSlot(id=0, seed_sequence=seed_sequence)] + + state.record_generation_history(0, generator_slots) + state.record_attempts(generation=0, attempts=4, successes=1) + state.archive_population(particle_population) + + state.reset() + + assert state.step_successes == [0, 0] + assert state.step_attempts == [0, 0] + assert state.generator_history == {} + assert state.population_archive == {}