-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsearch.py
More file actions
156 lines (117 loc) · 5.76 KB
/
Copy pathsearch.py
File metadata and controls
156 lines (117 loc) · 5.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""The RL search: group-relative policy gradient over augmentation policies.
This is the GRPO trick applied to a computer-vision problem. We optimise a
*distribution* over augmentation policies, parameterised by:
- ``incl_logit[o]`` — a Bernoulli ``P(include op o) = sigmoid(incl_logit[o])``.
- ``mag_logit[o, :]`` — a categorical over magnitude bins for op ``o``.
Each step samples a group of ``G`` concrete policies, scores each by a reward
(held-out classifier accuracy — see :mod:`rl_augment.lib.reward`), and updates the
parameters by
advantage_i = (reward_i - mean) / (std + eps) # group-relative, no value net
theta += lr * mean_i [ advantage_i * grad_logp(policy_i) ] # REINFORCE
The group mean is the baseline — exactly GRPO's defining move. The reward stays
verifiable (a deterministic accuracy given a fixed eval seed), so the whole search
is reproducible and there is no learned reward model anywhere.
The search is backend-agnostic: it only needs a ``reward_fn(policy, seed) -> float``.
``sklearn`` and ``cnn`` backends inject different reward fns; the loop is identical.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass, field
import numpy as np
from rl_augment.lib.augment import MAG_BINS, N_MAG, N_OPS, OP_NAMES, Policy
RewardFn = Callable[[Policy, int], float]
def _sigmoid(x: np.ndarray) -> np.ndarray:
return 1.0 / (1.0 + np.exp(-x))
def _softmax(x: np.ndarray) -> np.ndarray:
z = x - x.max(axis=-1, keepdims=True)
e = np.exp(z)
return e / e.sum(axis=-1, keepdims=True)
@dataclass
class SearchResult:
best_policy: Policy
best_reward: float
baseline_reward: float # the empty (no-aug) policy under the same proxy + seed
include_probs: dict[str, float] # final P(include op) — interpretable
history: dict[str, list[float]] = field(default_factory=dict)
@property
def lift(self) -> float:
return self.best_reward - self.baseline_reward
@dataclass
class SearchConfig:
steps: int = 20
group_size: int = 10
lr: float = 0.5
entropy_coef: float = 0.01 # keeps the include-Bernoullis from collapsing early
p_apply: float = 0.5
seed: int = 0
eval_seed: int = 12345 # fixed → reward(policy) is a deterministic function
def _sample(theta_incl: np.ndarray, theta_mag: np.ndarray, p_apply: float, rng):
"""Sample one concrete policy; return it plus the raw include/mag for the grad."""
p_incl = _sigmoid(theta_incl)
include = rng.random(N_OPS) < p_incl
mag_probs = _softmax(theta_mag)
mag_bin = np.array([rng.choice(N_MAG, p=mag_probs[o]) for o in range(N_OPS)])
return Policy(include, mag_bin, p_apply), include, mag_bin
def _grad_logp(theta_incl, theta_mag, include, mag_bin):
"""Closed-form score function ∇θ log p(policy) for the factored distribution."""
p_incl = _sigmoid(theta_incl)
g_incl = include.astype(float) - p_incl # d/dz of Bernoulli log-likelihood
mag_probs = _softmax(theta_mag)
g_mag = np.zeros_like(theta_mag)
onehot = np.zeros_like(theta_mag)
onehot[np.arange(N_OPS), mag_bin] = 1.0
# magnitude is only "chosen" for included ops, so only they carry a gradient
g_mag = (onehot - mag_probs) * include[:, None]
return g_incl, g_mag
def _entropy_grad(theta_incl: np.ndarray) -> np.ndarray:
"""∇ of the Bernoulli entropy wrt the logit — pushes probs toward 0.5."""
p = _sigmoid(theta_incl)
return p * (1.0 - p) * (np.log1p(-p) - np.log(p + 1e-12))
def search(reward_fn: RewardFn, cfg: SearchConfig) -> SearchResult:
rng = np.random.default_rng(cfg.seed)
theta_incl = np.zeros(N_OPS) # start at P(include)=0.5 for every op
theta_mag = np.zeros((N_OPS, N_MAG))
baseline = reward_fn(
Policy(np.zeros(N_OPS, bool), np.zeros(N_OPS, int), cfg.p_apply), cfg.eval_seed
)
best_policy = Policy(np.zeros(N_OPS, bool), np.zeros(N_OPS, int), cfg.p_apply)
best_reward = baseline
hist: dict[str, list[float]] = {"mean_reward": [], "max_reward": [], "best_so_far": []}
for _ in range(cfg.steps):
policies, includes, magbins, rewards = [], [], [], []
for _g in range(cfg.group_size):
policy, inc, mb = _sample(theta_incl, theta_mag, cfg.p_apply, rng)
r = reward_fn(policy, cfg.eval_seed)
policies.append(policy)
includes.append(inc)
magbins.append(mb)
rewards.append(r)
if r > best_reward:
best_reward, best_policy = r, policy
rewards_arr = np.array(rewards)
adv = (rewards_arr - rewards_arr.mean()) / (rewards_arr.std() + 1e-4)
grad_incl = np.zeros(N_OPS)
grad_mag = np.zeros((N_OPS, N_MAG))
for i in range(cfg.group_size):
g_incl, g_mag = _grad_logp(theta_incl, theta_mag, includes[i], magbins[i])
grad_incl += adv[i] * g_incl
grad_mag += adv[i] * g_mag
grad_incl /= cfg.group_size
grad_mag /= cfg.group_size
grad_incl += cfg.entropy_coef * _entropy_grad(theta_incl)
theta_incl += cfg.lr * grad_incl
theta_mag += cfg.lr * grad_mag
hist["mean_reward"].append(float(rewards_arr.mean()))
hist["max_reward"].append(float(rewards_arr.max()))
hist["best_so_far"].append(float(best_reward))
include_probs = {OP_NAMES[o]: float(_sigmoid(theta_incl)[o]) for o in range(N_OPS)}
return SearchResult(
best_policy=best_policy,
best_reward=float(best_reward),
baseline_reward=float(baseline),
include_probs=include_probs,
history=hist,
)
def favoured_magnitudes(theta_mag: np.ndarray) -> dict[str, float]: # pragma: no cover - helper
probs = _softmax(theta_mag)
return {OP_NAMES[o]: float(MAG_BINS[int(probs[o].argmax())]) for o in range(N_OPS)}