@@ -616,6 +616,185 @@ fn test_mixed_magnitudes() {
616616 assert_eq ! ( ternary[ 3 ] , 0 , "-0.001 should be 0" ) ;
617617}
618618
619+ // ============================================================================
620+ // 8. Layer Filter Tests (per ADR-017 AD-2)
621+ // ============================================================================
622+
623+ #[ test]
624+ fn test_should_quantize_expert_layers ( ) {
625+ // MoE expert FFN layers (gate_proj, up_proj, down_proj) should be quantized
626+ use super :: LayerMask ;
627+
628+ let layer_mask = LayerMask :: ExpertsOnly ;
629+
630+ assert ! (
631+ should_quantize_layer( "model.layers.0.mlp.gate_proj.weight" , & layer_mask) ,
632+ "gate_proj should be quantized"
633+ ) ;
634+ assert ! (
635+ should_quantize_layer( "model.layers.0.mlp.up_proj.weight" , & layer_mask) ,
636+ "up_proj should be quantized"
637+ ) ;
638+ assert ! (
639+ should_quantize_layer( "model.layers.0.mlp.down_proj.weight" , & layer_mask) ,
640+ "down_proj should be quantized"
641+ ) ;
642+ assert ! (
643+ should_quantize_layer( "model.layers.15.block_sparse_moe.experts.7.w3.weight" , & layer_mask) ,
644+ "Expert w3 (up_proj) should be quantized"
645+ ) ;
646+ }
647+
648+ #[ test]
649+ fn test_should_not_quantize_router ( ) {
650+ // Router and gate layers must remain in FP16 per ADR-017 (AD-2)
651+ use super :: LayerMask ;
652+
653+ let layer_mask = LayerMask :: ExpertsOnly ;
654+
655+ assert ! (
656+ !should_quantize_layer( "model.layers.0.mlp.router.weight" , & layer_mask) ,
657+ "Router should NOT be quantized"
658+ ) ;
659+ assert ! (
660+ !should_quantize_layer( "model.layers.0.block_sparse_moe.gate.weight" , & layer_mask) ,
661+ "MoE gate should NOT be quantized"
662+ ) ;
663+ }
664+
665+ #[ test]
666+ fn test_should_not_quantize_embed ( ) {
667+ // Embeddings and LM head must remain in FP16 per ADR-017 (AD-2)
668+ use super :: LayerMask ;
669+
670+ let layer_mask = LayerMask :: ExpertsOnly ;
671+
672+ assert ! (
673+ !should_quantize_layer( "model.embed_tokens.weight" , & layer_mask) ,
674+ "Embed tokens should NOT be quantized"
675+ ) ;
676+ assert ! (
677+ !should_quantize_layer( "lm_head.weight" , & layer_mask) ,
678+ "LM head should NOT be quantized"
679+ ) ;
680+ assert ! (
681+ !should_quantize_layer( "model.embeddings.word_embeddings" , & layer_mask) ,
682+ "Word embeddings should NOT be quantized"
683+ ) ;
684+ }
685+
686+ #[ test]
687+ fn test_should_not_quantize_norm ( ) {
688+ // Normalization layers must remain in FP16 per ADR-017 (AD-2)
689+ use super :: LayerMask ;
690+
691+ let layer_mask = LayerMask :: ExpertsOnly ;
692+
693+ assert ! (
694+ !should_quantize_layer( "model.layers.0.input_layernorm.weight" , & layer_mask) ,
695+ "Input layernorm should NOT be quantized"
696+ ) ;
697+ assert ! (
698+ !should_quantize_layer( "model.layers.0.post_attention_layernorm.weight" , & layer_mask) ,
699+ "Post-attention layernorm should NOT be quantized"
700+ ) ;
701+ assert ! (
702+ !should_quantize_layer( "model.norm.weight" , & layer_mask) ,
703+ "Final norm should NOT be quantized"
704+ ) ;
705+ assert ! (
706+ !should_quantize_layer( "model.layers.0.self_attn.layer_norm" , & layer_mask) ,
707+ "Self-attention layer_norm should NOT be quantized"
708+ ) ;
709+ }
710+
711+ #[ test]
712+ fn test_layer_mask_all ( ) {
713+ // LayerMask::All should quantize all linear layers except protected ones
714+ use super :: LayerMask ;
715+
716+ let layer_mask = LayerMask :: All ;
717+
718+ // Should quantize attention projections
719+ assert ! (
720+ should_quantize_layer( "model.layers.0.self_attn.q_proj.weight" , & layer_mask) ,
721+ "Query projection should be quantized with LayerMask::All"
722+ ) ;
723+ assert ! (
724+ should_quantize_layer( "model.layers.0.self_attn.k_proj.weight" , & layer_mask) ,
725+ "Key projection should be quantized with LayerMask::All"
726+ ) ;
727+
728+ // Should still protect router/embed/norm
729+ assert ! (
730+ !should_quantize_layer( "model.layers.0.mlp.router.weight" , & layer_mask) ,
731+ "Router should be protected even with LayerMask::All"
732+ ) ;
733+ assert ! (
734+ !should_quantize_layer( "model.embed_tokens.weight" , & layer_mask) ,
735+ "Embeddings should be protected even with LayerMask::All"
736+ ) ;
737+ }
738+
739+ #[ test]
740+ fn test_layer_mask_custom ( ) {
741+ // LayerMask::Custom should match specified patterns only
742+ use super :: LayerMask ;
743+
744+ let layer_mask = LayerMask :: Custom ( vec ! [ "w1" . to_string( ) , "w3" . to_string( ) ] ) ;
745+
746+ assert ! (
747+ should_quantize_layer( "model.layers.0.mlp.experts.0.w1.weight" , & layer_mask) ,
748+ "w1 should match custom pattern"
749+ ) ;
750+ assert ! (
751+ should_quantize_layer( "model.layers.0.mlp.experts.0.w3.weight" , & layer_mask) ,
752+ "w3 should match custom pattern"
753+ ) ;
754+ assert ! (
755+ !should_quantize_layer( "model.layers.0.mlp.experts.0.w2.weight" , & layer_mask) ,
756+ "w2 should NOT match custom pattern"
757+ ) ;
758+ }
759+
760+ /// Helper function for layer filtering logic (matches ADR-017 AD-2 specification)
761+ fn should_quantize_layer ( layer_name : & str , mask : & super :: LayerMask ) -> bool {
762+ use super :: LayerMask ;
763+
764+ match mask {
765+ LayerMask :: ExpertsOnly => {
766+ // Quantize MoE expert FFN layers only (gate_proj, up_proj, down_proj, w1, w2, w3)
767+ // Exclude: router, gate, embed, norm, lm_head
768+ let is_expert_ffn = layer_name. contains ( "gate_proj" )
769+ || layer_name. contains ( "up_proj" )
770+ || layer_name. contains ( "down_proj" )
771+ || ( layer_name. contains ( "experts" )
772+ && ( layer_name. contains ( ".w1." ) || layer_name. contains ( ".w2." ) || layer_name. contains ( ".w3." ) ) ) ;
773+
774+ let is_protected = layer_name. contains ( "router" )
775+ || layer_name. contains ( ".gate." ) // MoE gate (not gate_proj)
776+ || layer_name. contains ( "embed" )
777+ || layer_name. contains ( "lm_head" )
778+ || layer_name. contains ( "norm" ) ;
779+
780+ is_expert_ffn && !is_protected
781+ }
782+ LayerMask :: All => {
783+ // Quantize all linear layers except protected ones
784+ let is_protected = layer_name. contains ( "router" )
785+ || layer_name. contains ( "embed" )
786+ || layer_name. contains ( "lm_head" )
787+ || layer_name. contains ( "norm" ) ;
788+
789+ !is_protected
790+ }
791+ LayerMask :: Custom ( patterns) => {
792+ // Match any custom pattern
793+ patterns. iter ( ) . any ( |p| layer_name. contains ( p) )
794+ }
795+ }
796+ }
797+
619798// ============================================================================
620799// Helper Functions
621800// ============================================================================
0 commit comments