diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 7b9b0c1b0..a578d2622 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -2219,3 +2219,99 @@ jsd,liger,full,memory,MB,BT,total tokens,1024,3514.0009765625,3514.0009765625,35 jsd,liger,full,memory,MB,BT,total tokens,2048,7014.0009765625,7014.0009765625,7014.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:32,0.8.0 jsd,liger,full,memory,MB,BT,total tokens,4096,14028.0009765625,14028.0009765625,14028.0009765625,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:32,0.8.0 jsd,liger,full,memory,MB,BT,total tokens,8192,28056.0,28056.0,28056.0,"{""vocab_size"": 128256, ""bsz"": 1, ""seq_len"": 8192}",NVIDIA B200,2026-05-27 17:14:32,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,4096,0.2385600060224533,0.23590399324893951,0.24048000574111938,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:58,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,8192,0.26820799708366394,0.26531198620796204,0.27125120162963867,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:58,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,16384,0.3599199950695038,0.3569599986076355,0.36266239881515505,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:58,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,32768,0.6502079963684082,0.6452223896980286,0.6563008189201356,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:58,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,65536,1.087440013885498,1.0827391862869262,1.0942399978637696,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:58,0.8.0 +megatron_cross_entropy,liger,forward,speed,ms,V,vocab size,131072,2.1112000942230225,2.108524799346924,2.1132928848266603,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:58,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,4096,0.2269120067358017,0.225913605093956,0.22775039970874789,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:59,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,8192,0.48070400953292847,0.47954559326171875,0.48152959942817686,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:59,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,16384,0.8124480247497559,0.8111680150032043,0.813696026802063,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:59,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,32768,2.1381120681762695,2.1316160202026366,2.142361545562744,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:59,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,65536,4.391056060791016,4.389933013916016,4.3939519882202145,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:59,0.8.0 +megatron_cross_entropy,torch,forward,speed,ms,V,vocab size,131072,8.81107234954834,8.806303977966309,8.81488037109375,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:43:59,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,4096,0.31273600459098816,0.31200000643730164,0.31331199407577515,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:01,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,8192,0.6373440027236938,0.636352002620697,0.6381440162658691,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:01,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,16384,1.3848639726638794,1.3825664043426515,1.3870784282684325,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:01,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,32768,2.829935908317566,2.8288448333740233,2.8316224098205565,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:01,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,65536,5.823584079742432,5.8137922286987305,5.826655864715576,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:01,0.8.0 +megatron_cross_entropy,megatron,forward,speed,ms,V,vocab size,131072,11.810944080352783,11.80013427734375,11.817446517944337,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:01,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,4096,0.5649279952049255,0.5639231920242309,0.5661375880241394,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:02,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,8192,1.094543993473053,1.093395209312439,1.0956799983978271,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:02,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,16384,2.0701760053634644,2.068671941757202,2.072096109390259,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:02,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,32768,4.027215957641602,4.024915313720703,4.028927993774413,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:02,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,65536,7.93833589553833,7.937132930755616,7.940275096893311,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:02,0.8.0 +megatron_cross_entropy,megatron-unfused,forward,speed,ms,V,vocab size,131072,15.817471981048584,15.814432144165039,15.821215629577637,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:02,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,4096,0.3516159951686859,0.3489919900894165,0.35476480722427367,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:03,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,8192,0.4639679938554764,0.46081281304359434,0.46724479794502255,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:03,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,16384,0.7123200297355652,0.7092544078826905,0.7154880046844483,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:03,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,32768,1.322975993156433,1.3184319734573364,1.328320026397705,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:03,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,65536,2.3985120058059692,2.393356847763062,2.405248022079468,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:03,0.8.0 +megatron_cross_entropy,liger,backward,speed,ms,V,vocab size,131072,4.698495864868164,4.69042558670044,4.70184965133667,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:03,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,4096,0.48710399866104126,0.4862975895404816,0.48793599009513855,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:04,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,8192,0.9741439819335938,0.9731391787528991,0.9752448081970215,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:04,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,16384,1.8759199976921082,1.8737216234207152,1.8779136180877687,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:04,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,32768,4.405440092086792,4.398931121826172,4.407916831970215,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:04,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,65536,8.920191764831543,8.916671752929688,8.923616409301758,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:04,0.8.0 +megatron_cross_entropy,torch,backward,speed,ms,V,vocab size,131072,17.845855712890625,17.843469619750977,17.847046661376954,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:04,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,4096,0.5368639826774597,0.536191999912262,0.5374848127365113,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:05,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,8192,1.0521279573440552,1.05141122341156,1.0528064489364624,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:05,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,16384,2.1722079515457153,2.169472026824951,2.1735936641693114,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:05,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,32768,4.448447942733765,4.447282981872559,4.449868965148926,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:05,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,65536,9.033520221710205,9.030751991271973,9.042854690551758,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:05,0.8.0 +megatron_cross_entropy,megatron,backward,speed,ms,V,vocab size,131072,17.958431243896484,17.9557315826416,17.960690689086913,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:05,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,4096,0.7786880135536194,0.7767040133476257,0.7803391933441162,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:06,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,8192,1.475823998451233,1.4742015838623046,1.4774847745895388,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:06,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,16384,2.787328004837036,2.785683250427246,2.7883071899414062,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:06,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,32768,5.420928001403809,5.419379329681396,5.423315238952636,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:06,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,65536,10.677472114562988,10.674201774597169,10.67763843536377,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:06,0.8.0 +megatron_cross_entropy,megatron-unfused,backward,speed,ms,V,vocab size,131072,21.254928588867188,21.25368995666504,21.2557315826416,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:06,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,4096,0.41488000750541687,0.4122175931930542,0.41767039299011227,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:07,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,8192,0.5923520028591156,0.5882560014724731,0.6005120277404785,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:07,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,16384,0.9725440144538879,0.9690751910209655,0.9779008150100708,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:07,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,32768,1.8396799564361572,1.8349119424819946,1.8454079627990723,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:07,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,65536,3.4398880004882812,3.429087924957275,3.4599680423736574,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:07,0.8.0 +megatron_cross_entropy,liger,full,speed,ms,V,vocab size,131072,6.780799865722656,6.763827323913574,6.796000003814697,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:07,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,4096,0.5540960133075714,0.5528640151023865,0.555296003818512,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:08,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,8192,1.1072640419006348,1.1059776306152345,1.1086400032043457,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:08,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,16384,2.139024019241333,2.1362879276275635,2.142067241668701,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:08,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,32768,4.9202880859375,4.91553258895874,4.924480152130127,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:08,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,65536,9.961983680725098,9.958450889587402,9.964262008666992,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:08,0.8.0 +megatron_cross_entropy,torch,full,speed,ms,V,vocab size,131072,19.926143646240234,19.913933563232423,19.93060417175293,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:08,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,4096,0.6057599782943726,0.604960024356842,0.6063359975814819,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:09,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,8192,1.1868000030517578,1.1857984066009521,1.1880639791488647,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:09,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,16384,2.4351680278778076,2.4330944538116457,2.4368255615234373,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:09,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,32768,4.968512058258057,4.967942237854004,4.969702434539795,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:09,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,65536,10.074239730834961,10.06367359161377,10.077676963806152,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:09,0.8.0 +megatron_cross_entropy,megatron,full,speed,ms,V,vocab size,131072,20.022079467773438,20.015307998657228,20.02815971374512,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:09,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,4096,0.846015989780426,0.8447743892669677,0.8472639918327332,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,8192,1.6103359460830688,1.6088960409164428,1.6117696285247802,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,16384,3.0503358840942383,3.049056053161621,3.0524160861968994,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,32768,5.937999963760376,5.936895847320557,5.939583778381348,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,65536,11.711487770080566,11.710764503479005,11.713843536376952,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron-unfused,full,speed,ms,V,vocab size,131072,23.322416305541992,23.321644973754882,23.32369956970215,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,4096,192.0791015625,192.0791015625,192.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,8192,384.0791015625,384.0791015625,384.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,16384,768.0791015625,768.0791015625,768.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,32768,1536.0791015625,1536.0791015625,1536.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,65536,3072.0791015625,3072.0791015625,3072.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,liger,full,memory,MB,V,vocab size,131072,6144.0791015625,6144.0791015625,6144.0791015625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,4096,512.0947265625,512.0947265625,512.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,8192,1024.0947265625,1024.0947265625,1024.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,16384,2048.0947265625,2048.0947265625,2048.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,32768,4096.0947265625,4096.0947265625,4096.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,65536,8192.0947265625,8192.0947265625,8192.0947265625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,torch,full,memory,MB,V,vocab size,131072,16384.09375,16384.09375,16384.09375,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,4096,448.1650390625,448.1650390625,448.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,8192,896.1650390625,896.1650390625,896.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,16384,1792.1650390625,1792.1650390625,1792.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,32768,3584.1650390625,3584.1650390625,3584.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,65536,7168.1650390625,7168.1650390625,7168.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron,full,memory,MB,V,vocab size,131072,14336.1650390625,14336.1650390625,14336.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:10,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,4096,320.1650390625,320.1650390625,320.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,8192,640.1650390625,640.1650390625,640.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,16384,1280.1650390625,1280.1650390625,1280.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,32768,2560.1650390625,2560.1650390625,2560.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,65536,5120.1650390625,5120.1650390625,5120.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 +megatron_cross_entropy,megatron-unfused,full,memory,MB,V,vocab size,131072,10240.1650390625,10240.1650390625,10240.1650390625,"{""S"": 2048, ""B"": 4}",NVIDIA H100 80GB HBM3,2026-06-10 07:44:11,0.8.0 diff --git a/benchmark/scripts/benchmark_megatron_cross_entropy.py b/benchmark/scripts/benchmark_megatron_cross_entropy.py new file mode 100644 index 000000000..d523054b2 --- /dev/null +++ b/benchmark/scripts/benchmark_megatron_cross_entropy.py @@ -0,0 +1,207 @@ +"""Benchmark Liger's Megatron-LM cross-entropy wrapper. + +Compares four providers on the per-token CE call shape ``[seq, batch, vocab]``: + + - **torch**: vanilla ``F.cross_entropy`` + - **megatron**: Megatron's *fused* ``fused_vocab_parallel_cross_entropy`` path + (``cross_entropy_loss_fusion=True``, JIT-fused via TorchScript) + - **megatron-unfused**: Megatron's *unfused* ``vocab_parallel_cross_entropy`` + path (``cross_entropy_loss_fusion=False``, eager Python; the path users on + ``label_smoothing`` typically end up on) + - **liger**: ``LigerMegatronCrossEntropy`` — Liger's Triton CE wrapped in the + Megatron fused signature. Same kernel regardless of which Megatron symbol + it was patched onto, so we only benchmark it once. + +Requires a Liger-supported accelerator (CUDA / ROCm). With megatron-core not +installed, both megatron providers are silently skipped. + +Output goes to the shared ``benchmark/data/all_benchmark_data.csv`` like every +other Liger benchmark — rows are tagged with ``kernel_name="megatron_cross_entropy"`` +and the standard visualizer renders them via: + + python benchmark/benchmarks_visualizer.py \\ + --kernel-name megatron_cross_entropy --metric-name speed + python benchmark/benchmarks_visualizer.py \\ + --kernel-name megatron_cross_entropy --metric-name memory +""" + +import os + +import torch +import torch.nn.functional as F +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.megatron import LigerMegatronCrossEntropy +from liger_kernel.utils import infer_device + +device = infer_device() + +try: + from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy + from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy + + _MEGATRON_AVAILABLE = True +except ImportError: + fused_vocab_parallel_cross_entropy = None + vocab_parallel_cross_entropy = None + _MEGATRON_AVAILABLE = False + + +def _make_inputs(s: int, b: int, v: int, requires_grad: bool = True): + logits = torch.randn(s, b, v, device=device, dtype=torch.bfloat16, requires_grad=requires_grad) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + return logits, target + + +def _pytorch_cross_entropy(logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + s, b, v = logits.shape + return F.cross_entropy( + logits.reshape(-1, v).float(), + target.reshape(-1), + reduction="none", + ).reshape(s, b) + + +def _ensure_single_rank_tp_group(): + """Initialize torch.distributed (single-rank) and return a usable TP group. + + For a single-process benchmark we use the world group of size 1; the + internal all-reduce becomes a no-op. + """ + import torch.distributed as dist + + if not dist.is_initialized(): + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", "29500") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("LOCAL_RANK", "0") + dist.init_process_group(backend="nccl") + return dist.group.WORLD + + +def _select_fwd(provider: str): + if provider == "liger": + ce = LigerMegatronCrossEntropy(reduction="none") + return lambda logits, target: ce(logits, target) + if provider == "torch": + return _pytorch_cross_entropy + if provider == "megatron": + if not _MEGATRON_AVAILABLE: + raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron' provider") + tp_group = _ensure_single_rank_tp_group() + + def _megatron_fused_call(logits, target): + return fused_vocab_parallel_cross_entropy(logits, target, tp_group) + + return _megatron_fused_call + if provider == "megatron-unfused": + if not _MEGATRON_AVAILABLE: + raise RuntimeError("megatron-core not installed; cannot benchmark 'megatron-unfused' provider") + tp_group = _ensure_single_rank_tp_group() + + def _megatron_unfused_call(logits, target): + # Unfused signature: (logits, target, label_smoothing=0.0, tp_group=None) + return vocab_parallel_cross_entropy(logits, target, 0.0, tp_group) + + return _megatron_unfused_call + raise ValueError(f"unknown provider: {provider!r}") + + +def bench_speed_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + v = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + s = input.extra_benchmark_config["S"] + b = input.extra_benchmark_config["B"] + + logits, target = _make_inputs(s, b, v) + fwd_fn = _select_fwd(provider) + + def fwd(): + return fwd_fn(logits, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, rep=100, quantiles=QUANTILES) + elif mode == "backward": + # Megatron's fused CE writes gradients in-place into saved tensors during backward, + # which breaks the standard retain_graph=True / repeated-backward pattern do_bench + # uses elsewhere. Run a fresh fwd+bwd each iteration so each backward sees an + # unmodified autograd graph. Measurement therefore includes forward time — + # subtract the "forward" measurement to derive backward-only timing. + def _fwd_bwd(): + if logits.grad is not None: + logits.grad = None + out = fwd_fn(logits, target) + out.sum().backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(_fwd_bwd, rep=100, quantiles=QUANTILES) + elif mode == "full": + + def full(): + y = fwd() + y.sum().backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) + else: + raise ValueError(f"unknown mode: {mode!r}") + + return SingleBenchmarkRunOutput(y_20=ms_20, y_50=ms_50, y_80=ms_80) + + +def bench_memory_megatron_cross_entropy(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + v = input.x + provider = input.kernel_provider + s = input.extra_benchmark_config["S"] + b = input.extra_benchmark_config["B"] + + logits, target = _make_inputs(s, b, v) + fwd_fn = _select_fwd(provider) + + def full(): + y = fwd_fn(logits, target) + y.sum().backward() + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + providers = ["liger", "torch"] + if _MEGATRON_AVAILABLE: + providers.append("megatron") + providers.append("megatron-unfused") + + common_configs = { + "kernel_name": "megatron_cross_entropy", + "x_name": "V", + "x_label": "vocab size", + "x_values": [2**i for i in range(12, 18)], + "kernel_providers": providers, + "extra_benchmark_configs": [{"S": 2048, "B": 4}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_megatron_cross_entropy, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_megatron_cross_entropy, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/docs/High-Level-APIs.md b/docs/High-Level-APIs.md index 5433e03d3..6bbe008a9 100644 --- a/docs/High-Level-APIs.md +++ b/docs/High-Level-APIs.md @@ -91,3 +91,47 @@ You can also use the Patching APIs to use the kernels for a specific model archi extra: show_docstring: true show_signature: true + +--- + +## Megatron-LM + +Liger also exposes a patch for the [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) +training framework, replacing Megatron's native RMSNorm and both vocab-parallel +cross-entropy paths (fused and unfused) with Liger's Triton kernels. + +| **Framework** | **API** | **Supported Operations** | +|---------------|--------------------------------------------------------|--------------------------| +| Megatron-LM | `liger_kernel.megatron.apply_liger_kernel_to_megatron` | RMSNorm, CrossEntropyLoss | + +**Scope**: Initial release supports `tensor_model_parallel_size=1` only for +cross-entropy. Vocab-parallel cross-entropy (TP>1) is follow-up work — with +TP>1, each rank holds a sharded `[N, V/tp]` logits slice and cross-entropy +requires cross-rank all-reduces that Liger's kernel does not perform. The +patch raises a `RuntimeError` at patch time or call time if TP>1 is detected. + +**Usage**: + +```python +from liger_kernel.megatron import apply_liger_kernel_to_megatron + +# Call before Megatron's forward pass reaches compute_language_model_loss. +# Defaults match Megatron's native CE behavior; no CE-specific config needed. +apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True) +``` + +Both the fused (`config.cross_entropy_loss_fusion=True`, +`cross_entropy_fusion_impl='native'`) and unfused +(`config.cross_entropy_loss_fusion=False`) CE paths are patched in a single +call, so Megatron picks up Liger regardless of which path your config selects. + +For training setups that need explicit kernel configuration (custom +`ignore_index`, `label_smoothing`, etc.), instantiate +`LigerMegatronCrossEntropy` directly and wire it into your model — see +`examples/megatron/run_mode2_hand_spec.py`. + +::: liger_kernel.megatron.apply_liger_kernel_to_megatron + options: + extra: + show_docstring: true + show_signature: true diff --git a/examples/megatron/run_mode1_monkey_patch.py b/examples/megatron/run_mode1_monkey_patch.py index 5daff20fa..5e3f460df 100644 --- a/examples/megatron/run_mode1_monkey_patch.py +++ b/examples/megatron/run_mode1_monkey_patch.py @@ -1,20 +1,24 @@ -"""Mode 1 — monkey-patch Megatron-Core to use Liger RMSNorm. +"""Mode 1 — monkey-patch Megatron-Core to use Liger RMSNorm + cross-entropy. Adapted from Megatron's ``examples/run_simple_mcore_train_loop.py``. The relevant additions (vs. that file) are: - 1. ``apply_liger_kernel_to_megatron(rms_norm=True)`` called once at the - top of ``model_provider()``. This patches both - ``LocalSpecProvider.layer_norm`` (per-layer norm slots) and - ``transformer_block.LayerNormImpl`` (block-level ``final_layernorm``). + 1. ``apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True)`` + called once at the top of ``model_provider()``. This patches: + - ``LocalSpecProvider.layer_norm`` (per-layer norm slots) + - ``transformer_block.LayerNormImpl`` (block-level ``final_layernorm``) + - ``fused_cross_entropy.fused_vocab_parallel_cross_entropy`` + (the fused CE path) + - ``tensor_parallel.cross_entropy.vocab_parallel_cross_entropy`` + (the unfused CE path) 2. ``normalization="RMSNorm"`` added to ``TransformerConfig`` so the model actually has RMSNorm slots to patch (Megatron defaults to ``LayerNorm``). - 3. ``_print_norm_classes`` after model construction — prints the - resolved class for every norm slot so you can verify Liger took - over. + 3. ``_print_norm_classes`` + ``_print_ce_symbols`` after model construction + — print the resolved class/function bindings so you can verify Liger + took over for every slot. Run with: torchrun --nproc_per_node=2 --master_addr=127.0.0.1 --master_port=29500 \\ @@ -70,7 +74,7 @@ def initialize_distributed(tp: int = 2, pp: int = 1) -> None: def model_provider() -> GPTModel: # ↓↓ Mode 1 — patch once, everything below picks up Liger ↓↓ - apply_liger_kernel_to_megatron(rms_norm=True) + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=True) # ↑↑ ------------------------------------------------------ ↑↑ cfg = TransformerConfig( @@ -84,7 +88,7 @@ def model_provider() -> GPTModel: return GPTModel( config=cfg, transformer_layer_spec=get_gpt_layer_local_spec(normalization="RMSNorm"), - vocab_size=100, + vocab_size=128, max_sequence_length=_SEQUENCE_LENGTH, ) @@ -142,8 +146,21 @@ def _print_norm_classes(model: torch.nn.Module) -> None: print() +def _print_ce_symbols() -> None: + """Show the current bindings of Megatron's two CE entry points.""" + import megatron.core.fusions.fused_cross_entropy as fused + import megatron.core.tensor_parallel.cross_entropy as unfused + + print("\n=== Resolved CE symbols ===") + print(f" fused.fused_vocab_parallel_cross_entropy → {fused.fused_vocab_parallel_cross_entropy.__name__}") + print(f" unfused.vocab_parallel_cross_entropy → {unfused.vocab_parallel_cross_entropy.__name__}") + print() + + def main() -> None: - initialize_distributed(tp=2, pp=1) + # TP=1, DP=2 — CE patch (TP=1 only). Norms are correct under any TP value, so + # demonstrating both Liger features in one script means running data-parallel. + initialize_distributed(tp=1, pp=1) model_parallel_cuda_manual_seed(123) torch.manual_seed(123) @@ -153,6 +170,7 @@ def main() -> None: print("\n=== Full model tree (mode 1: monkey-patch) ===") print(gpt_model) _print_norm_classes(gpt_model) + _print_ce_symbols() ddp_cfg = DistributedDataParallelConfig( grad_reduce_in_fp32=False, diff --git a/examples/megatron/run_mode2_hand_spec.py b/examples/megatron/run_mode2_hand_spec.py index 1e17a6268..4053a7adc 100644 --- a/examples/megatron/run_mode2_hand_spec.py +++ b/examples/megatron/run_mode2_hand_spec.py @@ -1,22 +1,28 @@ -"""Mode 2 — hand-assembled TransformerBlockSubmodules using LigerMegatronRMSNorm. +"""Mode 2 — hand-assembled spec + GPTModel subclass using Liger directly. Adapted from Megatron's ``examples/run_simple_mcore_train_loop.py``. The relevant additions (vs. that file) are: - 1. Direct import of ``LigerMegatronRMSNorm`` (no monkey-patch). + 1. Direct imports of ``LigerMegatronRMSNorm`` and ``LigerMegatronCrossEntropy`` + (no monkey-patch). - 2. ``model_provider()`` assembles a ``TransformerBlockSubmodules`` by - hand, placing ``LigerMegatronRMSNorm`` into every norm slot: + 2. ``model_provider()`` assembles a ``TransformerBlockSubmodules`` by hand, + placing ``LigerMegatronRMSNorm`` into every norm slot: - per-layer ``input_layernorm`` and ``pre_mlp_layernorm`` - the block-level ``layer_norm`` field that backs ``decoder.final_layernorm`` - This is the slot-level integration path — verbose but maximally - explicit. It is the only way to control ``final_layernorm`` from - user code without monkey-patching. + This is the slot-level integration path — verbose but maximally explicit. + It is the only way to control ``final_layernorm`` from user code without + monkey-patching. - 3. ``_print_norm_classes`` after model construction — prints the - resolved class for every norm slot so you can verify Liger took - over. + 3. ``_LigerCEGPTModel(GPTModel)`` overrides + ``LanguageModule.compute_language_model_loss`` to route the loss through + a ``LigerMegatronCrossEntropy`` instance. Cross-entropy has no spec slot + in Megatron, so subclassing is the symmetric "hand-built" path. + + 4. ``_print_norm_classes`` + ``_print_ce_class`` after model construction + — print the resolved class for every norm slot AND the resolved CE + class on the model so you can verify Liger took over. Run with: torchrun --nproc_per_node=2 --master_addr=127.0.0.1 --master_port=29500 \\ @@ -65,6 +71,7 @@ from torch.utils.data import DataLoader # --- Liger integration: Mode 2 --------------------------------------------- +from liger_kernel.megatron import LigerMegatronCrossEntropy from liger_kernel.megatron import LigerMegatronRMSNorm # --------------------------------------------------------------------------- @@ -72,6 +79,31 @@ _SEQUENCE_LENGTH = 64 _NUM_ITERS = 5 _NUM_LAYERS = 2 +_LABEL_SMOOTHING = 0.1 + + +class _LigerCEGPTModel(GPTModel): + """``GPTModel`` subclass that routes its loss through ``LigerMegatronCrossEntropy``. + + Megatron's CE is not a spec slot — ``LanguageModule.compute_language_model_loss`` + calls ``fused_vocab_parallel_cross_entropy`` directly. The symmetric "hand-built" + integration is therefore to subclass ``GPTModel`` and override that method. + """ + + def __init__(self, *args, liger_ce_label_smoothing: float = 0.0, **kwargs): + super().__init__(*args, **kwargs) + self.liger_ce = LigerMegatronCrossEntropy( + ignore_index=-100, + label_smoothing=liger_ce_label_smoothing, + reduction="none", + ) + + def compute_language_model_loss(self, labels, logits): + # LanguageModule contract: input labels are [b, s], output loss is [b, s]. + # LigerMegatronCrossEntropy matches the fused signature, which expects [s, b]. + labels_sb = labels.transpose(0, 1).contiguous() # [s, b] + loss_sb = self.liger_ce(logits, labels_sb, self.pg_collection.tp) # [s, b] + return loss_sb.transpose(0, 1).contiguous() # [b, s] def initialize_distributed(tp: int = 2, pp: int = 1) -> None: @@ -133,11 +165,12 @@ def model_provider() -> GPTModel: ) # ↑↑ ----------------------------------------------------------------- ↑↑ - return GPTModel( + return _LigerCEGPTModel( config=cfg, transformer_layer_spec=block_spec, - vocab_size=100, + vocab_size=128, max_sequence_length=_SEQUENCE_LENGTH, + liger_ce_label_smoothing=_LABEL_SMOOTHING, ) @@ -194,8 +227,23 @@ def _print_norm_classes(model: torch.nn.Module) -> None: print() +def _print_ce_class(model: torch.nn.Module) -> None: + """Show that ``compute_language_model_loss`` will route through Liger.""" + ce = getattr(model, "liger_ce", None) + print("=== Resolved CE class ===") + if ce is None: + print(" model.liger_ce → (not set; subclass missing)") + else: + print(f" model.liger_ce → {type(ce).__module__}.{type(ce).__name__}") + print(f" ce.label_smoothing → {ce.label_smoothing}") + print(f" ce.ignore_index → {ce.ignore_index}") + print() + + def main() -> None: - initialize_distributed(tp=2, pp=1) + # TP=1, DP=2 — CE patch (TP=1 only). Norms are correct under any TP value, so + # demonstrating both Liger features in one script means running data-parallel. + initialize_distributed(tp=1, pp=1) model_parallel_cuda_manual_seed(123) torch.manual_seed(123) @@ -205,6 +253,7 @@ def main() -> None: print("\n=== Full model tree (mode 2: hand-built spec) ===") print(gpt_model) _print_norm_classes(gpt_model) + _print_ce_class(gpt_model) ddp_cfg = DistributedDataParallelConfig( grad_reduce_in_fp32=False, diff --git a/src/liger_kernel/megatron/__init__.py b/src/liger_kernel/megatron/__init__.py index 2208319a4..beb9b00a3 100644 --- a/src/liger_kernel/megatron/__init__.py +++ b/src/liger_kernel/megatron/__init__.py @@ -3,11 +3,20 @@ Public API: LigerMegatronRMSNorm — RMSNorm module conforming to Megatron-Core's LayerNormBuilder protocol. - apply_liger_kernel_to_megatron — patches Megatron-Core's BackendSpecProvider - so existing scripts pick up Liger kernels with one line. + LigerMegatronCrossEntropy — nn.Module drop-in for Megatron's vocab-parallel + cross-entropy (fused signature). + apply_liger_kernel_to_megatron — patches Megatron-Core so existing training + scripts pick up Liger kernels with one line. Currently supports + RMSNorm (via BackendSpecProvider) plus both the fused and unfused + vocab-parallel cross-entropy paths. """ +from liger_kernel.megatron.cross_entropy import LigerMegatronCrossEntropy from liger_kernel.megatron.monkey_patch import apply_liger_kernel_to_megatron from liger_kernel.megatron.rms_norm import LigerMegatronRMSNorm -__all__ = ["LigerMegatronRMSNorm", "apply_liger_kernel_to_megatron"] +__all__ = [ + "LigerMegatronCrossEntropy", + "LigerMegatronRMSNorm", + "apply_liger_kernel_to_megatron", +] diff --git a/src/liger_kernel/megatron/cross_entropy.py b/src/liger_kernel/megatron/cross_entropy.py new file mode 100644 index 000000000..fb899dd2f --- /dev/null +++ b/src/liger_kernel/megatron/cross_entropy.py @@ -0,0 +1,80 @@ +"""Megatron-Core compatible cross-entropy backed by Liger's Triton kernel.""" + +from __future__ import annotations + +import torch + +from liger_kernel.transformers.functional import liger_cross_entropy + + +class LigerMegatronCrossEntropy(torch.nn.Module): + """``nn.Module`` drop-in for Megatron's vocab-parallel cross-entropy. + + Conforms to ``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy``'s + signature, ``(vocab_parallel_logits, target, tp_group=None)``. Public Mode-2 (hand-built) + API: instantiate once with the per-training-run config, then call from your overridden + ``LanguageModule.compute_language_model_loss`` (or wherever Megatron's CE would live in your + custom model). + + Mirrors the ``LigerMegatronRMSNorm`` pattern shipped in PR #1254: config-time kwargs on + ``__init__``, data-only ``forward``. Single source of truth for the underlying Liger call; + the monkey-patch wrappers in ``monkey_patch.py`` instantiate this class. + + Args: + ignore_index: Target index to ignore. + label_smoothing: Cross-entropy label smoothing factor. + reduction: Must be ``"none"`` — Megatron's vocab-parallel CE contract returns per-token + loss shaped ``[seq, batch]`` and handles reduction itself downstream. + + Scope: + TP=1 only. Vocab-parallel cross-entropy (TP>1) requires cross-rank reductions + that Liger's kernel does not perform; tracked as Phase 1.5 follow-up. Raises + ``RuntimeError`` at call time if a multi-rank ``tp_group`` is supplied. + """ + + def __init__( + self, + ignore_index: int = -100, + label_smoothing: float = 0.0, + reduction: str = "none", + ): + super().__init__() + if reduction != "none": + raise ValueError( + f"Megatron's vocab-parallel CE contract requires per-token loss; " + f"reduction must be 'none', got {reduction!r}." + ) + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward( + self, + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + tp_group=None, + ) -> torch.Tensor: + if tp_group is not None and hasattr(tp_group, "size") and tp_group.size() > 1: + raise RuntimeError( + f"LigerMegatronCrossEntropy requires tensor_model_parallel_size=1, " + f"got tp_group.size()={tp_group.size()}. Vocab-parallel support is " + f"tracked as follow-up work." + ) + if vocab_parallel_logits.dim() != 3: + raise ValueError( + f"vocab_parallel_logits must be 3-D ([seq, batch, vocab]); " + f"got shape {tuple(vocab_parallel_logits.shape)}. (HuggingFace's " + f"[batch, seq, vocab] callers must transpose before calling.)" + ) + s, b, v = vocab_parallel_logits.shape + loss = liger_cross_entropy( + vocab_parallel_logits.reshape(-1, v), + target.reshape(-1), + ignore_index=self.ignore_index, + label_smoothing=self.label_smoothing, + reduction=self.reduction, + ) + return loss.reshape(s, b) + + def extra_repr(self) -> str: + return f"ignore_index={self.ignore_index}, label_smoothing={self.label_smoothing}, reduction={self.reduction!r}" diff --git a/src/liger_kernel/megatron/monkey_patch.py b/src/liger_kernel/megatron/monkey_patch.py index 8b47f6dd2..db0bb43d3 100644 --- a/src/liger_kernel/megatron/monkey_patch.py +++ b/src/liger_kernel/megatron/monkey_patch.py @@ -9,12 +9,16 @@ _PATCH_MARKER = "__liger_patched__" -def apply_liger_kernel_to_megatron(rms_norm: bool = True) -> None: +def apply_liger_kernel_to_megatron( + rms_norm: bool = True, + cross_entropy: bool = False, +) -> None: """Patch Megatron-Core to use Liger Triton kernels. - Idempotent. Targets Megatron's ``BackendSpecProvider`` and - ``transformer_block.LayerNormImpl`` so every model that routes through - the standard spec system benefits without per-model code. + Idempotent. Targets Megatron's ``BackendSpecProvider``, + ``transformer_block.LayerNormImpl``, and (optionally) both of Megatron's + vocab-parallel cross-entropy entry points so models that route through + the standard spec system pick up Liger without per-model code. Args: rms_norm: When ``True`` (default) replace both @@ -22,20 +26,69 @@ def apply_liger_kernel_to_megatron(rms_norm: bool = True) -> None: ``transformer_block.LayerNormImpl`` (the block-level ``final_layernorm`` slot) so all RMSNorm modules in the model become ``LigerMegatronRMSNorm``. + cross_entropy: When ``True`` replace both + ``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy`` + (fused path) and + ``megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy`` + (unfused path) with Liger's Triton cross-entropy. Default + ``False`` because this path currently supports + ``tensor_model_parallel_size=1`` only. The fused wrapper matches + native's ``(logits, target, tp_group)`` signature exactly; the + unfused wrapper additionally honors a runtime ``label_smoothing`` + argument, matching native's + ``(logits, target, label_smoothing=0.0, tp_group=None)``. Notes: Call this BEFORE building your model. Patching after instantiation will not retroactively swap modules already created. - This only affects the local (non-TE) backend. Mixing Liger norms with - ``TESpecProvider`` requires a custom ``BackendSpecProvider`` subclass - because TE's ``TELayerNormColumnParallelLinear`` folds the norm into - the QKV linear; naive substitution would either double-norm or skip - the norm. See the project README for the mixing recipe. + The RMSNorm patches only affect the local (non-TE) backend. Mixing + Liger norms with ``TESpecProvider`` requires a custom + ``BackendSpecProvider`` subclass because TE's + ``TELayerNormColumnParallelLinear`` folds the norm into the QKV + linear; naive substitution would either double-norm or skip the norm. + + For explicit kernel configuration (custom ``ignore_index``, + ``label_smoothing``, etc.) instantiate ``LigerMegatronCrossEntropy`` + directly and wire it into your model (Mode 2). The monkey-patch path + is intentionally a transparent drop-in: it matches Megatron's native + defaults so callers can flip Liger on without touching loss config. + + Raises: + RuntimeError: When ``cross_entropy=True`` and Megatron's parallel + state already reports ``tensor_model_parallel_size > 1``. """ if rms_norm: _patch_local_spec_provider_layer_norm() _patch_transformer_block_layernorm_impl() + if cross_entropy: + _check_tensor_parallel_size_at_patch_time() + _patch_fused_vocab_parallel_cross_entropy() + _patch_vocab_parallel_cross_entropy() + + +def _check_tensor_parallel_size_at_patch_time() -> None: + """Raise RuntimeError if Megatron's parallel state already reports TP>1. + + If Megatron is importable but the parallel state is not yet initialized + (for example, ``apply_liger_kernel_to_megatron`` is called before + ``initialize_megatron``), silently defer; per-kernel wrappers check again + at call time against the ``tp_group`` argument Megatron supplies. + """ + try: + from megatron.core import parallel_state + except ImportError: + return + try: + tp_size = parallel_state.get_tensor_model_parallel_world_size() + except (AssertionError, RuntimeError): + return + if tp_size > 1: + raise RuntimeError( + f"apply_liger_kernel_to_megatron(cross_entropy=True) currently requires " + f"tensor_model_parallel_size=1, got {tp_size}. Vocab-parallel cross-entropy " + f"support is planned as follow-up work." + ) def _patch_local_spec_provider_layer_norm() -> None: @@ -120,3 +173,118 @@ def __new__(cls, config, hidden_size, eps=1e-5, **kwargs): "Patched megatron.core.transformer.transformer_block.LayerNormImpl " "to route RMSNorm configs through LigerMegatronRMSNorm." ) + + +# Sentinel for "caller did not pass this kwarg". A plain ``0.0`` default would be +# observationally indistinguishable from "user explicitly asked for 0.0" — and Megatron's +# native vocab_parallel_cross_entropy accepts call-time ``label_smoothing=0.0`` as a real +# request for that value. We must not silently override. +_LABEL_SMOOTHING_UNSET = object() + + +def _patch_fused_vocab_parallel_cross_entropy() -> None: + """Replace ``megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy``. + + Wraps a single ``LigerMegatronCrossEntropy`` instance (constructed with class defaults + that match Megatron's native fused-CE behavior) in a closure matching Megatron's fused-CE + signature ``(logits, target, tp_group)``. Idempotent: a sentinel attribute on the + replacement prevents wrappers from stacking. + """ + try: + import megatron.core.fusions.fused_cross_entropy as fused_ce + except ImportError as exc: + raise ImportError( + "apply_liger_kernel_to_megatron(cross_entropy=True) requires megatron-core to be " + "installed. Expected symbol path: " + "megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy." + ) from exc + + if not hasattr(fused_ce, "fused_vocab_parallel_cross_entropy"): + raise ImportError( + "megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy not " + "found. The symbol path may have changed in your Megatron-LM version. Please file " + "an issue on https://github.com/linkedin/Liger-Kernel with your megatron-core version." + ) + + if getattr(fused_ce.fused_vocab_parallel_cross_entropy, _PATCH_MARKER, False): + return # already patched + + original = fused_ce.fused_vocab_parallel_cross_entropy + + from liger_kernel.megatron.cross_entropy import LigerMegatronCrossEntropy + + ce = LigerMegatronCrossEntropy() + + def liger_fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target, tp_group=None): + return ce(vocab_parallel_logits, target, tp_group=tp_group) + + setattr(liger_fused_vocab_parallel_cross_entropy, _PATCH_MARKER, True) + setattr(liger_fused_vocab_parallel_cross_entropy, "__wrapped__", original) + fused_ce.fused_vocab_parallel_cross_entropy = liger_fused_vocab_parallel_cross_entropy + + logger.info( + "Patched megatron.core.fusions.fused_cross_entropy.fused_vocab_parallel_cross_entropy with Liger cross-entropy." + ) + + +def _patch_vocab_parallel_cross_entropy() -> None: + """Replace ``megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy``. + + This is Megatron's *unfused* eager-Python vocab-parallel CE path, dispatched to when + ``config.cross_entropy_loss_fusion=False``. Its signature accepts ``label_smoothing`` + at call time, so the wrapper honors a runtime value when the caller actually passed + one. A sentinel disambiguates "caller passed 0.0" (use 0.0) from "caller didn't pass" + (use class default). + """ + try: + import megatron.core.tensor_parallel.cross_entropy as unfused_ce + except ImportError as exc: + raise ImportError( + "apply_liger_kernel_to_megatron(cross_entropy=True) requires megatron-core to be " + "installed. Expected symbol path: " + "megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy." + ) from exc + + if not hasattr(unfused_ce, "vocab_parallel_cross_entropy"): + raise ImportError( + "megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy not " + "found. The symbol path may have changed in your Megatron-LM version. Please file " + "an issue on https://github.com/linkedin/Liger-Kernel with your megatron-core version." + ) + + if getattr(unfused_ce.vocab_parallel_cross_entropy, _PATCH_MARKER, False): + return # already patched + + original = unfused_ce.vocab_parallel_cross_entropy + + from liger_kernel.megatron.cross_entropy import LigerMegatronCrossEntropy + + # Class-default instance; reused for every call where the caller doesn't pass + # label_smoothing. Avoids allocating a fresh module per CE call in the common case + # (Megatron's own LanguageModule.compute_language_model_loss dispatch does not pass + # label_smoothing — it always lands here). + default_ce = LigerMegatronCrossEntropy() + + def liger_vocab_parallel_cross_entropy( + vocab_parallel_logits, + target, + label_smoothing=_LABEL_SMOOTHING_UNSET, + tp_group=None, + ): + # Sentinel-based "did the caller pass this?" check so that an explicit + # label_smoothing=0.0 from the caller is honored verbatim (matching Megatron's + # native vocab_parallel_cross_entropy contract). Construct a fresh + # LigerMegatronCrossEntropy only on the runtime-override path; nn.Module + # construction is microseconds vs. CE-kernel milliseconds. + if label_smoothing is _LABEL_SMOOTHING_UNSET: + return default_ce(vocab_parallel_logits, target, tp_group=tp_group) + ce = LigerMegatronCrossEntropy(label_smoothing=label_smoothing) + return ce(vocab_parallel_logits, target, tp_group=tp_group) + + setattr(liger_vocab_parallel_cross_entropy, _PATCH_MARKER, True) + setattr(liger_vocab_parallel_cross_entropy, "__wrapped__", original) + unfused_ce.vocab_parallel_cross_entropy = liger_vocab_parallel_cross_entropy + + logger.info( + "Patched megatron.core.tensor_parallel.cross_entropy.vocab_parallel_cross_entropy with Liger cross-entropy." + ) diff --git a/test/megatron/test_cross_entropy.py b/test/megatron/test_cross_entropy.py new file mode 100644 index 000000000..a530e8607 --- /dev/null +++ b/test/megatron/test_cross_entropy.py @@ -0,0 +1,489 @@ +"""Correctness tests for ``LigerMegatronCrossEntropy``. + +The class is the public Mode-2 API; the monkey-patch wrappers in +``monkey_patch.py`` are thin closures around an instance of this class. Tests +target the class directly — that's the single source of truth for the CE math. + +Mirrors ``test/megatron/test_rms_norm.py``'s parametrize style for the +fp32/bf16 sweep so the visual symmetry across the two megatron-side files is +preserved. +""" + +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.megatron import LigerMegatronCrossEntropy +from liger_kernel.utils import infer_device +from test.utils import assert_verbose_allclose +from test.utils import set_seed +from test.utils import supports_bfloat16 + +device = infer_device() +set_seed(42) + + +def _reference_loss( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + ignore_index: int, + label_smoothing: float, +) -> torch.Tensor: + s, b, v = vocab_parallel_logits.shape + loss_flat = F.cross_entropy( + vocab_parallel_logits.reshape(-1, v).float(), + target.reshape(-1), + reduction="none", + ignore_index=ignore_index, + label_smoothing=label_smoothing, + ) + return loss_flat.reshape(s, b) + + +# --------------------------------------------------------------------------- +# Forward correctness vs. F.cross_entropy reference. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "s, b, v", + [ + (8, 2, 128), + (16, 4, 4096), + (32, 1, 32000), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-7, 1e-6), + pytest.param( + torch.bfloat16, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_class_matches_pytorch_cross_entropy(s, b, v, dtype, atol, rtol): + """Headline correctness — forward AND backward parity vs. PyTorch's F.cross_entropy. + + Liger writes the gradient back into the input tensor in-place during forward; both paths + are fed independent clones of the same starting tensor so the in-place write on the Liger + side can't corrupt the reference path.""" + ce = LigerMegatronCrossEntropy() + + base = torch.randn(s, b, v, device=device, dtype=dtype) * 0.5 + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=-100, label_smoothing=0.0) + got = ce(h_got, target) + + assert got.shape == (s, b) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=atol, rtol=rtol) + + +# --------------------------------------------------------------------------- +# Configuration plumbing — wrapper-specific contracts. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("ignore_index", [-100, 0]) +def test_class_respects_ignore_index(ignore_index): + """ignore_index plumbing — forward AND backward parity. Ignored positions must + contribute zero loss AND zero gradient on the Liger side.""" + s, b, v = 16, 2, 1024 + ce = LigerMegatronCrossEntropy(ignore_index=ignore_index) + + base = torch.randn(s, b, v, device=device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + target.view(-1)[: (s * b) // 4] = ignore_index + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=ignore_index, label_smoothing=0.0) + got = ce(h_got, target) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) + + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=1e-6, rtol=1e-5) + + +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) +def test_class_respects_label_smoothing(label_smoothing): + """label_smoothing plumbing — forward AND backward parity. Liger and PyTorch share the + same smoothing formula but with different intermediate kernels; gradient check guards + against algebraic-equivalence-but-numerical-divergence bugs.""" + s, b, v = 8, 2, 512 + ce = LigerMegatronCrossEntropy(label_smoothing=label_smoothing) + + base = torch.randn(s, b, v, device=device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=-100, label_smoothing=label_smoothing) + got = ce(h_got, target) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-5, rtol=1e-4) + + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=1e-5, rtol=1e-4) + + +@pytest.mark.parametrize("bad_reduction", ["mean", "sum", "MEAN", "garbage"]) +def test_class_rejects_non_none_reduction(bad_reduction): + """Megatron's contract is per-token loss; mean/sum break the [s, b] return shape.""" + with pytest.raises(ValueError, match="reduction must be 'none'"): + LigerMegatronCrossEntropy(reduction=bad_reduction) + + +def test_class_rejects_non_3d_logits(): + """The class explicitly guards against HuggingFace-shape [b, s, v] callers etc.""" + ce = LigerMegatronCrossEntropy() + bad = torch.randn(8, 16, device=device) # 2-D + target = torch.randint(0, 16, (8,), device=device, dtype=torch.long) + with pytest.raises(ValueError, match="3-D"): + ce(bad, target) + + too_many = torch.randn(2, 2, 4, 16, device=device) # 4-D + target2 = torch.randint(0, 16, (2, 2, 4), device=device, dtype=torch.long) + with pytest.raises(ValueError, match="3-D"): + ce(too_many, target2) + + +# --------------------------------------------------------------------------- +# TP guard — the only safety net the class itself enforces. +# --------------------------------------------------------------------------- + + +def test_class_raises_on_tp_group_size_greater_than_one(): + ce = LigerMegatronCrossEntropy() + logits = torch.randn(4, 1, 32, device=device) + target = torch.randint(0, 32, (4, 1), device=device, dtype=torch.long) + + class _FakeGroup: + def size(self): + return 2 + + with pytest.raises(RuntimeError, match="tensor_model_parallel_size=1"): + ce(logits, target, tp_group=_FakeGroup()) + + +def test_class_accepts_single_rank_tp_group(): + ce = LigerMegatronCrossEntropy() + logits = torch.randn(4, 1, 32, device=device) + target = torch.randint(0, 32, (4, 1), device=device, dtype=torch.long) + + class _FakeGroup: + def size(self): + return 1 + + out = ce(logits, target, tp_group=_FakeGroup()) + assert out.shape == (4, 1) + + +# --------------------------------------------------------------------------- +# Gradient sanity — Liger's CE writes gradients in place; verify the class +# preserves them through Megatron's [s, b, v] reshape contract. +# --------------------------------------------------------------------------- + + +def test_class_preserves_gradients(): + """Backward smoke test — grad exists with correct shape AND matches PyTorch's reference. + + Previously asserted only ``grad is not None`` + shape; that gave a misleadingly green test + when grad values were wrong. Now compares against ``F.cross_entropy(...).sum().backward()``.""" + s, b, v = 8, 2, 256 + ce = LigerMegatronCrossEntropy() + + base = torch.randn(s, b, v, device=device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=-100, label_smoothing=0.0) + got = ce(h_got, target) + + ref.sum().backward() + got.sum().backward() + + assert h_got.grad is not None + assert h_got.grad.shape == h_got.shape + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=1e-6, rtol=1e-5) + + +def test_class_extra_repr(): + ce = LigerMegatronCrossEntropy(ignore_index=42, label_smoothing=0.07) + rep = ce.extra_repr() + assert "ignore_index=42" in rep + assert "label_smoothing=0.07" in rep + assert "reduction='none'" in rep + + +# --------------------------------------------------------------------------- +# Beefier sweeps adapted from test/transformers/test_cross_entropy.py. +# +# LigerMegatronCrossEntropy is a [s, b, v] -> [s, b] reshape around Liger's +# CE op (reduction='none'). Numerical behavior should match the kernel itself, +# so the same parametrization patterns the kernel suite uses are the right +# coverage shape here — just adapted to the 3-D contract. +# --------------------------------------------------------------------------- + + +def _assign_ignore_index(target: torch.Tensor, ignore_index: int, frac: float = 0.25) -> None: + """In-place: replace ~frac of target positions with ignore_index. + + Matches the transformers-side helpers that randomize the masked-out indices + so the test isn't degenerate on a particular row layout. + """ + flat = target.view(-1) + n = max(1, int(flat.numel() * frac)) + idx = torch.randperm(flat.numel(), device=flat.device)[:n] + flat[idx] = ignore_index + + +@pytest.mark.parametrize( + "s, b, v", + [ + (16, 1, 4096), + (32, 2, 32000), # llama-ish vocab + (5, 3, 123), # weird shape + ], +) +@pytest.mark.parametrize("scalar", [0.5, 1.0, 5.0]) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-7, 1e-6), + pytest.param( + torch.bfloat16, + 1e-2, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_class_correctness_scalar_sweep(s, b, v, scalar, dtype, atol, rtol): + """Vary input magnitude — guards against numerical drift at large logit scales + (mirrors the ``scalar`` parametrize in ``test/transformers/test_cross_entropy.py``).""" + ce = LigerMegatronCrossEntropy() + + base = torch.randn(s, b, v, device=device, dtype=dtype) * scalar + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + # Backward parity: feed the same starting tensor through both paths. + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=-100, label_smoothing=0.0) + got = ce(h_got, target) + + assert got.shape == (s, b) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "s, b, v, ignore_index", + [ + (16, 1, 4096, -100), # standard hf sentinel + (32, 2, 32000, 2), # positive id (valid vocab slot used as ignore) + (5, 3, 123, -123), # weird negative + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-7, 1e-6), + pytest.param( + torch.bfloat16, + 1e-2, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_class_correctness_with_ignore_index_sweep(s, b, v, ignore_index, dtype, atol, rtol): + """Broader ignore_index sweep including positive/negative sentinels and forward+backward + correctness vs. PyTorch's reference. Mirrors transformers-side ``test_correctness_with_ignore_index``.""" + ce = LigerMegatronCrossEntropy(ignore_index=ignore_index) + + base = torch.randn(s, b, v, device=device, dtype=dtype) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + _assign_ignore_index(target, ignore_index, frac=0.3) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=ignore_index, label_smoothing=0.0) + got = ce(h_got, target) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "s, b, v, ignore_index, label_smoothing", + [ + (16, 1, 4096, 1, 0.1), + (32, 2, 32000, -100, 0.2), + (5, 3, 123, -300, 0.05), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-6, 1e-5), + pytest.param( + torch.bfloat16, + 1e-2, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_class_correctness_with_label_smoothing_and_ignore_index( + s, + b, + v, + ignore_index, + label_smoothing, + dtype, + atol, + rtol, +): + """Combined ignore_index × label_smoothing sweep — the two are independent in Liger's CE + kernel but mixing them historically surfaced bugs in the smoothing math. Mirrors + ``test_correctness_with_label_smoothing_with_ignore_index_once`` from the kernel suite.""" + ce = LigerMegatronCrossEntropy(ignore_index=ignore_index, label_smoothing=label_smoothing) + + base = torch.randn(s, b, v, device=device, dtype=dtype) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + _assign_ignore_index(target, ignore_index, frac=0.25) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=ignore_index, label_smoothing=label_smoothing) + got = ce(h_got, target) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "s, b, v", + [ + (16, 1, 4096), + (5, 3, 123), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-6, 1e-5), + pytest.param( + torch.bfloat16, + 1e-2, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_class_correctness_not_last_layer(s, b, v, dtype, atol, rtol): + """Loss is multiplied by a downstream factor before ``.backward(grad_output)`` — verifies + that Liger's in-place gradient write through the wrapper survives non-trivial chained + autograd (i.e. CE isn't the last op in the graph). Mirrors transformers-side + ``test_correctness_not_last_layer``.""" + ce = LigerMegatronCrossEntropy() + + base = torch.randn(s, b, v, device=device, dtype=dtype) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + + ref = _reference_loss(h_ref, target, ignore_index=-100, label_smoothing=0.0) + got = ce(h_got, target) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + # Chain: loss = ref * 3 then backward with arbitrary grad_output. + loss_ref = ref * 3.0 + loss_got = got * 3.0 + grad_out = torch.rand_like(ref) + loss_ref.backward(gradient=grad_out) + loss_got.backward(gradient=grad_out) + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("ignore_index", [-100, 2]) +def test_class_rejects_out_of_bounds_target(ignore_index): + """Liger's CE kernel asserts target ∈ [0, V); a stray out-of-bounds target should + raise rather than silently produce garbage. Mirrors transformers-side + ``test_correctness_with_out_of_bounds_target_once``.""" + s, b, v = 8, 2, 64 + ce = LigerMegatronCrossEntropy(ignore_index=ignore_index) + + logits = torch.randn(s, b, v, device=device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + # Plant a couple of out-of-bounds values; ignore_index is permitted but the + # >=V poisoned slots are not. + flat = target.view(-1) + poison = torch.randperm(flat.numel(), device=flat.device)[:2] + flat[poison] = v + 5 # >= V; the kernel-level assert should fire. + + with pytest.raises(AssertionError, match="out of bounds"): + ce(logits, target) + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-7, 1e-6), + pytest.param( + torch.bfloat16, + 1e-2, + 5e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), + ), + ], +) +def test_class_correctness_forward_only(dtype, atol, rtol): + """Forward-only path (under ``torch.no_grad()``) — verifies the wrapper still returns the + right loss when autograd is disabled, AND that a subsequent ``.backward()`` raises the + expected "does not require grad" error. Mirrors transformers-side ``test_correctness_with_forward_only``.""" + s, b, v = 16, 2, 1024 + ce = LigerMegatronCrossEntropy() + + logits_input = torch.randn(s, b, v, device=device, dtype=dtype) + target = torch.randint(0, v, (s, b), device=device, dtype=torch.long) + + with torch.no_grad(): + # Clone the input separately for each path because Liger writes gradient + # state in-place; sharing a buffer would corrupt the reference. + ref = _reference_loss(logits_input.clone(), target, ignore_index=-100, label_smoothing=0.0) + got = ce(logits_input.clone(), target) + assert_verbose_allclose(got.float(), ref.float(), atol=atol, rtol=rtol) + + # Attempting backward on a forward-only output should raise. + with pytest.raises(RuntimeError, match="does not require grad"): + got.sum().backward() diff --git a/test/megatron/test_monkey_patch.py b/test/megatron/test_monkey_patch.py new file mode 100644 index 000000000..4f1510b0f --- /dev/null +++ b/test/megatron/test_monkey_patch.py @@ -0,0 +1,1009 @@ +"""Tests for ``apply_liger_kernel_to_megatron``'s patch mechanism. + +Megatron-LM is not a test dependency. We inject stub modules into ``sys.modules`` so the +patch helpers can run entirely on CPU without a real megatron-core install. For each kernel +Liger patches into Megatron, this file verifies: + +- the patched Megatron symbol(s) are actually replaced +- patching is idempotent (calling apply twice doesn't stack wrappers) +- the patch is a no-op when the kernel flag is False +- missing megatron-core / missing symbol path raise helpful ``ImportError``\\s +- kernel-specific dispatch contracts (e.g. CE TP>1 raises; RMSNorm only displaces the + ``WrappedTorchNorm`` fallback, not TE / Apex) +- end-to-end: the patched symbol invoked with real tensors produces correct output + +File layout — extend by appending a new ``-- patch --`` section per kernel +Liger learns to patch: + + 1. Stub-megatron installers + fixtures + 2. Cross-entropy patch tests + 3. RMSNorm patch tests + 4. Cross-kernel public-API surface tests + 5. End-to-end integration through patched CE symbols +""" + +import sys +import types + +from unittest.mock import patch + +import pytest + +# =========================================================================== +# 1. Stub-megatron installers + fixtures +# =========================================================================== + + +def _ensure_megatron_roots(): + """Create / fetch ``megatron`` and ``megatron.core`` module stubs. + + Idempotent: a second installer (RMSNorm + CE called in either order) reuses + the roots rather than clobbering them, so a single test can stand up both + kernel surfaces at once. + """ + megatron = sys.modules.get("megatron") or types.ModuleType("megatron") + megatron_core = sys.modules.get("megatron.core") or types.ModuleType("megatron.core") + sys.modules["megatron"] = megatron + sys.modules["megatron.core"] = megatron_core + megatron.core = megatron_core + return megatron, megatron_core + + +def _install_fake_megatron_ce( + tp_size: int = 1, + with_fused_symbol: bool = True, + with_unfused_symbol: bool = True, +): + """Install the cross-entropy slice of the Megatron stub. + + Returns a tuple ``(fused_ce_module, unfused_ce_module)`` so tests can inspect what + the patch helpers wrote onto them. + """ + _, megatron_core = _ensure_megatron_roots() + fusions = types.ModuleType("megatron.core.fusions") + fused_ce = types.ModuleType("megatron.core.fusions.fused_cross_entropy") + tensor_parallel = types.ModuleType("megatron.core.tensor_parallel") + unfused_ce = types.ModuleType("megatron.core.tensor_parallel.cross_entropy") + parallel_state = types.ModuleType("megatron.core.parallel_state") + + if with_fused_symbol: + + def original_fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target, tp_group=None): + raise AssertionError("original megatron fused kernel called — patch failed") + + fused_ce.fused_vocab_parallel_cross_entropy = original_fused_vocab_parallel_cross_entropy + + if with_unfused_symbol: + + def original_vocab_parallel_cross_entropy( + vocab_parallel_logits, + target, + label_smoothing=0.0, + tp_group=None, + ): + raise AssertionError("original megatron unfused kernel called — patch failed") + + unfused_ce.vocab_parallel_cross_entropy = original_vocab_parallel_cross_entropy + + parallel_state.get_tensor_model_parallel_world_size = lambda: tp_size + + sys.modules["megatron.core.fusions"] = fusions + sys.modules["megatron.core.fusions.fused_cross_entropy"] = fused_ce + sys.modules["megatron.core.tensor_parallel"] = tensor_parallel + sys.modules["megatron.core.tensor_parallel.cross_entropy"] = unfused_ce + sys.modules["megatron.core.parallel_state"] = parallel_state + + megatron_core.fusions = fusions + megatron_core.tensor_parallel = tensor_parallel + megatron_core.parallel_state = parallel_state + fusions.fused_cross_entropy = fused_ce + tensor_parallel.cross_entropy = unfused_ce + + return fused_ce, unfused_ce + + +def _install_fake_megatron_rms_norm( + layer_norm_is_wrapped_torch_norm: bool = True, + with_backends_module: bool = True, + with_transformer_block_module: bool = True, +): + """Install the RMSNorm slice of the Megatron stub. + + Returns ``(backends_module, transformer_block_module)`` so tests can inspect what the + patch helpers wrote onto them. Mirrors the CE installer's shape so future kernels can + grow alongside via the same pattern. + + Args: + layer_norm_is_wrapped_torch_norm: When True (default), seeds + ``transformer_block.LayerNormImpl`` as the stub ``WrappedTorchNorm``. The block- + level patch only displaces that fallback; set False to verify the no-op path + taken under TE / Apex. + """ + _, megatron_core = _ensure_megatron_roots() + + backends = None + if with_backends_module: + models = types.ModuleType("megatron.core.models") + backends = types.ModuleType("megatron.core.models.backends") + + class _OriginalNormSentinel: + """Sentinel class returned by the stub's ``layer_norm`` when ``rms_norm=False`` — + tests assert identity against this to confirm the patch delegated correctly.""" + + class _LocalSpecProvider: + def layer_norm(self, rms_norm=False, for_qk=False, has_residual=False): + # Echo the kwargs so tests can verify pass-through; the value returned is + # the sentinel class so identity checks work regardless. + return _OriginalNormSentinel + + backends.LocalSpecProvider = _LocalSpecProvider + backends._OriginalNormSentinel = _OriginalNormSentinel # exposed for tests + sys.modules["megatron.core.models"] = models + sys.modules["megatron.core.models.backends"] = backends + megatron_core.models = models + models.backends = backends + + transformer_block = None + if with_transformer_block_module: + transformer = types.ModuleType("megatron.core.transformer") + transformer_block = types.ModuleType("megatron.core.transformer.transformer_block") + torch_norm_mod = types.ModuleType("megatron.core.transformer.torch_norm") + + class _OriginalTorchNormInstance: + """Sentinel marker for "the original WrappedTorchNorm was instantiated." The stub + ``WrappedTorchNorm.__new__`` returns one of these; tests assert ``isinstance(...)``.""" + + def __init__(self, hidden_size, eps): + self.hidden_size = hidden_size + self.eps = eps + + class _WrappedTorchNorm: + """Stub of ``megatron.core.transformer.torch_norm.WrappedTorchNorm``.""" + + def __new__(cls, config=None, hidden_size=None, eps=1e-5, **kwargs): + return _OriginalTorchNormInstance(hidden_size=hidden_size, eps=eps) + + torch_norm_mod.WrappedTorchNorm = _WrappedTorchNorm + + if layer_norm_is_wrapped_torch_norm: + transformer_block.LayerNormImpl = _WrappedTorchNorm + else: + + class _SomeOtherNorm: + """Stand-in for TE / Apex LN — block-level patch should leave this alone.""" + + transformer_block.LayerNormImpl = _SomeOtherNorm + + # Expose sentinels for test assertions. + transformer_block._OriginalTorchNormInstance = _OriginalTorchNormInstance + transformer_block._WrappedTorchNorm = _WrappedTorchNorm + + sys.modules["megatron.core.transformer"] = transformer + sys.modules["megatron.core.transformer.transformer_block"] = transformer_block + sys.modules["megatron.core.transformer.torch_norm"] = torch_norm_mod + megatron_core.transformer = transformer + transformer.transformer_block = transformer_block + transformer.torch_norm = torch_norm_mod + + return backends, transformer_block + + +def _uninstall_fake_megatron(): + """Tear down every stub module installed by either installer.""" + for mod in [ + # CE side + "megatron.core.parallel_state", + "megatron.core.fusions.fused_cross_entropy", + "megatron.core.fusions", + "megatron.core.tensor_parallel.cross_entropy", + "megatron.core.tensor_parallel", + # RMSNorm side + "megatron.core.models.backends", + "megatron.core.models", + "megatron.core.transformer.transformer_block", + "megatron.core.transformer.torch_norm", + "megatron.core.transformer", + # Shared roots + "megatron.core", + "megatron", + ]: + sys.modules.pop(mod, None) + + +@pytest.fixture +def fake_megatron_ce(): + fused_ce, unfused_ce = _install_fake_megatron_ce(tp_size=1) + try: + yield fused_ce, unfused_ce + finally: + _uninstall_fake_megatron() + + +@pytest.fixture +def fake_megatron_rms_norm(): + backends, transformer_block = _install_fake_megatron_rms_norm() + try: + yield backends, transformer_block + finally: + _uninstall_fake_megatron() + + +# =========================================================================== +# 2. Cross-entropy patch tests +# =========================================================================== + + +# --------------------------------------------------------------------------- +# 2.1 Both CE symbols get replaced. +# --------------------------------------------------------------------------- + + +def test_patch_replaces_fused_symbol(fake_megatron_ce): + fused_ce, _ = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + original = fused_ce.fused_vocab_parallel_cross_entropy + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + assert fused_ce.fused_vocab_parallel_cross_entropy is not original + assert fused_ce.fused_vocab_parallel_cross_entropy.__name__ == "liger_fused_vocab_parallel_cross_entropy" + + +def test_patch_replaces_unfused_symbol(fake_megatron_ce): + _, unfused_ce = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + original = unfused_ce.vocab_parallel_cross_entropy + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + assert unfused_ce.vocab_parallel_cross_entropy is not original + assert unfused_ce.vocab_parallel_cross_entropy.__name__ == "liger_vocab_parallel_cross_entropy" + + +def test_patch_replaces_both_fused_and_unfused_symbols_in_one_call(fake_megatron_ce): + """A single ``cross_entropy=True`` call must replace both Megatron CE paths.""" + fused_ce, unfused_ce = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + assert fused_ce.fused_vocab_parallel_cross_entropy.__name__ == "liger_fused_vocab_parallel_cross_entropy" + assert unfused_ce.vocab_parallel_cross_entropy.__name__ == "liger_vocab_parallel_cross_entropy" + + +def test_patch_with_cross_entropy_false_leaves_ce_symbols_untouched(fake_megatron_ce): + """Default ``cross_entropy=False`` must not touch the CE symbols even if the call runs.""" + fused_ce, unfused_ce = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + fused_before = fused_ce.fused_vocab_parallel_cross_entropy + unfused_before = unfused_ce.vocab_parallel_cross_entropy + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=False) + + assert fused_ce.fused_vocab_parallel_cross_entropy is fused_before + assert unfused_ce.vocab_parallel_cross_entropy is unfused_before + + +def test_patch_is_idempotent_for_both_symbols(fake_megatron_ce): + """Calling ``apply_liger_kernel_to_megatron(cross_entropy=True)`` twice must not stack + wrappers — the sentinel attribute guards against double-patching.""" + fused_ce, unfused_ce = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + fused_first = fused_ce.fused_vocab_parallel_cross_entropy + unfused_first = unfused_ce.vocab_parallel_cross_entropy + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + # Same identity → no stacked wrapping. + assert fused_ce.fused_vocab_parallel_cross_entropy is fused_first + assert unfused_ce.vocab_parallel_cross_entropy is unfused_first + # __wrapped__ still references the original Megatron symbol, not the first Liger wrapper. + assert fused_first.__wrapped__.__name__ == "original_fused_vocab_parallel_cross_entropy" + assert unfused_first.__wrapped__.__name__ == "original_vocab_parallel_cross_entropy" + + +def test_patch_fused_wrapper_passes_tp_group_through(fake_megatron_ce): + """The fused wrapper closure must forward ``tp_group`` to the underlying class. + + We swap the CE class for a recording fake so the call doesn't need CUDA — just + confirms ``tp_group`` reaches the class's ``__call__``.""" + import torch + + fused_ce, _ = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + from liger_kernel.megatron import cross_entropy as ce_mod + + captured = {} + + class _FakeCE: + def __init__(self, ignore_index=-100, label_smoothing=0.0, reduction="none"): + pass + + def __call__(self, logits, target, tp_group=None): + captured["tp_group"] = tp_group + captured["shape"] = tuple(logits.shape) + return torch.zeros(logits.shape[:2]) + + class _FakeGroup: + def size(self): + return 1 + + with patch.object(ce_mod, "LigerMegatronCrossEntropy", _FakeCE): + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + logits = torch.zeros(2, 1, 4) + target = torch.zeros(2, 1, dtype=torch.long) + group = _FakeGroup() + fused_ce.fused_vocab_parallel_cross_entropy(logits, target, group) + + assert captured["tp_group"] is group + assert captured["shape"] == (2, 1, 4) + + +# --------------------------------------------------------------------------- +# 2.2 CE TP-1 guard. +# --------------------------------------------------------------------------- + + +def test_patch_raises_on_tp_greater_than_one(): + _install_fake_megatron_ce(tp_size=2) + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(RuntimeError, match="tensor_model_parallel_size=1"): + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + finally: + _uninstall_fake_megatron() + + +def test_patch_defers_tp_check_when_parallel_state_not_initialized(): + """If get_tensor_model_parallel_world_size() raises, patch should still succeed.""" + fused_ce, unfused_ce = _install_fake_megatron_ce(tp_size=1) + + def raising_tp_size(): + raise AssertionError("parallel_state not initialized") + + sys.modules["megatron.core.parallel_state"].get_tensor_model_parallel_world_size = raising_tp_size + + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + assert fused_ce.fused_vocab_parallel_cross_entropy.__name__ == "liger_fused_vocab_parallel_cross_entropy" + assert unfused_ce.vocab_parallel_cross_entropy.__name__ == "liger_vocab_parallel_cross_entropy" + finally: + _uninstall_fake_megatron() + + +# --------------------------------------------------------------------------- +# 2.3 CE missing-megatron / missing-symbol errors. +# --------------------------------------------------------------------------- + + +def test_patch_raises_when_megatron_not_installed(): + _uninstall_fake_megatron() + real_import = __builtins__["__import__"] if isinstance(__builtins__, dict) else __builtins__.__import__ + + def blocking_import(name, *args, **kwargs): + if name == "megatron" or name.startswith("megatron."): + raise ImportError(f"No module named {name!r}") + return real_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=blocking_import): + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(ImportError, match="requires megatron-core"): + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + +def test_patch_raises_when_fused_symbol_missing(): + _install_fake_megatron_ce(tp_size=1, with_fused_symbol=False) + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(ImportError, match="symbol path may have changed"): + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + finally: + _uninstall_fake_megatron() + + +def test_patch_raises_when_unfused_symbol_missing(): + """Symmetric to the fused-missing case; the unfused module exists but its symbol doesn't.""" + _install_fake_megatron_ce(tp_size=1, with_unfused_symbol=False) + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + with pytest.raises(ImportError, match="symbol path may have changed"): + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + finally: + _uninstall_fake_megatron() + + +# --------------------------------------------------------------------------- +# 2.4 CE class-default construction + runtime label_smoothing override on the unfused path. +# --------------------------------------------------------------------------- + + +def test_patch_constructs_ce_with_class_defaults(fake_megatron_ce): + """The public ``apply_liger_kernel_to_megatron`` API exposes no CE-specific kwargs; + the patch must therefore construct ``LigerMegatronCrossEntropy`` with class defaults. + + This intentionally matches Megatron's native fused-CE behavior (no ignore_index, no + label_smoothing). Callers needing custom config use ``LigerMegatronCrossEntropy`` + directly (Mode 2).""" + from liger_kernel.megatron import apply_liger_kernel_to_megatron + from liger_kernel.megatron import cross_entropy as ce_mod + + captured = [] + real_ctor = ce_mod.LigerMegatronCrossEntropy.__init__ + + def recording_init(self, ignore_index=-100, label_smoothing=0.0, reduction="none"): + captured.append( + { + "ignore_index": ignore_index, + "label_smoothing": label_smoothing, + "reduction": reduction, + } + ) + real_ctor(self, ignore_index=ignore_index, label_smoothing=label_smoothing, reduction=reduction) + + with patch.object(ce_mod.LigerMegatronCrossEntropy, "__init__", recording_init): + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + # Fused wrapper builds 1 instance; unfused wrapper builds 1 default instance. + assert len(captured) >= 2 + for entry in captured: + assert entry == {"ignore_index": -100, "label_smoothing": 0.0, "reduction": "none"} + + +def test_unfused_wrapper_honors_runtime_label_smoothing(fake_megatron_ce): + """The unfused signature takes ``label_smoothing`` as a runtime arg; the wrapper must honor it. + + When the caller passes a non-default value, the wrapper constructs a fresh + ``LigerMegatronCrossEntropy`` with that value rather than reusing the patch-time default. + + We verify this by replacing the class with a recording fake **before** calling + ``apply_liger_kernel_to_megatron`` — the patch helper does a fresh + ``from … import LigerMegatronCrossEntropy`` so the closure captures the fake. + """ + import torch + + _, unfused_ce = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + from liger_kernel.megatron import cross_entropy as ce_mod + + constructed = [] + + class _FakeCE: + def __init__(self, ignore_index=-100, label_smoothing=0.0, reduction="none"): + constructed.append(label_smoothing) + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def __call__(self, logits, target, tp_group=None): + # Skip Liger kernel — just return a CPU-friendly tensor in the right shape. + return torch.zeros(logits.shape[:2]) + + with patch.object(ce_mod, "LigerMegatronCrossEntropy", _FakeCE): + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + # Reset the recorder to focus on calls triggered by the next line. + constructed.clear() + logits = torch.zeros(2, 1, 4) + target = torch.zeros(2, 1, dtype=torch.long) + unfused_ce.vocab_parallel_cross_entropy(logits, target, label_smoothing=0.3) + + assert constructed == [0.3], ( + f"unfused wrapper should construct one fresh instance with the runtime override; got: {constructed}" + ) + + +def test_unfused_wrapper_uses_default_when_caller_does_not_pass_label_smoothing(fake_megatron_ce): + """When the caller doesn't pass ``label_smoothing``, the wrapper reuses the patch-time + ``default_ce`` instance — no fresh allocation per call.""" + import torch + + _, unfused_ce = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + from liger_kernel.megatron import cross_entropy as ce_mod + + constructed = [] + + class _FakeCE: + def __init__(self, ignore_index=-100, label_smoothing=0.0, reduction="none"): + constructed.append(label_smoothing) + + def __call__(self, logits, target, tp_group=None): + return torch.zeros(logits.shape[:2]) + + with patch.object(ce_mod, "LigerMegatronCrossEntropy", _FakeCE): + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + constructed.clear() + logits = torch.zeros(2, 1, 4) + target = torch.zeros(2, 1, dtype=torch.long) + # No label_smoothing arg — wrapper reuses the default_ce instance. + unfused_ce.vocab_parallel_cross_entropy(logits, target) + # Second positional call also without label_smoothing — still no new construction. + unfused_ce.vocab_parallel_cross_entropy(logits, target) + + assert constructed == [], f"default-path calls must reuse default_ce — no fresh instances; got: {constructed}" + + +def test_unfused_wrapper_honors_explicit_zero_label_smoothing(fake_megatron_ce): + """Explicit ``label_smoothing=0.0`` at call time must be honored verbatim, not silently + replaced by the patch-time default. + + This guards against the bug where the wrapper used ``if label_smoothing == 0.0:`` to + detect "caller passed nothing" — that conflated "caller didn't pass" with "caller + explicitly asked for 0.0" and corrupted loss math for Megatron callers that pass 0.0 + positionally.""" + import torch + + _, unfused_ce = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + from liger_kernel.megatron import cross_entropy as ce_mod + + constructed = [] + + class _FakeCE: + def __init__(self, ignore_index=-100, label_smoothing=0.0, reduction="none"): + constructed.append(label_smoothing) + + def __call__(self, logits, target, tp_group=None): + return torch.zeros(logits.shape[:2]) + + with patch.object(ce_mod, "LigerMegatronCrossEntropy", _FakeCE): + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + constructed.clear() + logits = torch.zeros(2, 1, 4) + target = torch.zeros(2, 1, dtype=torch.long) + + # Explicit positional 0.0 — must construct a fresh instance with 0.0. + unfused_ce.vocab_parallel_cross_entropy(logits, target, 0.0) + + assert constructed == [0.0], ( + f"explicit label_smoothing=0.0 at call time must be honored verbatim; got: {constructed}" + ) + + +# =========================================================================== +# 3. RMSNorm patch tests +# =========================================================================== +# Liger's RMSNorm patch displaces two Megatron symbols: +# +# - ``megatron.core.models.backends.LocalSpecProvider.layer_norm`` (a method) +# fills the per-layer norm slots inside ``TransformerLayerSubmodules``. +# - ``megatron.core.transformer.transformer_block.LayerNormImpl`` (a class) +# fills the block-level ``final_layernorm`` slot when the caller passes a +# per-layer spec rather than a ``TransformerBlockSubmodules``. +# +# The block-level patch only displaces the pure-torch ``WrappedTorchNorm`` +# fallback — users on TE / Apex chose those deliberately and Liger should not +# undo their fusions. Tests below cover both replacement targets, dispatch +# behavior, idempotency, and the "skip on non-WrappedTorchNorm" contract. + + +# --------------------------------------------------------------------------- +# 3.1 Both RMSNorm symbols get replaced. +# --------------------------------------------------------------------------- + + +def test_rms_norm_patch_replaces_local_spec_provider_layer_norm(fake_megatron_rms_norm): + backends, _ = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + original = backends.LocalSpecProvider.layer_norm + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + assert backends.LocalSpecProvider.layer_norm is not original + assert backends.LocalSpecProvider.layer_norm.__name__ == "patched_layer_norm" + # Marker + __wrapped__ chain back to the original method so future un-patching + # or introspection can find it. + assert getattr(backends.LocalSpecProvider.layer_norm, "__liger_patched__", False) is True + assert backends.LocalSpecProvider.layer_norm.__wrapped__ is original + + +def test_rms_norm_patch_replaces_transformer_block_layernorm_impl(fake_megatron_rms_norm): + _, transformer_block = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + original = transformer_block.LayerNormImpl + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + assert transformer_block.LayerNormImpl is not original + assert transformer_block.LayerNormImpl.__name__ == "_LigerOrTorchNorm" + assert getattr(transformer_block.LayerNormImpl, "__liger_patched__", False) is True + assert transformer_block.LayerNormImpl.__wrapped__ is original + + +def test_rms_norm_patch_replaces_both_symbols_in_one_call(fake_megatron_rms_norm): + """A single ``rms_norm=True`` call must replace both Megatron RMSNorm paths.""" + backends, transformer_block = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + assert backends.LocalSpecProvider.layer_norm.__name__ == "patched_layer_norm" + assert transformer_block.LayerNormImpl.__name__ == "_LigerOrTorchNorm" + + +def test_rms_norm_patch_with_rms_norm_false_leaves_norm_symbols_untouched(fake_megatron_rms_norm): + """Default ``rms_norm=False`` must not touch the norm symbols even if the call runs.""" + backends, transformer_block = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + lsp_before = backends.LocalSpecProvider.layer_norm + impl_before = transformer_block.LayerNormImpl + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=False) + + assert backends.LocalSpecProvider.layer_norm is lsp_before + assert transformer_block.LayerNormImpl is impl_before + + +def test_rms_norm_patch_is_idempotent(fake_megatron_rms_norm): + """Calling ``apply_liger_kernel_to_megatron(rms_norm=True)`` twice must not stack + wrappers — the sentinel attribute on each patched target guards against double-wrap.""" + backends, transformer_block = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + lsp_first = backends.LocalSpecProvider.layer_norm + impl_first = transformer_block.LayerNormImpl + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + assert backends.LocalSpecProvider.layer_norm is lsp_first + assert transformer_block.LayerNormImpl is impl_first + # __wrapped__ still references the original Megatron symbols, not the first Liger + # wrapper (otherwise the second apply would have chained over the first). + assert lsp_first.__wrapped__.__qualname__.endswith("LocalSpecProvider.layer_norm") + assert impl_first.__wrapped__ is transformer_block._WrappedTorchNorm + + +# --------------------------------------------------------------------------- +# 3.2 RMSNorm dispatch behavior through the patched targets. +# --------------------------------------------------------------------------- + + +def test_rms_norm_patch_local_spec_provider_returns_liger_for_rms_norm_true(fake_megatron_rms_norm): + """When the patched ``layer_norm`` method is called with ``rms_norm=True``, it returns + ``LigerMegatronRMSNorm`` — the actual class binding callers use to construct the per-layer + norms inside ``TransformerLayerSubmodules``.""" + backends, _ = fake_megatron_rms_norm + from liger_kernel.megatron import LigerMegatronRMSNorm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + provider = backends.LocalSpecProvider() + result = provider.layer_norm(rms_norm=True) + assert result is LigerMegatronRMSNorm + + +def test_rms_norm_patch_local_spec_provider_delegates_for_rms_norm_false(fake_megatron_rms_norm): + """When ``rms_norm=False``, the patched method must delegate to the original method — + Liger never touches LayerNorm users.""" + backends, _ = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + provider = backends.LocalSpecProvider() + result = provider.layer_norm(rms_norm=False) + # The stub's original layer_norm returns its _OriginalNormSentinel. + assert result is backends._OriginalNormSentinel + + +def test_rms_norm_patch_transformer_block_routes_rmsnorm_through_liger(fake_megatron_rms_norm): + """The ``_LigerOrTorchNorm`` wrapping class dispatches on ``config.normalization`` — + when it's ``"RMSNorm"``, instantiation returns a ``LigerMegatronRMSNorm`` instance. + + Construction is enough — no kernel is actually invoked, so this runs on CPU.""" + from types import SimpleNamespace + + _, transformer_block = fake_megatron_rms_norm + from liger_kernel.megatron import LigerMegatronRMSNorm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + config = SimpleNamespace( + normalization="RMSNorm", + sequence_parallel=False, + layernorm_zero_centered_gamma=False, + ) + instance = transformer_block.LayerNormImpl(config=config, hidden_size=64, eps=1e-5) + assert isinstance(instance, LigerMegatronRMSNorm) + + +def test_rms_norm_patch_transformer_block_routes_layernorm_through_original(fake_megatron_rms_norm): + """When ``config.normalization`` is not ``"RMSNorm"`` (e.g. ``"LayerNorm"``), the + wrapping class falls back to the original ``WrappedTorchNorm`` so LayerNorm users + keep their existing behavior.""" + from types import SimpleNamespace + + _, transformer_block = fake_megatron_rms_norm + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + + config = SimpleNamespace(normalization="LayerNorm") + result = transformer_block.LayerNormImpl(config=config, hidden_size=64, eps=1e-5) + # The stub's WrappedTorchNorm.__new__ returns a sentinel instance. + assert isinstance(result, transformer_block._OriginalTorchNormInstance) + assert result.hidden_size == 64 + + +# --------------------------------------------------------------------------- +# 3.3 Block-level "skip when not WrappedTorchNorm" contract. +# --------------------------------------------------------------------------- + + +def test_rms_norm_patch_transformer_block_skips_when_layer_norm_is_not_wrapped_torch_norm(): + """If ``transformer_block.LayerNormImpl`` is TE / Apex / anything other than the pure-torch + ``WrappedTorchNorm`` fallback, the block-level patch must be a no-op. Replacing TE's + fused LN+Linear with Liger's standalone RMSNorm would double-norm or skip the norm + entirely; replacing Apex would surprise users who chose it deliberately.""" + _, transformer_block = _install_fake_megatron_rms_norm( + layer_norm_is_wrapped_torch_norm=False, + ) + try: + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + impl_before = transformer_block.LayerNormImpl + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + # Unchanged — patch detected a non-WrappedTorchNorm value and bailed. + assert transformer_block.LayerNormImpl is impl_before + # And the spec-provider patch above DID still apply (it's independent of the block + # path), so we'd see the provider patched even though the block was skipped. + # The fixture didn't install backends — re-fetch from sys.modules. + backends_mod = sys.modules["megatron.core.models.backends"] + assert backends_mod.LocalSpecProvider.layer_norm.__name__ == "patched_layer_norm" + finally: + _uninstall_fake_megatron() + + +# =========================================================================== +# 4. Cross-kernel public-API surface checks +# =========================================================================== +# Mirrors transformers-side ``test_import_from_root`` and +# ``test_apply_liger_kernel_only_passes_valid_kwargs`` patterns. + + +def test_import_from_root(): + """All public Megatron symbols must be reachable from ``liger_kernel.megatron``. + + Mirrors the import-smoke pattern from ``test/transformers/test_monkey_patch.py``: catches + accidental __init__.py removals so the docs' import snippets keep working.""" + try: + from liger_kernel.megatron import LigerMegatronCrossEntropy # noqa: F401 + from liger_kernel.megatron import LigerMegatronRMSNorm # noqa: F401 + from liger_kernel.megatron import apply_liger_kernel_to_megatron # noqa: F401 + except Exception: + pytest.fail("Importing public Megatron symbols from liger_kernel.megatron failed.") + + +def test_public_apply_function_has_no_ce_specific_kwargs(): + """The framework-level patch entry point intentionally hides CE-specific knobs + (ignore_index, label_smoothing, reduction). Catch accidental re-introduction — + Mode-2 callers use ``LigerMegatronCrossEntropy`` directly for that config surface.""" + import inspect + + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + sig = inspect.signature(apply_liger_kernel_to_megatron) + leaked = {"ignore_index", "label_smoothing", "reduction"} & set(sig.parameters) + assert not leaked, ( + f"apply_liger_kernel_to_megatron has re-grown CE-specific kwargs: {sorted(leaked)}. " + f"Those belong on LigerMegatronCrossEntropy, not on the framework patch entry point." + ) + + +# =========================================================================== +# 5. End-to-end integration through the patched CE symbols +# =========================================================================== +# Earlier tests verify symbol identity + stub plumbing; the suite was missing +# the "patch + call with real tensors + check the numbers" coverage. These +# tests install the fake megatron, apply the patch, then invoke the resulting +# wrapper with live torch tensors and compare against ``F.cross_entropy``. +# That's the only way to catch wrapper-math bugs that pass the identity tests. + + +import torch # noqa: E402 (deferred so the no-torch import-smoke tests above are unaffected) +import torch.nn.functional as F # noqa: E402 + +from liger_kernel.utils import infer_device # noqa: E402 +from test.utils import assert_verbose_allclose # noqa: E402 + +_device = infer_device() + + +def _ref_loss_sbv( + logits_sbv: torch.Tensor, target_sb: torch.Tensor, ignore_index: int = -100, label_smoothing: float = 0.0 +) -> torch.Tensor: + """Reference CE for [s, b, v] logits / [s, b] target, returning [s, b].""" + s, b, v = logits_sbv.shape + loss_flat = F.cross_entropy( + logits_sbv.reshape(-1, v).float(), + target_sb.reshape(-1), + reduction="none", + ignore_index=ignore_index, + label_smoothing=label_smoothing, + ) + return loss_flat.reshape(s, b) + + +def test_patched_fused_symbol_computes_correct_loss(fake_megatron_ce): + """End-to-end: install stub megatron, patch, invoke the resulting fused symbol with real + [s, b, v] logits, verify the loss matches ``F.cross_entropy``. Closes the gap between + "patch wired correctly" (existing tests) and "patched function does the right math".""" + fused_ce, _ = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + s, b, v = 16, 2, 1024 + torch.manual_seed(0) + logits = torch.randn(s, b, v, device=_device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=_device, dtype=torch.long) + + ref = _ref_loss_sbv(logits.clone(), target) + # Call through the patched symbol. tp_group=None is what Megatron's + # LanguageModule passes when TP is uninitialized. + got = fused_ce.fused_vocab_parallel_cross_entropy(logits.clone(), target, None) + + assert got.shape == (s, b) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) + + +def test_patched_unfused_symbol_computes_correct_loss(fake_megatron_ce): + """Same as the fused case, but through the unfused symbol — verifies both wrappers + are exercised and exercises the no-label_smoothing default branch (caller doesn't pass).""" + _, unfused_ce = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + s, b, v = 8, 4, 512 + torch.manual_seed(1) + logits = torch.randn(s, b, v, device=_device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=_device, dtype=torch.long) + + ref = _ref_loss_sbv(logits.clone(), target) + # Native unfused signature: (logits, target, label_smoothing=0.0, tp_group=None). + # Pass only positional args the caller normally would. + got = unfused_ce.vocab_parallel_cross_entropy(logits.clone(), target) + + assert got.shape == (s, b) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) + + +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) +def test_patched_unfused_symbol_runtime_label_smoothing_matches_pytorch(fake_megatron_ce, label_smoothing): + """The unfused wrapper's main feature beyond the fused path is honoring a runtime + label_smoothing arg. Verify the resulting loss actually matches + ``F.cross_entropy(..., label_smoothing=...)``, not just that a fresh CE instance is built.""" + _, unfused_ce = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + s, b, v = 8, 2, 256 + torch.manual_seed(2) + logits = torch.randn(s, b, v, device=_device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=_device, dtype=torch.long) + + ref = _ref_loss_sbv(logits.clone(), target, label_smoothing=label_smoothing) + got = unfused_ce.vocab_parallel_cross_entropy( + logits.clone(), + target, + label_smoothing=label_smoothing, + ) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-5, rtol=1e-4) + + +def test_patched_fused_symbol_preserves_gradients(fake_megatron_ce): + """Backward through the patched fused symbol: gradient shape + parity vs. + PyTorch's reference. Liger writes the gradient back into the input buffer, + so verifying ``.grad`` after backward exercises both the reshape contract + and the in-place write.""" + fused_ce, _ = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + s, b, v = 8, 2, 256 + torch.manual_seed(3) + base = torch.randn(s, b, v, device=_device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=_device, dtype=torch.long) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + ref = _ref_loss_sbv(h_ref, target) + got = fused_ce.fused_vocab_parallel_cross_entropy(h_got, target, None) + + ref.sum().backward() + got.sum().backward() + + assert h_got.grad is not None + assert h_got.grad.shape == h_got.shape + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=1e-6, rtol=1e-5) + + +def test_patched_unfused_symbol_preserves_gradients(fake_megatron_ce): + """Symmetric to the fused-gradient test; ensures the closure in + ``_patch_vocab_parallel_cross_entropy`` doesn't break autograd.""" + _, unfused_ce = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + s, b, v = 8, 2, 256 + torch.manual_seed(4) + base = torch.randn(s, b, v, device=_device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=_device, dtype=torch.long) + + h_ref = base.detach().clone().requires_grad_(True) + h_got = base.detach().clone().requires_grad_(True) + ref = _ref_loss_sbv(h_ref, target) + got = unfused_ce.vocab_parallel_cross_entropy(h_got, target) + + ref.sum().backward() + got.sum().backward() + assert_verbose_allclose(h_got.grad.float(), h_ref.grad.float(), atol=1e-6, rtol=1e-5) + + +def test_patched_fused_symbol_default_ignore_index_minus_100(fake_megatron_ce): + """Patch-time defaults: targets containing -100 should be treated as ignored — Liger's + kernel zeros those loss positions, matching ``F.cross_entropy(ignore_index=-100)``. + + Native Megatron's fused CE has no ignore_index concept and would silently produce + garbage on -100; this is one place where Liger is strictly better than the symbol + it replaces, and the test pins that behavior.""" + fused_ce, _ = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + apply_liger_kernel_to_megatron(rms_norm=False, cross_entropy=True) + + s, b, v = 8, 2, 128 + torch.manual_seed(5) + logits = torch.randn(s, b, v, device=_device, dtype=torch.float32) + target = torch.randint(0, v, (s, b), device=_device, dtype=torch.long) + # Plant some -100 sentinel positions. + flat = target.view(-1) + flat[: flat.numel() // 4] = -100 + + ref = _ref_loss_sbv(logits.clone(), target, ignore_index=-100) + got = fused_ce.fused_vocab_parallel_cross_entropy(logits.clone(), target, None) + + # Per-token loss at masked positions should be exactly 0. + mask = (target != -100).float() + assert torch.all(got * (1 - mask) == 0) + assert_verbose_allclose(got.float(), ref.float(), atol=1e-6, rtol=1e-5) + + +def test_rms_norm_only_patch_does_not_touch_ce_symbols(fake_megatron_ce): + """Symmetric to ``test_patch_with_cross_entropy_false_leaves_ce_symbols_untouched``, + but for the opposite split. With ``rms_norm=True, cross_entropy=False`` (RMSNorm + helpers require real megatron and will ImportError on the stub — that's fine, we + only need to confirm the CE symbols are not pre-emptively touched before the RMSNorm + helpers run). Documenting this protects against future apply_… reorderings that would + silently couple the two.""" + fused_ce, unfused_ce = fake_megatron_ce + from liger_kernel.megatron import apply_liger_kernel_to_megatron + + fused_before = fused_ce.fused_vocab_parallel_cross_entropy + unfused_before = unfused_ce.vocab_parallel_cross_entropy + + # RMSNorm helpers do their own megatron import; on the stub they'll raise. Catch + # any exception so the assertion at the end runs regardless — we only care that the + # CE symbols weren't touched. + try: + apply_liger_kernel_to_megatron(rms_norm=True, cross_entropy=False) + except Exception: + pass + + assert fused_ce.fused_vocab_parallel_cross_entropy is fused_before + assert unfused_ce.vocab_parallel_cross_entropy is unfused_before