Implement MoNNA decentralized simulation#32
Draft
moh wants to merge 28 commits into
Draft
Conversation
6b741b4 to
07d82bc
Compare
pevab
requested changes
Jun 9, 2026
| raise NotImplementedError | ||
|
|
||
| def run(self, rounds: int) -> list[StepResultT]: | ||
| """Execute several simulation rounds. |
Collaborator
There was a problem hiding this comment.
put this logic in the implementation
|
|
||
| from krum.simulations.base import Simulation | ||
|
|
||
| __all__ = ["Simulation"] |
Select each node's received vectors from a random permutation of the other n - 1 nodes instead of always including Byzantine vectors. Run nearest-neighbor averaging over each local n - f coordination set so the final average keeps n - 2f vectors.
Replace the functional MoNNA implementation (MonnaConfig, MonnaState/ initial_state, and the free-function protocol of run_simulation/run_round/ mix_each_worker/...) with a single MonnaSimulation class that owns configuration, state, data streams, attack, and aggregator. - step()/run() drive the round; each phase is a small named method (collect_worker_batches, compute_honest_worker_gradients, update_local_momentum, compute_local_parameter_updates, generate_byzantine_models, aggregate_over_received_nodes), and commit_step returns a primitive snapshot dict. - Validate all constructor arguments up front with precise errors. - Make aggregator a required Aggregator (no implicit NearestNeighbor default) and drop the callable/hasattr duck-typing. - Delete config.py and state.py; export only MonnaSimulation. - Rewrite tests against the class API.
The repackaged Aggregator base no longer takes n/f or validates them, so move the n/f checks into NearestNeighbor.__init__ and call super().__init__() with no arguments, matching the other aggregators (e.g. MultiKrum).
The rule only needs the number of vectors to keep, not n and f: it averages the closest vectors to a pivot. Replace the (n, f) constructor with a single num_closest, validating it is positive and not larger than the candidate count at aggregation time. Rename the class and module to NearestNeighborAverage to describe what it does. The choice of how many to keep (n - f, n - 2f, ...) is now the caller's policy; the primitive is reusable outside MoNNA.
Make the aggregator optional again and default it, inside MonnaSimulation, to NearestNeighborAverage(num_closest=num_honest - num_byzantine) == n - 2f. The n - 2f sizing is intrinsic to MoNNA, so the simulation now owns it instead of the caller computing it when constructing the aggregator. A supplied aggregator still overrides the default for comparing other robust aggregation rules.
Each honest worker receives n - f models per round. The new byzantine_reach parameter controls how the f Byzantine models are placed in that set: - "all" (default): every Byzantine model reaches every honest worker and only the honest responders are randomized, matching the reference MoNNA implementation's worst-case adversary. - "sampled": responders are drawn uniformly from all other nodes, so a worker receives 0..f Byzantine models (the previous behavior). Both modes keep the received-set size at n - f. Adds tests covering set size, self-first ordering, per-worker Byzantine composition in each mode, validation, and the f = 0 edge case.
After rebasing onto 27, MoNNA's primitives are stateless classmethods, so adapt the simulation to them: - NearestNeighborAverage becomes a stateless @classmethod (num_closest and pivot are per-call keyword args), imports the base from the package root, uses `from torch import ...`, and accepts the `out` parameter. - MonnaSimulation holds the attack and aggregator as stateless callables. The default aggregator is partial(NearestNeighborAverage.aggregate, num_closest=n-2f); the attack is invoked as attack(models, f=...). The obsolete pivot try/except is dropped since `**specialized` absorbs an unused pivot. Update the tests to call the classmethods directly and to pass SignFlipAttack.generate as the attack callable.
Model.gradients caches a flat view of each parameter's .grad. The per-worker loop calls zero_grad(set_to_none=True) then backward(), which replaces every .grad with a fresh tensor, leaving the cached flat view pointing at the prior worker's storage. From the second read on, model.gradients returned stale gradients. Call relink_gradients() after backward() to re-sync the flat view, as the Model contract prescribes for set_to_none=True.
Add Args/Returns/Raises sections to MonnaSimulation methods to match the style used across krum.primitives.
Attacks and aggregators no longer support __call__; they expose classmethods (Attack.generate, Aggregator.aggregate). MonnaSimulation now takes the attack/aggregator classes plus attack_kwargs/aggregator_kwargs and dispatches via the classmethods, matching CentralisedSimulation and giving the library one invocation convention.
Update MoNNA protocol tests for the stateless class API: pass SignFlipAttack and NearestNeighborAverage classes (plus aggregator_kwargs) instead of bound .generate methods and partials.
self.generator is a CPU generator, which torch.randperm rejects when called with device='cuda'. Draw the permutation on CPU and move the selected indices to the target device, so seeded sampling no longer crashes on GPU. CPU determinism is unchanged.
Add an abstract Simulation[StepResultT] base in krum.simulations with an abstract step() contract and a shared run(rounds) loop, exported from the package root. Gives simulation protocols one common interface.
Subclass Simulation[StepResult] and drop the now-redundant run() override, which is identical to the base loop. No behavior change.
ruff format: the make_simulation call fits on one line after the .generate-to-class change.
- Make StepResult a TypedDict so step() snapshot fields are typed as tensors (not int | Tensor), fixing allclose/.shape call sites. - Narrow self.attack before .generate (guaranteed set when num_byzantine>0). - Widen NearestNeighborAverage.aggregate gradients to Sequence[Tensor] | Tensor, matching the Aggregator base. - Convert mypy-style '# type: ignore' to ty's '# ty: ignore' in the intentional negative tests.
Drop the abstract Simulation base in favor of a standalone MonnaSimulation: move the run loop into the concrete class and clean up the now-empty simulations package exports.
Add unit tests asserting run drives step once per round in order, no-ops on zero rounds, and rejects a negative round count.
Small CPU-friendly runner demonstrating the MonnaSimulation path on MNIST/FakeData with IID or Dirichlet partitioning. Shared as a discussion artifact; may be reverted depending on team feedback.
Match the docstring/signature convention introduced on main: drop the duplicated Args/Returns/Raises from the class docstring (kept on aggregate) and move out ahead of the keyword-only specialized params (num_closest, pivot). Update the keyword-only test accordingly.
9206038 to
cd9e6e4
Compare
Move the MoNNA simulation into a new krum.simulations.decentralised package and split the shared decentralised round loop — local momentum update, model mixing, Byzantine generation, snapshots, and the resumable run driver — into a DecentralisedSimulation base. MonnaSimulation now inherits the base and supplies only gather_received_models (the communication-topology seam) plus its reach helpers, mirroring the centralised package layout. Update the docs reference and the experiment import to the new module paths.
Relocate the protocol tests to tests/simulations/decentralised/ and point the imports at the new krum.simulations.decentralised module.
Take n (total workers) and f (Byzantine count) in the decentralised simulation constructors instead of num_honest/num_byzantine, deriving num_honest = n - f internally. Align the DecentralisedSimulation base argument order with MonnaSimulation for the shared parameters. Update the tests and the MoNNA experiment caller accordingly.
pevab
approved these changes
Jun 12, 2026
| return self.net(inputs) | ||
|
|
||
|
|
||
| def parse_args() -> argparse.Namespace: |
Collaborator
There was a problem hiding this comment.
On peut faire un script plus simple, avec des valeurs fixées dans le script lui-même. Pas besoin de parser des arguments.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements the MoNNA (Momentum Nearest-Neighbor Averaging) decentralized
Byzantine-resilient learning simulation as a library module under
krum.simulations.monna. Implements #26.Each round runs one local momentum-SGD phase per honest worker, then a
model-mixing phase: every worker aggregates the
n - fmodels it receives,keeping the
n - 2fclosest to its own (pivot-anchored nearest-neighboraverage by default).
What's in here
MonnaSimulation— class-based runner.step()returns a per-roundsnapshot (
StepResult);run(rounds)collects snapshots.Simulationbase class —Simulation[StepResultT]ABC with theabstract
step()contract and the sharedrun(rounds)loop; MoNNAsubclasses it.
classes and invoked via
Attack.generate/Aggregator.aggregate(no
__call__), consistent with the rest of the library, withattack_kwargs/aggregator_kwargspassthrough.byzantine_reach—"all"(worst case: every Byzantine model reachesevery worker) and
"sampled"(responders drawn uniformly).Notable fixes
zero_grad(set_to_none=True)(re-sync viarelink_gradients()).now draws on CPU and moves indices to the data's device.
Testing
pytest tests/simulations/monna/— 13 passing.ruff checkclean.Out of scope (by design, per the brief)
Minimal class for the first milestone: no metrics/evaluation, no device
management, no LR scheduling. Experiment scripts live outside the package.