Skip to content

RubenHaisma/rl-augment

Repository files navigation

rl-augment

Learn an image-augmentation policy with reinforcement learning — then prove it was actually worth it.

A group-relative policy gradient (the GRPO trick — no value network) searches which augmentations to apply and how strongly, scored by a verifiable reward: held-out classifier accuracy. Then it does the part most repos skip — it honestly benchmarks the searched policy against RandAugment and TrivialAugment, because the interesting question was never "does augmentation help" but "was searching for the policy worth it, or would random augmentation have done just as well?"

CLI-first, --json on every command, load-bearing exit codes — so a coding agent (Claude Code, Codex, Cursor) can drive the whole search loop and parse the result with no UI and no notebook.

reinforcement-learning · data-augmentation · computer-vision · autoaugment · randaugment · policy-gradient · grpo · cli · agents


The honest result (read this first)

On scikit-learn digits, deliberately starved to 80 training images so augmentation matters, with a KNN proxy and a fixed seed:

strategy held-out val acc note
no augmentation 0.860 the floor
TrivialAugment 0.840 worse than nothing
RandAugment 0.860 no better than nothing
RL search (this repo) 0.900 +0.040 over no-aug, and it beats both random strategies

The search also rediscovers, from reward alone, that digits don't flip: flip_h sinks to the lowest include-probability (~0.11) while rotate / shear / zoom / translate rise. Nobody told it — it learned that a horizontal flip turns a valid digit into an invalid one.

But here's the part that makes this repo honest. On the large untouched test split (n≈1400), that same searched policy scores 0.889 vs 0.892 for no-aug — the lift vanishes. Searching augmentation on a small validation set partly overfits the validation set; the gain doesn't fully transfer. That is the real, load-bearing finding — and it is literally RandAugment's original motivation, reproduced. Most portfolio repos would report only the +0.04 and call it a win. This one reports the test number too, because a metric without its disconfirming baseline is marketing.

So what's the takeaway? Targeted RL search genuinely beats blind random augmentation in the small-data regime — but proxy-search gains need a held-out test to be believed. The repo is built to make that distinction impossible to hide.

The cnn backend (CIFAR-10 + a small conv net, on a GPU) is where augmentation reliably helps and the gain transfers — same search loop, real images. See Backends.


Quickstart

Needs uv. The verified path is CPU-only, no torch, ~1 second.

uv sync --extra dev
uv run rl-augment doctor --json
uv run rl-augment search    configs/digits-search.yaml --json   # the RL search
uv run rl-augment baselines configs/digits-search.yaml --json   # vs RandAugment/TrivialAugment
uv run rl-augment eval      configs/digits-search.yaml --json   # the honest TEST-split number

Or run it as a tool with no clone:

uvx --from git+https://github.com/RubenHaisma/rl-augment rl-augment doctor --json

See the found policy applied to real images:

uv run rl-augment sample configs/digits-search.yaml --n 8   # writes artifacts/.../samples.png

How the search works

We optimise a distribution over augmentation policies, factored per operation:

  • a Bernoulli P(include op o) = sigmoid(θ_incl[o]), and
  • a categorical over magnitude bins softmax(θ_mag[o]).

Each step samples a group of G concrete policies, scores each by the reward (a trained proxy's held-out accuracy), and updates with the GRPO move:

advantage_i = (reward_i − group_mean) / (group_std + ε)     # group-relative — no value network
θ          += lr · mean_i [ advantage_i · ∇θ log p(policy_i) ]   # REINFORCE, closed-form score fn

The group mean is the baseline — that's the defining GRPO trick, here applied to a computer-vision search instead of an LLM. The reward stays verifiable: a deterministic accuracy given a fixed eval seed, no learned reward model anywhere. The empty policy's reward is the no-augmentation baseline, for free.

The loop (lib/search.py) is backend-agnostic — it only needs a reward_fn(policy, seed) -> float. The sklearn and cnn backends inject different reward functions; the search is identical.


Backends

rl-augment search <config> reads backend: and routes — same --json shape either way.

backend reward engine compute status
sklearn held-out accuracy of a small scikit-learn proxy on digits CPU, in-process verified in CI, ~1s, demonstrably beats no-aug + random search on val
cnn val accuracy of a small torch CNN on CIFAR-10 compute: modal (rented GPU) or compute: local (your GPU) 🔌 wired; needs a GPU (not run in CI)
# Inspect the GPU plan + cost without spending anything:
uv run rl-augment search configs/cifar-search.yaml --dry-run --json

# Run it (rented GPU):  uv sync --extra modal && modal token set  (once)
# Run it (your GPU):    uv sync --extra gpu

What's verified vs scaffolded

Claim How it's checked Status
The search beats no-aug on val test_search_beats_no_aug_on_val (load-bearing invariant) ✅ CI
The search beats RandAugment/TrivialAugment test_search_at_least_matches_random_search ✅ CI
The search learns to drop flip_h for digits test_search_downweights_flip_for_digits ✅ CI
The search is deterministic (seed → identical metrics) make repro (trains twice, asserts equal) ✅ CI
The README quickstart still runs make readme runs the <!-- ci-test --> block ✅ CI
Live metrics posted to the run summary CI runs the search + ci_report.py ✅ CI
The val gain does not fully transfer to test rl-augment eval reports the test-split number ✅ honest, reported
CNN/CIFAR search on a real GPU scripts/modal_cnn.py 🔌 wired, needs a GPU

Drive it with Claude Code (or any agent)

Every command is non-interactive, takes --json, and uses load-bearing exit codes. The contract: with --json, stdout is exactly one JSON object (success or {"ok": false, "error": "..."}); exit 0 = success, non-zero = failure with one stderr line. Parse stdout, branch on the exit code.

A typical agent loop:

rl-augment search configs/digits-search.yaml --json   # → best_policy, include_probs, history, metrics
rl-augment baselines configs/digits-search.yaml --json # → did the search beat random? (search_beats_random)
rl-augment eval configs/digits-search.yaml --json      # → the honest test-split lift

Agent instructions live in AGENTS.md (the cross-tool standard; CLAUDE.md is a symlink to it).


Adapt it to your own task

Mostly config, not code:

  • Different proxy / data size — edit proxy (knn | svc | mlp), n_train, k_aug in the config.
  • Your own images on a GPU — use the cnn backend; point lib/cnn_runner.py at your dataset (it's a small, readable torch loop).
  • New augmentation op — add a function to the OPS registry in lib/augment.py (signature (img, magnitude, rng) -> img); the search picks it up automatically.

Layout

src/rl_augment/
  cli.py           # Typer app, one command per capability
  output.py        # emit()/fail() — output + exit-code contract
  commands/        # doctor, search, baselines, eval, sample, version
  lib/
    augment.py     # the op set + Policy (the actions)
    search.py      # group-relative policy gradient (the GRPO loop)
    reward.py      # verifiable reward = held-out accuracy + the baselines
    data.py        # starved digits splits
    backends.py    # sklearn (verified) | cnn (GPU) dispatch
    cnn_runner.py  # the torch/CIFAR reward engine (gpu extra)
configs/           # digits-search.yaml (verified) · cifar-search.yaml (GPU)
scripts/           # modal_cnn.py + stdlib CI helpers (ci_report, check_repro, test_readme)
notebooks/         # marimo (.py) — reward curve + which-ops-kept

Why this exists

A portfolio piece at the reinforcement-learning × computer-vision intersection that values rigor over a green number: a real policy-gradient implementation, a verifiable reward, honest baselines, and a finding it doesn't hide. Built in the same house style as rl-studio and ml-pipeline-template.

License

Apache-2.0.

About

Learn an image-augmentation policy with RL (group-relative policy gradient — the GRPO trick) and honestly prove whether it beats RandAugment/TrivialAugment. Agent-drivable --json CLI; verifiable reward = held-out accuracy. RL × computer-vision.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors