@@ -1799,53 +1799,114 @@ def train(
17991799 # We already did forward pass, just counting it
18001800 ga_forward_only_count += 1
18011801
1802- # Log GA iteration
1802+ # Log GA/GD iteration
18031803 if neox_args .rank == 0 and not ga_skipped_iter :
1804- print (f" GA iteration { ga_iter + 1 } /{ neox_args .ga_iters } , actual_loss: { actual_loss :.4f} , ga_objective: { ga_objective :.4f} " )
1805-
1806- # Debug: Show first GA loss values and pipeline mode
1804+ if neox_args .gd_mode and retain_data_iterator is not None :
1805+ # Gradient Difference mode
1806+ gd_combined_loss = ga_loss_dict .get ('combined_loss' , 0 )
1807+ print (f" GD iteration { ga_iter + 1 } /{ neox_args .ga_iters } , combined_loss: { gd_combined_loss :.4f} , "
1808+ f"forget_loss: { ga_loss_dict .get ('forget_loss' , 0 ):.4f} , "
1809+ f"retain_loss: { ga_loss_dict .get ('retain_loss' , 0 ):.4f} " )
1810+ else :
1811+ # Gradient Ascent mode
1812+ print (f" GA iteration { ga_iter + 1 } /{ neox_args .ga_iters } , actual_loss: { actual_loss :.4f} , ga_objective: { ga_objective :.4f} " )
1813+
1814+ # Debug: Show first iteration loss values and pipeline mode
18071815 if ga_iter == 0 :
18081816 print (f" [DEBUG] Pipeline parallel: { neox_args .is_pipe_parallel } " )
1809- print (f" [DEBUG] GA actual loss (positive): { actual_loss :.4f} , GA objective (negative): { ga_objective :.4f} " )
1817+ if neox_args .gd_mode and retain_data_iterator is not None :
1818+ print (f" [DEBUG] GD mode active with combined loss formula: α * L_retain - L_forget" )
1819+ else :
1820+ print (f" [DEBUG] GA actual loss (positive): { actual_loss :.4f} , GA objective (negative): { ga_objective :.4f} " )
18101821
1811- # Log average GA loss to all monitoring services
1822+ # Log average GA/GD loss to all monitoring services
18121823 if neox_args .rank == 0 and ga_loss_count > 0 :
1813- avg_ga_actual_loss = ga_loss_sum / ga_loss_count # This is the actual positive loss
1814- avg_ga_objective = - avg_ga_actual_loss # This is the negated value for GA
1815-
1816- # TensorBoard logging
1817- if neox_args .tensorboard_writer :
1818- neox_args .tensorboard_writer .add_scalar (
1819- "train/ga_actual_loss" , avg_ga_actual_loss , iteration
1820- )
1821- neox_args .tensorboard_writer .add_scalar (
1822- "train/ga_objective" , avg_ga_objective , iteration
1823- )
1824- neox_args .tensorboard_writer .add_scalar (
1825- "train/ga_iterations_total" , ga_iter + 1 , iteration
1826- )
1827-
1828- # WandB logging
1829- if neox_args .use_wandb :
1830- import wandb
1831- wandb .log ({
1832- "train/ga_actual_loss" : avg_ga_actual_loss ,
1833- "train/ga_objective" : avg_ga_objective ,
1834- "train/ga_iterations" : ga_iter + 1 ,
1835- "iteration" : iteration ,
1836- })
1837-
1838- # Comet logging
1839- if neox_args .comet_experiment :
1840- neox_args .comet_experiment .log_metric (
1841- "ga_actual_loss" , avg_ga_actual_loss , step = iteration
1842- )
1843- neox_args .comet_experiment .log_metric (
1844- "ga_objective" , avg_ga_objective , step = iteration
1845- )
1846- neox_args .comet_experiment .log_metric (
1847- "ga_iterations" , ga_iter + 1 , step = iteration
1848- )
1824+ if neox_args .gd_mode and retain_data_iterator is not None :
1825+ # Gradient Difference mode - get the combined loss from the last iteration
1826+ # Note: For interval mode with multiple iterations, we'd need to track averages differently
1827+ # For now, log the last iteration's metrics
1828+ gd_combined_loss = ga_loss_dict .get ('combined_loss' , 0 )
1829+ gd_forget_loss = ga_loss_dict .get ('forget_loss' , 0 )
1830+ gd_retain_loss = ga_loss_dict .get ('retain_loss' , 0 )
1831+
1832+ # TensorBoard logging
1833+ if neox_args .tensorboard_writer :
1834+ neox_args .tensorboard_writer .add_scalar (
1835+ "train/gd_combined_loss" , gd_combined_loss , iteration
1836+ )
1837+ neox_args .tensorboard_writer .add_scalar (
1838+ "train/gd_forget_loss" , gd_forget_loss , iteration
1839+ )
1840+ neox_args .tensorboard_writer .add_scalar (
1841+ "train/gd_retain_loss" , gd_retain_loss , iteration
1842+ )
1843+ neox_args .tensorboard_writer .add_scalar (
1844+ "train/gd_iterations_total" , ga_iter + 1 , iteration
1845+ )
1846+
1847+ # WandB logging
1848+ if neox_args .use_wandb :
1849+ import wandb
1850+ wandb .log ({
1851+ "train/gd_combined_loss" : gd_combined_loss ,
1852+ "train/gd_forget_loss" : gd_forget_loss ,
1853+ "train/gd_retain_loss" : gd_retain_loss ,
1854+ "train/gd_iterations" : ga_iter + 1 ,
1855+ "iteration" : iteration ,
1856+ })
1857+
1858+ # Comet logging
1859+ if neox_args .comet_experiment :
1860+ neox_args .comet_experiment .log_metric (
1861+ "gd_combined_loss" , gd_combined_loss , step = iteration
1862+ )
1863+ neox_args .comet_experiment .log_metric (
1864+ "gd_forget_loss" , gd_forget_loss , step = iteration
1865+ )
1866+ neox_args .comet_experiment .log_metric (
1867+ "gd_retain_loss" , gd_retain_loss , step = iteration
1868+ )
1869+ neox_args .comet_experiment .log_metric (
1870+ "gd_iterations" , ga_iter + 1 , step = iteration
1871+ )
1872+ else :
1873+ # Gradient Ascent mode
1874+ avg_ga_actual_loss = ga_loss_sum / ga_loss_count # This is the actual positive loss
1875+ avg_ga_objective = - avg_ga_actual_loss # This is the negated value for GA
1876+
1877+ # TensorBoard logging
1878+ if neox_args .tensorboard_writer :
1879+ neox_args .tensorboard_writer .add_scalar (
1880+ "train/ga_actual_loss" , avg_ga_actual_loss , iteration
1881+ )
1882+ neox_args .tensorboard_writer .add_scalar (
1883+ "train/ga_objective" , avg_ga_objective , iteration
1884+ )
1885+ neox_args .tensorboard_writer .add_scalar (
1886+ "train/ga_iterations_total" , ga_iter + 1 , iteration
1887+ )
1888+
1889+ # WandB logging
1890+ if neox_args .use_wandb :
1891+ import wandb
1892+ wandb .log ({
1893+ "train/ga_actual_loss" : avg_ga_actual_loss ,
1894+ "train/ga_objective" : avg_ga_objective ,
1895+ "train/ga_iterations" : ga_iter + 1 ,
1896+ "iteration" : iteration ,
1897+ })
1898+
1899+ # Comet logging
1900+ if neox_args .comet_experiment :
1901+ neox_args .comet_experiment .log_metric (
1902+ "ga_actual_loss" , avg_ga_actual_loss , step = iteration
1903+ )
1904+ neox_args .comet_experiment .log_metric (
1905+ "ga_objective" , avg_ga_objective , step = iteration
1906+ )
1907+ neox_args .comet_experiment .log_metric (
1908+ "ga_iterations" , ga_iter + 1 , step = iteration
1909+ )
18491910
18501911 ga_loss_sum = 0.0
18511912 ga_loss_count = 0
@@ -1965,42 +2026,86 @@ def train(
19652026 # Log interleaved GA/GD step
19662027 if neox_args .rank == 0 :
19672028 if neox_args .gd_mode and retain_data_iterator is not None :
1968- print (f" [Interleaved GD] iteration { iteration } , actual_loss: { actual_loss :.4f} , ga_objective: { ga_objective :.4f} " )
2029+ # For gradient difference, use gd-specific naming
2030+ print (f" [Interleaved GD] iteration { iteration } , gd_combined_loss: { loss_dict .get ('combined_loss' , 0 ):.4f} , gd_objective: { ga_objective :.4f} " )
19692031 else :
19702032 print (f" [Interleaved GA] iteration { iteration } , actual_loss: { actual_loss :.4f} , ga_objective: { ga_objective :.4f} " )
19712033
1972- # Log GA metrics immediately (not periodically) for interleaved mode
2034+ # Log GA/GD metrics immediately (not periodically) for interleaved mode
19732035 if neox_args .rank == 0 and iteration % neox_args .log_interval == 0 :
1974- # TensorBoard logging
1975- if neox_args .tensorboard_writer :
1976- neox_args .tensorboard_writer .add_scalar (
1977- "train/ga_actual_loss" , actual_loss , iteration
1978- )
1979- neox_args .tensorboard_writer .add_scalar (
1980- "train/ga_objective" , ga_objective , iteration
1981- )
1982- neox_args .tensorboard_writer .add_scalar (
1983- "train/batch_type" , 1.0 , iteration # 1.0 for GA
1984- )
1985-
1986- # WandB logging
1987- if neox_args .use_wandb :
1988- import wandb
1989- wandb .log ({
1990- "train/ga_actual_loss" : actual_loss ,
1991- "train/ga_objective" : ga_objective ,
1992- "train/batch_type" : 1.0 , # 1.0 for GA
1993- "iteration" : iteration ,
1994- })
1995-
1996- # Comet logging
1997- if neox_args .comet_experiment :
1998- neox_args .comet_experiment .log_metric (
1999- "ga_actual_loss" , actual_loss , step = iteration
2000- )
2001- neox_args .comet_experiment .log_metric (
2002- "ga_objective" , ga_objective , step = iteration
2003- )
2036+ if neox_args .gd_mode and retain_data_iterator is not None :
2037+ # Gradient Difference mode metrics
2038+ gd_combined_loss = loss_dict .get ('combined_loss' , 0 )
2039+
2040+ # TensorBoard logging
2041+ if neox_args .tensorboard_writer :
2042+ neox_args .tensorboard_writer .add_scalar (
2043+ "train/gd_combined_loss" , gd_combined_loss , iteration
2044+ )
2045+ neox_args .tensorboard_writer .add_scalar (
2046+ "train/gd_forget_loss" , loss_dict .get ('forget_loss' , 0 ), iteration
2047+ )
2048+ neox_args .tensorboard_writer .add_scalar (
2049+ "train/gd_retain_loss" , loss_dict .get ('retain_loss' , 0 ), iteration
2050+ )
2051+ neox_args .tensorboard_writer .add_scalar (
2052+ "train/batch_type" , 2.0 , iteration # 2.0 for GD
2053+ )
2054+
2055+ # WandB logging
2056+ if neox_args .use_wandb :
2057+ import wandb
2058+ wandb .log ({
2059+ "train/gd_combined_loss" : gd_combined_loss ,
2060+ "train/gd_forget_loss" : loss_dict .get ('forget_loss' , 0 ),
2061+ "train/gd_retain_loss" : loss_dict .get ('retain_loss' , 0 ),
2062+ "train/batch_type" : 2.0 , # 2.0 for GD
2063+ "iteration" : iteration ,
2064+ })
2065+
2066+ # Comet logging
2067+ if neox_args .comet_experiment :
2068+ neox_args .comet_experiment .log_metric (
2069+ "gd_combined_loss" , gd_combined_loss , step = iteration
2070+ )
2071+ neox_args .comet_experiment .log_metric (
2072+ "gd_forget_loss" , loss_dict .get ('forget_loss' , 0 ), step = iteration
2073+ )
2074+ neox_args .comet_experiment .log_metric (
2075+ "gd_retain_loss" , loss_dict .get ('retain_loss' , 0 ), step = iteration
2076+ )
2077+ else :
2078+ # Gradient Ascent mode metrics
2079+ # TensorBoard logging
2080+ if neox_args .tensorboard_writer :
2081+ neox_args .tensorboard_writer .add_scalar (
2082+ "train/ga_actual_loss" , actual_loss , iteration
2083+ )
2084+ neox_args .tensorboard_writer .add_scalar (
2085+ "train/ga_objective" , ga_objective , iteration
2086+ )
2087+ neox_args .tensorboard_writer .add_scalar (
2088+ "train/batch_type" , 1.0 , iteration # 1.0 for GA
2089+ )
2090+
2091+ # WandB logging
2092+ if neox_args .use_wandb :
2093+ import wandb
2094+ wandb .log ({
2095+ "train/ga_actual_loss" : actual_loss ,
2096+ "train/ga_objective" : ga_objective ,
2097+ "train/batch_type" : 1.0 , # 1.0 for GA
2098+ "iteration" : iteration ,
2099+ })
2100+
2101+ # Comet logging
2102+ if neox_args .comet_experiment :
2103+ neox_args .comet_experiment .log_metric (
2104+ "ga_actual_loss" , actual_loss , step = iteration
2105+ )
2106+ neox_args .comet_experiment .log_metric (
2107+ "ga_objective" , ga_objective , step = iteration
2108+ )
20042109
20052110 # Restore original learning rates
20062111 if neox_args .ga_lr_scale != 1.0 and original_lrs :
0 commit comments