Harden Keras DP accounting and fit validation#228
Draft
staticpayload wants to merge 8 commits into
Draft
Conversation
|
This PR has been idle for 7 days. Please provide an update or review. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR hardens the Keras DP
fit()path so the runtime, privacy accounting, and examples all describe the same training contract. The core goal is simple: the sampling method, batch size, training step count, and validated inputs should match the assumptions used for privacy accounting.What Changed
Accounting now follows optimizer updates. With gradient accumulation enabled, privacy budget enforcement counts real optimizer update steps instead of raw minibatches. The optimizer iteration counter is now the source of truth, and
train_stepsis documented as optimizer updates.Sampling assumptions are explicit.
jax_privacy/accounting/calibrate.pynow accepts an explicit sampling method, so calibrated noise is tied to the same sampling assumption used by training. The Keras example and docs default to fixed-size batching with fixed-batch accounting, while Poisson sampling infit()remains available as an opt-in.fit()rejects unaccountable inputs early. DP-incompatible setups now fail with readable errors before the jitted training step runs. This includesvalidation_split, generator-like inputs withoutsteps_per_epoch, fixed-batch array inputs with implicit partial tails, and prebatchedtf.data/PyDatasetstructures whose batch shape does not match the DP config.Noise caching and Poisson seeding are safer. Repeated
fit()calls with the same seed no longer replay the same internal Poisson sample stream, and cached calibrated noise is preserved whenfit()only resolves an implicit sampling method.Examples and tests assert post-fit quality correctly. The Keras example and e2e tests now use
model.evaluate(...)after training for quality checks instead of relying on the finalfit()history metric. That matters because the example uses Dropout, so the in-training accuracy can differ substantially from the actual trained model's evaluated accuracy.Why
Before this branch, several configurations could make training and accounting quietly disagree. Gradient accumulation could make the privacy ledger count a different notion of step than the optimizer, calibration could assume a different sampling method than
fit(), and some input pipelines could make the declaredtrain_sizeor batch shape inaccurate without being rejected up front.The stricter validation is intentional: if the library cannot account for a setup precisely, it should say so at call time instead of failing deep inside compiled code or, worse, training under the wrong privacy assumptions.
Review Guide
jax_privacy/keras_api.py: step counting,fit()validation, Poisson seeding, and calibrated-noise cache behavior.jax_privacy/accounting/calibrate.py: explicit sampling-method plumbing for noise calibration.examples/keras_api_example.py,README.md,docs/keras_api.rst, anddocs/overview.md: public contract and example defaults.tests/keras_api_test.py,tests/keras_api_e2e_test.py,tests/batch_selection_test.py, andtests/accounting/calibrate_test.py: regression coverage for the hardened behavior.Validation
Most recent local validation after the final review fix:
PYTHONPATH=. ./.venv/bin/python -m pytest -q tests/keras_api_test.pyPYTHONPATH=. ./.venv/bin/python -m pytest -q tests/keras_api_e2e_test.pyPYTHONPATH=. ./.venv/bin/python -m py_compile examples/keras_api_example.py tests/keras_api_e2e_test.py tests/keras_api_test.py./.venv/bin/pyink --check --diff examples/keras_api_example.py tests/keras_api_e2e_test.py tests/keras_api_test.py./.venv/bin/flake8 examples/keras_api_example.py tests/keras_api_e2e_test.py tests/keras_api_test.pyPYTHONPATH=. ./.venv/bin/python examples/keras_api_example.pyThe example now completes with evaluated DP train accuracy
0.6180and evaluated validation accuracy0.6111for the seeded run.