Skip to content

Implement MoNNA decentralized simulation#32

Draft
moh wants to merge 28 commits into
mainfrom
26-implement-simulation-from-monna
Draft

Implement MoNNA decentralized simulation#32
moh wants to merge 28 commits into
mainfrom
26-implement-simulation-from-monna

Conversation

@moh

@moh moh commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Summary

Implements the MoNNA (Momentum Nearest-Neighbor Averaging) decentralized
Byzantine-resilient learning simulation as a library module under
krum.simulations.monna. Implements #26.

Each round runs one local momentum-SGD phase per honest worker, then a
model-mixing phase: every worker aggregates the n - f models it receives,
keeping the n - 2f closest to its own (pivot-anchored nearest-neighbor
average by default).

What's in here

  • MonnaSimulation — class-based runner. step() returns a per-round
    snapshot (StepResult); run(rounds) collects snapshots.
  • Simulation base classSimulation[StepResultT] ABC with the
    abstract step() contract and the shared run(rounds) loop; MoNNA
    subclasses it.
  • Stateless primitive integration — attacks/aggregators are passed as
    classes and invoked via Attack.generate / Aggregator.aggregate
    (no __call__), consistent with the rest of the library, with
    attack_kwargs / aggregator_kwargs passthrough.
  • byzantine_reach"all" (worst case: every Byzantine model reaches
    every worker) and "sampled" (responders drawn uniformly).
  • Google-style docstrings throughout.

Notable fixes

  • Stale flat-gradient view after zero_grad(set_to_none=True) (re-sync via
    relink_gradients()).
  • Seeded responder sampling crashed on GPU (CPU generator vs CUDA device);
    now draws on CPU and moves indices to the data's device.

Testing

  • pytest tests/simulations/monna/ — 13 passing.
  • ruff check clean.

Out of scope (by design, per the brief)

Minimal class for the first milestone: no metrics/evaluation, no device
management, no LR scheduling. Experiment scripts live outside the package.

@moh moh linked an issue Jun 8, 2026 that may be closed by this pull request
@moh moh self-assigned this Jun 8, 2026
@moh moh force-pushed the 26-implement-simulation-from-monna branch from 6b741b4 to 07d82bc Compare June 8, 2026 22:17
Comment thread krum/simulations/base.py Outdated
Comment thread krum/simulations/base.py Outdated
Comment thread krum/simulations/base.py Outdated
raise NotImplementedError

def run(self, rounds: int) -> list[StepResultT]:
"""Execute several simulation rounds.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this logic in the implementation

Comment thread krum/simulations/__init__.py Outdated

from krum.simulations.base import Simulation

__all__ = ["Simulation"]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this line?

moh added 24 commits June 11, 2026 08:59
Select each node's received vectors from a random permutation of the other n - 1 nodes instead of always including Byzantine vectors.

Run nearest-neighbor averaging over each local n - f coordination set so the final average keeps n - 2f vectors.
Replace the functional MoNNA implementation (MonnaConfig, MonnaState/
initial_state, and the free-function protocol of run_simulation/run_round/
mix_each_worker/...) with a single MonnaSimulation class that owns
configuration, state, data streams, attack, and aggregator.

- step()/run() drive the round; each phase is a small named method
  (collect_worker_batches, compute_honest_worker_gradients,
  update_local_momentum, compute_local_parameter_updates,
  generate_byzantine_models, aggregate_over_received_nodes), and
  commit_step returns a primitive snapshot dict.
- Validate all constructor arguments up front with precise errors.
- Make aggregator a required Aggregator (no implicit NearestNeighbor
  default) and drop the callable/hasattr duck-typing.
- Delete config.py and state.py; export only MonnaSimulation.
- Rewrite tests against the class API.
The repackaged Aggregator base no longer takes n/f or validates them, so
move the n/f checks into NearestNeighbor.__init__ and call super().__init__()
with no arguments, matching the other aggregators (e.g. MultiKrum).
The rule only needs the number of vectors to keep, not n and f: it
averages the closest vectors to a pivot. Replace the (n, f) constructor
with a single num_closest, validating it is positive and not larger than
the candidate count at aggregation time. Rename the class and module to
NearestNeighborAverage to describe what it does.

The choice of how many to keep (n - f, n - 2f, ...) is now the caller's
policy; the primitive is reusable outside MoNNA.
Make the aggregator optional again and default it, inside MonnaSimulation,
to NearestNeighborAverage(num_closest=num_honest - num_byzantine) == n - 2f.
The n - 2f sizing is intrinsic to MoNNA, so the simulation now owns it
instead of the caller computing it when constructing the aggregator.
A supplied aggregator still overrides the default for comparing other
robust aggregation rules.
Each honest worker receives n - f models per round. The new
byzantine_reach parameter controls how the f Byzantine models are
placed in that set:

- "all" (default): every Byzantine model reaches every honest worker
  and only the honest responders are randomized, matching the reference
  MoNNA implementation's worst-case adversary.
- "sampled": responders are drawn uniformly from all other nodes, so a
  worker receives 0..f Byzantine models (the previous behavior).

Both modes keep the received-set size at n - f. Adds tests covering set
size, self-first ordering, per-worker Byzantine composition in each
mode, validation, and the f = 0 edge case.
After rebasing onto 27, MoNNA's primitives are stateless classmethods,
so adapt the simulation to them:

- NearestNeighborAverage becomes a stateless @classmethod (num_closest
  and pivot are per-call keyword args), imports the base from the package
  root, uses `from torch import ...`, and accepts the `out` parameter.
- MonnaSimulation holds the attack and aggregator as stateless callables.
  The default aggregator is partial(NearestNeighborAverage.aggregate,
  num_closest=n-2f); the attack is invoked as attack(models, f=...). The
  obsolete pivot try/except is dropped since `**specialized` absorbs an
  unused pivot.

Update the tests to call the classmethods directly and to pass
SignFlipAttack.generate as the attack callable.
Model.gradients caches a flat view of each parameter's .grad. The per-worker
loop calls zero_grad(set_to_none=True) then backward(), which replaces every
.grad with a fresh tensor, leaving the cached flat view pointing at the prior
worker's storage. From the second read on, model.gradients returned stale
gradients.

Call relink_gradients() after backward() to re-sync the flat view, as the
Model contract prescribes for set_to_none=True.
Add Args/Returns/Raises sections to MonnaSimulation methods to match the
style used across krum.primitives.
Attacks and aggregators no longer support __call__; they expose
classmethods (Attack.generate, Aggregator.aggregate). MonnaSimulation now
takes the attack/aggregator classes plus attack_kwargs/aggregator_kwargs
and dispatches via the classmethods, matching CentralisedSimulation and
giving the library one invocation convention.
Update MoNNA protocol tests for the stateless class API: pass SignFlipAttack
and NearestNeighborAverage classes (plus aggregator_kwargs) instead of bound
.generate methods and partials.
self.generator is a CPU generator, which torch.randperm rejects when called
with device='cuda'. Draw the permutation on CPU and move the selected
indices to the target device, so seeded sampling no longer crashes on GPU.
CPU determinism is unchanged.
Add an abstract Simulation[StepResultT] base in krum.simulations with an
abstract step() contract and a shared run(rounds) loop, exported from the
package root. Gives simulation protocols one common interface.
Subclass Simulation[StepResult] and drop the now-redundant run() override,
which is identical to the base loop. No behavior change.
ruff format: the make_simulation call fits on one line after the
.generate-to-class change.
- Make StepResult a TypedDict so step() snapshot fields are typed as
  tensors (not int | Tensor), fixing allclose/.shape call sites.
- Narrow self.attack before .generate (guaranteed set when num_byzantine>0).
- Widen NearestNeighborAverage.aggregate gradients to Sequence[Tensor] |
  Tensor, matching the Aggregator base.
- Convert mypy-style '# type: ignore' to ty's '# ty: ignore' in the
  intentional negative tests.
Drop the abstract Simulation base in favor of a standalone
MonnaSimulation: move the run loop into the concrete class and clean
up the now-empty simulations package exports.
Add unit tests asserting run drives step once per round in order,
no-ops on zero rounds, and rejects a negative round count.
Small CPU-friendly runner demonstrating the MonnaSimulation path on
MNIST/FakeData with IID or Dirichlet partitioning. Shared as a
discussion artifact; may be reverted depending on team feedback.
Match the docstring/signature convention introduced on main: drop the
duplicated Args/Returns/Raises from the class docstring (kept on
aggregate) and move out ahead of the keyword-only specialized params
(num_closest, pivot). Update the keyword-only test accordingly.
@moh moh force-pushed the 26-implement-simulation-from-monna branch from 9206038 to cd9e6e4 Compare June 11, 2026 09:23
Move the MoNNA simulation into a new krum.simulations.decentralised
package and split the shared decentralised round loop — local momentum
update, model mixing, Byzantine generation, snapshots, and the resumable
run driver — into a DecentralisedSimulation base.

MonnaSimulation now inherits the base and supplies only
gather_received_models (the communication-topology seam) plus its reach
helpers, mirroring the centralised package layout. Update the docs
reference and the experiment import to the new module paths.
moh and others added 3 commits June 12, 2026 12:56
Relocate the protocol tests to tests/simulations/decentralised/ and
point the imports at the new krum.simulations.decentralised module.
Take n (total workers) and f (Byzantine count) in the decentralised
simulation constructors instead of num_honest/num_byzantine, deriving
num_honest = n - f internally. Align the DecentralisedSimulation base
argument order with MonnaSimulation for the shared parameters.

Update the tests and the MoNNA experiment caller accordingly.
return self.net(inputs)


def parse_args() -> argparse.Namespace:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On peut faire un script plus simple, avec des valeurs fixées dans le script lui-même. Pas besoin de parser des arguments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement simulation from MoNNA

2 participants