Skip to content

Harden Keras DP accounting and fit validation#228

Draft
staticpayload wants to merge 8 commits into
google-deepmind:mainfrom
scalarian:fix/keras-accounting-sampling-method
Draft

Harden Keras DP accounting and fit validation#228
staticpayload wants to merge 8 commits into
google-deepmind:mainfrom
scalarian:fix/keras-accounting-sampling-method

Conversation

@staticpayload
Copy link
Copy Markdown
Contributor

@staticpayload staticpayload commented Apr 21, 2026

This PR hardens the Keras DP fit() path so the runtime, privacy accounting, and examples all describe the same training contract. The core goal is simple: the sampling method, batch size, training step count, and validated inputs should match the assumptions used for privacy accounting.

What Changed

Accounting now follows optimizer updates. With gradient accumulation enabled, privacy budget enforcement counts real optimizer update steps instead of raw minibatches. The optimizer iteration counter is now the source of truth, and train_steps is documented as optimizer updates.

Sampling assumptions are explicit. jax_privacy/accounting/calibrate.py now accepts an explicit sampling method, so calibrated noise is tied to the same sampling assumption used by training. The Keras example and docs default to fixed-size batching with fixed-batch accounting, while Poisson sampling in fit() remains available as an opt-in.

fit() rejects unaccountable inputs early. DP-incompatible setups now fail with readable errors before the jitted training step runs. This includes validation_split, generator-like inputs without steps_per_epoch, fixed-batch array inputs with implicit partial tails, and prebatched tf.data / PyDataset structures whose batch shape does not match the DP config.

Noise caching and Poisson seeding are safer. Repeated fit() calls with the same seed no longer replay the same internal Poisson sample stream, and cached calibrated noise is preserved when fit() only resolves an implicit sampling method.

Examples and tests assert post-fit quality correctly. The Keras example and e2e tests now use model.evaluate(...) after training for quality checks instead of relying on the final fit() history metric. That matters because the example uses Dropout, so the in-training accuracy can differ substantially from the actual trained model's evaluated accuracy.

Why

Before this branch, several configurations could make training and accounting quietly disagree. Gradient accumulation could make the privacy ledger count a different notion of step than the optimizer, calibration could assume a different sampling method than fit(), and some input pipelines could make the declared train_size or batch shape inaccurate without being rejected up front.

The stricter validation is intentional: if the library cannot account for a setup precisely, it should say so at call time instead of failing deep inside compiled code or, worse, training under the wrong privacy assumptions.

Review Guide

  • jax_privacy/keras_api.py: step counting, fit() validation, Poisson seeding, and calibrated-noise cache behavior.
  • jax_privacy/accounting/calibrate.py: explicit sampling-method plumbing for noise calibration.
  • examples/keras_api_example.py, README.md, docs/keras_api.rst, and docs/overview.md: public contract and example defaults.
  • tests/keras_api_test.py, tests/keras_api_e2e_test.py, tests/batch_selection_test.py, and tests/accounting/calibrate_test.py: regression coverage for the hardened behavior.

Validation

Most recent local validation after the final review fix:

  • PYTHONPATH=. ./.venv/bin/python -m pytest -q tests/keras_api_test.py
  • PYTHONPATH=. ./.venv/bin/python -m pytest -q tests/keras_api_e2e_test.py
  • PYTHONPATH=. ./.venv/bin/python -m py_compile examples/keras_api_example.py tests/keras_api_e2e_test.py tests/keras_api_test.py
  • ./.venv/bin/pyink --check --diff examples/keras_api_example.py tests/keras_api_e2e_test.py tests/keras_api_test.py
  • ./.venv/bin/flake8 examples/keras_api_example.py tests/keras_api_e2e_test.py tests/keras_api_test.py
  • PYTHONPATH=. ./.venv/bin/python examples/keras_api_example.py

The example now completes with evaluated DP train accuracy 0.6180 and evaluated validation accuracy 0.6111 for the seeded run.

@staticpayload staticpayload marked this pull request as ready for review April 22, 2026 04:42
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 1, 2026

This PR has been idle for 7 days. Please provide an update or review.

@github-actions github-actions Bot added the Stale label May 1, 2026
@staticpayload staticpayload marked this pull request as draft May 23, 2026 03:56
@github-actions github-actions Bot removed the Stale label May 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant