Skip to content

Commit 9bf30dd

Browse files
Kyle1668claude
andcommitted
fix: Update gradient difference to match inverted formula
Updated documentation and tests to match the inverted gradient difference formula: L_total = α * L_retain - L_forget Now gd_retain_weight semantics are intuitive: - Higher values (40-100) = more retention, less forgetting - Lower values (1-10) = more aggressive unlearning Updated test expectations to match new formula 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent f1c6570 commit 9bf30dd

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

megatron/neox_arguments/neox_args.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ class NeoXArgsTraining(NeoXArgsTemplate):
11721172
"""
11731173
Enable gradient difference mode. When enabled, the model performs gradient
11741174
difference unlearning using both forget (GA) and retain datasets. This replaces
1175-
pure gradient ascent with the formula: L_total = L_retain - α * L_forget
1175+
pure gradient ascent with the formula: L_total = α * L_retain - L_forget
11761176
"""
11771177

11781178
gd_retain_dataset: str = None
@@ -1189,10 +1189,11 @@ class NeoXArgsTraining(NeoXArgsTemplate):
11891189

11901190
gd_retain_weight: float = 40.0
11911191
"""
1192-
Weight (α) for the forget loss in gradient difference formula.
1193-
Higher values provide stronger retention of general capabilities.
1194-
The combined loss is: L_retain - α * L_forget
1195-
Based on Composable Interventions paper, values around 40 work well.
1192+
Weight (α) for the retain loss in gradient difference formula.
1193+
Higher values provide stronger retention of general capabilities (less forgetting).
1194+
Lower values allow more aggressive unlearning (more forgetting).
1195+
The combined loss is: α * L_retain - L_forget
1196+
Typical values: 1-10 for aggressive unlearning, 40-100 for balanced unlearning.
11961197
"""
11971198

11981199
gd_log_separate_losses: bool = True

tests/unit/test_gradient_ascent.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,11 +1125,11 @@ def test_gradient_difference_loss_formula(self):
11251125
alpha = 40.0
11261126

11271127
# Compute gradient difference loss
1128-
# L_total = L_retain - α * L_forget
1129-
gd_loss = retain_loss - alpha * forget_loss
1128+
# L_total = α * L_retain - L_forget
1129+
gd_loss = alpha * retain_loss - forget_loss
11301130

1131-
# Expected: 3.0 - 40.0 * 2.5 = 3.0 - 100.0 = -97.0
1132-
assert gd_loss.item() == -97.0
1131+
# Expected: 40.0 * 3.0 - 2.5 = 120.0 - 2.5 = 117.5
1132+
assert gd_loss.item() == 117.5
11331133

11341134
def test_gradient_difference_direction(self):
11351135
"""Test that gradient difference moves in correct directions."""
@@ -1144,7 +1144,7 @@ def test_gradient_difference_direction(self):
11441144

11451145
# Gradient difference objective
11461146
alpha = 1.0
1147-
combined_loss = retain_loss - alpha * forget_loss
1147+
combined_loss = alpha * retain_loss - forget_loss
11481148

11491149
# Compute gradients
11501150
combined_loss.backward()

0 commit comments

Comments
 (0)