Skip to content

Commit cf794c4

Browse files
Kyle1668claude
andcommitted
feat: Separate metric naming for gradient difference vs gradient ascent
- When gd_mode is active, use gd_ prefixed metrics instead of ga_ metrics - Log gd_combined_loss, gd_forget_loss, gd_retain_loss for gradient difference - Keep ga_actual_loss, ga_objective for gradient ascent only - Update batch_type indicator: 0.0=normal, 1.0=GA, 2.0=GD - Fix print statements to use appropriate terminology for each mode - Tests pass: all gradient difference unit tests working correctly 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 9b51779 commit cf794c4

File tree

1 file changed

+179
-74
lines changed

1 file changed

+179
-74
lines changed

megatron/training.py

Lines changed: 179 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,53 +1799,114 @@ def train(
17991799
# We already did forward pass, just counting it
18001800
ga_forward_only_count += 1
18011801

1802-
# Log GA iteration
1802+
# Log GA/GD iteration
18031803
if neox_args.rank == 0 and not ga_skipped_iter:
1804-
print(f" GA iteration {ga_iter + 1}/{neox_args.ga_iters}, actual_loss: {actual_loss:.4f}, ga_objective: {ga_objective:.4f}")
1805-
1806-
# Debug: Show first GA loss values and pipeline mode
1804+
if neox_args.gd_mode and retain_data_iterator is not None:
1805+
# Gradient Difference mode
1806+
gd_combined_loss = ga_loss_dict.get('combined_loss', 0)
1807+
print(f" GD iteration {ga_iter + 1}/{neox_args.ga_iters}, combined_loss: {gd_combined_loss:.4f}, "
1808+
f"forget_loss: {ga_loss_dict.get('forget_loss', 0):.4f}, "
1809+
f"retain_loss: {ga_loss_dict.get('retain_loss', 0):.4f}")
1810+
else:
1811+
# Gradient Ascent mode
1812+
print(f" GA iteration {ga_iter + 1}/{neox_args.ga_iters}, actual_loss: {actual_loss:.4f}, ga_objective: {ga_objective:.4f}")
1813+
1814+
# Debug: Show first iteration loss values and pipeline mode
18071815
if ga_iter == 0:
18081816
print(f" [DEBUG] Pipeline parallel: {neox_args.is_pipe_parallel}")
1809-
print(f" [DEBUG] GA actual loss (positive): {actual_loss:.4f}, GA objective (negative): {ga_objective:.4f}")
1817+
if neox_args.gd_mode and retain_data_iterator is not None:
1818+
print(f" [DEBUG] GD mode active with combined loss formula: α * L_retain - L_forget")
1819+
else:
1820+
print(f" [DEBUG] GA actual loss (positive): {actual_loss:.4f}, GA objective (negative): {ga_objective:.4f}")
18101821

1811-
# Log average GA loss to all monitoring services
1822+
# Log average GA/GD loss to all monitoring services
18121823
if neox_args.rank == 0 and ga_loss_count > 0:
1813-
avg_ga_actual_loss = ga_loss_sum / ga_loss_count # This is the actual positive loss
1814-
avg_ga_objective = -avg_ga_actual_loss # This is the negated value for GA
1815-
1816-
# TensorBoard logging
1817-
if neox_args.tensorboard_writer:
1818-
neox_args.tensorboard_writer.add_scalar(
1819-
"train/ga_actual_loss", avg_ga_actual_loss, iteration
1820-
)
1821-
neox_args.tensorboard_writer.add_scalar(
1822-
"train/ga_objective", avg_ga_objective, iteration
1823-
)
1824-
neox_args.tensorboard_writer.add_scalar(
1825-
"train/ga_iterations_total", ga_iter + 1, iteration
1826-
)
1827-
1828-
# WandB logging
1829-
if neox_args.use_wandb:
1830-
import wandb
1831-
wandb.log({
1832-
"train/ga_actual_loss": avg_ga_actual_loss,
1833-
"train/ga_objective": avg_ga_objective,
1834-
"train/ga_iterations": ga_iter + 1,
1835-
"iteration": iteration,
1836-
})
1837-
1838-
# Comet logging
1839-
if neox_args.comet_experiment:
1840-
neox_args.comet_experiment.log_metric(
1841-
"ga_actual_loss", avg_ga_actual_loss, step=iteration
1842-
)
1843-
neox_args.comet_experiment.log_metric(
1844-
"ga_objective", avg_ga_objective, step=iteration
1845-
)
1846-
neox_args.comet_experiment.log_metric(
1847-
"ga_iterations", ga_iter + 1, step=iteration
1848-
)
1824+
if neox_args.gd_mode and retain_data_iterator is not None:
1825+
# Gradient Difference mode - get the combined loss from the last iteration
1826+
# Note: For interval mode with multiple iterations, we'd need to track averages differently
1827+
# For now, log the last iteration's metrics
1828+
gd_combined_loss = ga_loss_dict.get('combined_loss', 0)
1829+
gd_forget_loss = ga_loss_dict.get('forget_loss', 0)
1830+
gd_retain_loss = ga_loss_dict.get('retain_loss', 0)
1831+
1832+
# TensorBoard logging
1833+
if neox_args.tensorboard_writer:
1834+
neox_args.tensorboard_writer.add_scalar(
1835+
"train/gd_combined_loss", gd_combined_loss, iteration
1836+
)
1837+
neox_args.tensorboard_writer.add_scalar(
1838+
"train/gd_forget_loss", gd_forget_loss, iteration
1839+
)
1840+
neox_args.tensorboard_writer.add_scalar(
1841+
"train/gd_retain_loss", gd_retain_loss, iteration
1842+
)
1843+
neox_args.tensorboard_writer.add_scalar(
1844+
"train/gd_iterations_total", ga_iter + 1, iteration
1845+
)
1846+
1847+
# WandB logging
1848+
if neox_args.use_wandb:
1849+
import wandb
1850+
wandb.log({
1851+
"train/gd_combined_loss": gd_combined_loss,
1852+
"train/gd_forget_loss": gd_forget_loss,
1853+
"train/gd_retain_loss": gd_retain_loss,
1854+
"train/gd_iterations": ga_iter + 1,
1855+
"iteration": iteration,
1856+
})
1857+
1858+
# Comet logging
1859+
if neox_args.comet_experiment:
1860+
neox_args.comet_experiment.log_metric(
1861+
"gd_combined_loss", gd_combined_loss, step=iteration
1862+
)
1863+
neox_args.comet_experiment.log_metric(
1864+
"gd_forget_loss", gd_forget_loss, step=iteration
1865+
)
1866+
neox_args.comet_experiment.log_metric(
1867+
"gd_retain_loss", gd_retain_loss, step=iteration
1868+
)
1869+
neox_args.comet_experiment.log_metric(
1870+
"gd_iterations", ga_iter + 1, step=iteration
1871+
)
1872+
else:
1873+
# Gradient Ascent mode
1874+
avg_ga_actual_loss = ga_loss_sum / ga_loss_count # This is the actual positive loss
1875+
avg_ga_objective = -avg_ga_actual_loss # This is the negated value for GA
1876+
1877+
# TensorBoard logging
1878+
if neox_args.tensorboard_writer:
1879+
neox_args.tensorboard_writer.add_scalar(
1880+
"train/ga_actual_loss", avg_ga_actual_loss, iteration
1881+
)
1882+
neox_args.tensorboard_writer.add_scalar(
1883+
"train/ga_objective", avg_ga_objective, iteration
1884+
)
1885+
neox_args.tensorboard_writer.add_scalar(
1886+
"train/ga_iterations_total", ga_iter + 1, iteration
1887+
)
1888+
1889+
# WandB logging
1890+
if neox_args.use_wandb:
1891+
import wandb
1892+
wandb.log({
1893+
"train/ga_actual_loss": avg_ga_actual_loss,
1894+
"train/ga_objective": avg_ga_objective,
1895+
"train/ga_iterations": ga_iter + 1,
1896+
"iteration": iteration,
1897+
})
1898+
1899+
# Comet logging
1900+
if neox_args.comet_experiment:
1901+
neox_args.comet_experiment.log_metric(
1902+
"ga_actual_loss", avg_ga_actual_loss, step=iteration
1903+
)
1904+
neox_args.comet_experiment.log_metric(
1905+
"ga_objective", avg_ga_objective, step=iteration
1906+
)
1907+
neox_args.comet_experiment.log_metric(
1908+
"ga_iterations", ga_iter + 1, step=iteration
1909+
)
18491910

18501911
ga_loss_sum = 0.0
18511912
ga_loss_count = 0
@@ -1965,42 +2026,86 @@ def train(
19652026
# Log interleaved GA/GD step
19662027
if neox_args.rank == 0:
19672028
if neox_args.gd_mode and retain_data_iterator is not None:
1968-
print(f" [Interleaved GD] iteration {iteration}, actual_loss: {actual_loss:.4f}, ga_objective: {ga_objective:.4f}")
2029+
# For gradient difference, use gd-specific naming
2030+
print(f" [Interleaved GD] iteration {iteration}, gd_combined_loss: {loss_dict.get('combined_loss', 0):.4f}, gd_objective: {ga_objective:.4f}")
19692031
else:
19702032
print(f" [Interleaved GA] iteration {iteration}, actual_loss: {actual_loss:.4f}, ga_objective: {ga_objective:.4f}")
19712033

1972-
# Log GA metrics immediately (not periodically) for interleaved mode
2034+
# Log GA/GD metrics immediately (not periodically) for interleaved mode
19732035
if neox_args.rank == 0 and iteration % neox_args.log_interval == 0:
1974-
# TensorBoard logging
1975-
if neox_args.tensorboard_writer:
1976-
neox_args.tensorboard_writer.add_scalar(
1977-
"train/ga_actual_loss", actual_loss, iteration
1978-
)
1979-
neox_args.tensorboard_writer.add_scalar(
1980-
"train/ga_objective", ga_objective, iteration
1981-
)
1982-
neox_args.tensorboard_writer.add_scalar(
1983-
"train/batch_type", 1.0, iteration # 1.0 for GA
1984-
)
1985-
1986-
# WandB logging
1987-
if neox_args.use_wandb:
1988-
import wandb
1989-
wandb.log({
1990-
"train/ga_actual_loss": actual_loss,
1991-
"train/ga_objective": ga_objective,
1992-
"train/batch_type": 1.0, # 1.0 for GA
1993-
"iteration": iteration,
1994-
})
1995-
1996-
# Comet logging
1997-
if neox_args.comet_experiment:
1998-
neox_args.comet_experiment.log_metric(
1999-
"ga_actual_loss", actual_loss, step=iteration
2000-
)
2001-
neox_args.comet_experiment.log_metric(
2002-
"ga_objective", ga_objective, step=iteration
2003-
)
2036+
if neox_args.gd_mode and retain_data_iterator is not None:
2037+
# Gradient Difference mode metrics
2038+
gd_combined_loss = loss_dict.get('combined_loss', 0)
2039+
2040+
# TensorBoard logging
2041+
if neox_args.tensorboard_writer:
2042+
neox_args.tensorboard_writer.add_scalar(
2043+
"train/gd_combined_loss", gd_combined_loss, iteration
2044+
)
2045+
neox_args.tensorboard_writer.add_scalar(
2046+
"train/gd_forget_loss", loss_dict.get('forget_loss', 0), iteration
2047+
)
2048+
neox_args.tensorboard_writer.add_scalar(
2049+
"train/gd_retain_loss", loss_dict.get('retain_loss', 0), iteration
2050+
)
2051+
neox_args.tensorboard_writer.add_scalar(
2052+
"train/batch_type", 2.0, iteration # 2.0 for GD
2053+
)
2054+
2055+
# WandB logging
2056+
if neox_args.use_wandb:
2057+
import wandb
2058+
wandb.log({
2059+
"train/gd_combined_loss": gd_combined_loss,
2060+
"train/gd_forget_loss": loss_dict.get('forget_loss', 0),
2061+
"train/gd_retain_loss": loss_dict.get('retain_loss', 0),
2062+
"train/batch_type": 2.0, # 2.0 for GD
2063+
"iteration": iteration,
2064+
})
2065+
2066+
# Comet logging
2067+
if neox_args.comet_experiment:
2068+
neox_args.comet_experiment.log_metric(
2069+
"gd_combined_loss", gd_combined_loss, step=iteration
2070+
)
2071+
neox_args.comet_experiment.log_metric(
2072+
"gd_forget_loss", loss_dict.get('forget_loss', 0), step=iteration
2073+
)
2074+
neox_args.comet_experiment.log_metric(
2075+
"gd_retain_loss", loss_dict.get('retain_loss', 0), step=iteration
2076+
)
2077+
else:
2078+
# Gradient Ascent mode metrics
2079+
# TensorBoard logging
2080+
if neox_args.tensorboard_writer:
2081+
neox_args.tensorboard_writer.add_scalar(
2082+
"train/ga_actual_loss", actual_loss, iteration
2083+
)
2084+
neox_args.tensorboard_writer.add_scalar(
2085+
"train/ga_objective", ga_objective, iteration
2086+
)
2087+
neox_args.tensorboard_writer.add_scalar(
2088+
"train/batch_type", 1.0, iteration # 1.0 for GA
2089+
)
2090+
2091+
# WandB logging
2092+
if neox_args.use_wandb:
2093+
import wandb
2094+
wandb.log({
2095+
"train/ga_actual_loss": actual_loss,
2096+
"train/ga_objective": ga_objective,
2097+
"train/batch_type": 1.0, # 1.0 for GA
2098+
"iteration": iteration,
2099+
})
2100+
2101+
# Comet logging
2102+
if neox_args.comet_experiment:
2103+
neox_args.comet_experiment.log_metric(
2104+
"ga_actual_loss", actual_loss, step=iteration
2105+
)
2106+
neox_args.comet_experiment.log_metric(
2107+
"ga_objective", ga_objective, step=iteration
2108+
)
20042109

20052110
# Restore original learning rates
20062111
if neox_args.ga_lr_scale != 1.0 and original_lrs:

0 commit comments

Comments
 (0)