@@ -7,109 +7,93 @@ import Testing
77
88@Suite ( " Gemma 4 Architectural Integrity Tests " )
99struct Gemma4Tests {
10-
11- /// Create a minimal test configuration for Gemma 4
12- private func makeTinyConfig( ) -> Gemma4Configuration {
13- Gemma4Configuration (
14- modelType: " gemma4 " ,
15- hiddenSize: 64 ,
16- hiddenLayers: 2 ,
17- intermediateSize: 128 ,
18- attentionHeads: 4 ,
19- headDim: 16 ,
20- rmsNormEps: 1e-6 ,
21- vocabularySize: 100 ,
22- kvHeads: 2 ,
23- ropeTheta: 10000.0 ,
24- ropeLocalBaseFreq: 10000.0 ,
25- ropeTraditional: false ,
26- queryPreAttnScalar: 1.0 ,
27- slidingWindow: 128 ,
28- slidingWindowPattern: 1 ,
29- maxPositionEmbeddings: 512 ,
30- globalHeadDim: 64 ,
31- numKvSharedLayers: 0 ,
32- useDoubleWideMlp: false ,
33- tieWordEmbeddings: true ,
34- hiddenSizePerLayerInput: 32 ,
35- vocabSizePerLayerInput: 10 ,
36- globalRopePartialFactor: 0.25 ,
37- finalLogitSoftcapping: 30.0
38- )
10+
11+ /// Create a minimal test configuration for Gemma 4 using upstream's JSON-based init
12+ private func makeTinyConfigData( ) -> Data {
13+ let json = """
14+ {
15+ " model_type " : " gemma4 " ,
16+ " text_config " : {
17+ " model_type " : " gemma4_text " ,
18+ " hidden_size " : 64,
19+ " num_hidden_layers " : 2,
20+ " intermediate_size " : 128,
21+ " num_attention_heads " : 4,
22+ " head_dim " : 16,
23+ " global_head_dim " : 64,
24+ " rms_norm_eps " : 1e-6,
25+ " vocab_size " : 100,
26+ " num_key_value_heads " : 2,
27+ " rope_traditional " : false,
28+ " sliding_window " : 128,
29+ " sliding_window_pattern " : 1,
30+ " max_position_embeddings " : 512,
31+ " num_kv_shared_layers " : 0,
32+ " use_double_wide_mlp " : false,
33+ " tie_word_embeddings " : true,
34+ " hidden_size_per_layer_input " : 32,
35+ " vocab_size_per_layer_input " : 10,
36+ " final_logit_softcapping " : 30.0,
37+ " enable_moe_block " : false,
38+ " attention_k_eq_v " : false
39+ },
40+ " vocab_size " : 100
41+ }
42+ """
43+ return json. data ( using: . utf8) !
3944 }
4045
41- @Test ( " Gemma 4 Forward Pass - Determinism & Shape " )
42- func testGemma4ForwardPass( ) throws {
43- let config = makeTinyConfig ( )
44- let model = Gemma4ModelInternal ( config)
45-
46- let input = MLXArray ( 0 ..< 8 ) . reshaped ( 1 , 8 )
47- let output = model ( input)
48-
49- #expect( output. shape == [ 1 , 8 , config. vocabularySize] )
50-
51- // Secondary pass to ensure determinism
52- let output2 = model ( input)
53- #expect( allClose ( output, output2) . item ( Bool . self) )
46+ @Test ( " Gemma 4 Configuration Decoding " )
47+ func testGemma4ConfigDecoding( ) throws {
48+ let data = makeTinyConfigData ( )
49+ let config = try JSONDecoder ( ) . decode ( Gemma4Configuration . self, from: data)
50+ // vocabSize is internal, verify via model
51+ let model = Gemma4Model ( config)
52+ #expect( model. vocabularySize == 100 )
5453 }
5554
56- @Test ( " PLE Multimodal Signal Integrity " )
57- func testPLESignalIntegrity( ) throws {
58- let config = makeTinyConfig ( )
59- let model = Gemma4ModelInternal ( config)
60-
61- let input = MLXArray ( Int32 ( 0 ) ..< Int32 ( 5 ) ) . reshaped ( 1 , 5 )
62-
63- // We expect the forward pass to finish without NaN or infinite values
64- let output = model ( input)
65- let sum = output. sum ( ) . item ( Float . self)
66- #expect( !sum. isNaN)
67- #expect( !sum. isInfinite)
55+ @Test ( " Gemma 4 Model Instantiation " )
56+ func testGemma4ModelInstantiation( ) throws {
57+ let data = makeTinyConfigData ( )
58+ let config = try JSONDecoder ( ) . decode ( Gemma4Configuration . self, from: data)
59+ let model = Gemma4Model ( config)
60+ #expect( model. vocabularySize == 100 )
6861 }
6962
70- @Test ( " Weight Sanitization - PLE Mapping " )
71- func testGemma4Sanitization( ) throws {
72- let config = makeTinyConfig ( )
73- let model = Gemma4ModelInternal ( config)
74-
75- var weights = [ String: MLXArray] ( )
76- weights [ " model.layers.0.per_layer_conditioning.scale " ] = MLXArray . ones ( [ config. hiddenSize, config. hiddenSizePerLayerInput] )
77- weights [ " model.layers.0.per_layer_conditioning.bias " ] = MLXArray . ones ( [ config. hiddenSize] )
78-
79- let sanitized = model. sanitize ( weights: weights, metadata: [ : ] )
80-
81- // Gemma 4 sanitization maps to model.layers...
82- #expect( sanitized [ " model.layers.0.per_layer_model_projection.scale " ] != nil || sanitized [ " layers.0.per_layer_input.scale " ] != nil )
63+ @Test ( " Gemma 4 Forward Pass - Shape " )
64+ func testGemma4ForwardPass( ) throws {
65+ let data = makeTinyConfigData ( )
66+ let config = try JSONDecoder ( ) . decode ( Gemma4Configuration . self, from: data)
67+ let model = Gemma4Model ( config)
68+
69+ let input = MLXArray ( 0 ..< 8 ) . reshaped ( 1 , 8 )
70+ let output = model ( input, cache: nil )
71+
72+ #expect( output. shape == [ 1 , 8 , model. vocabularySize] )
8373 }
8474
85- @Test ( " Audio Configuration Dependency Safety " )
86- func testAudioConfigSafety( ) throws {
87- let config = makeTinyConfig ( )
88- let model = Gemma4ModelInternal ( config)
89- #expect( model. model. layers. count == config. hiddenLayers)
75+ @Test ( " Forward Pass Determinism " )
76+ func testDeterminism( ) throws {
77+ let data = makeTinyConfigData ( )
78+ let config = try JSONDecoder ( ) . decode ( Gemma4Configuration . self, from: data)
79+ let model = Gemma4Model ( config)
80+
81+ let input = MLXArray ( 0 ..< 8 ) . reshaped ( 1 , 8 )
82+ let output1 = model ( input, cache: nil )
83+ let output2 = model ( input, cache: nil )
84+ #expect( allClose ( output1, output2) . item ( Bool . self) )
9085 }
9186
92- @Test ( " Router Parameter Tree Dump " )
93- func testRouterParameterTree( ) throws {
94- let config = makeTinyConfig ( )
95- let model = Gemma4ModelInternal ( config)
96-
97- var weights = [ String: MLXArray] ( )
98- weights [ " model.layers.0.experts.router.scale " ] = MLXArray . ones ( [ config. hiddenSize] )
99- weights [ " model.layers.0.experts.router.proj.weight " ] = MLXArray . ones ( [ 1 , config. hiddenSize] )
100-
101- print ( " Model parameters: " )
102- for (k, _) in model. parameters ( ) {
103- if k. contains ( " router " ) {
104- print ( k)
105- }
106- }
107-
108- do {
109- try model. update ( parameters: ModuleParameters . unflattened ( weights) )
110- } catch {
111- print ( " Update Error: \( error) " )
112- throw error
113- }
87+ @Test ( " No NaN/Inf in Output " )
88+ func testNoNaNInf( ) throws {
89+ let data = makeTinyConfigData ( )
90+ let config = try JSONDecoder ( ) . decode ( Gemma4Configuration . self, from: data)
91+ let model = Gemma4Model ( config)
92+
93+ let input = MLXArray ( Int32 ( 0 ) ..< Int32 ( 5 ) ) . reshaped ( 1 , 5 )
94+ let output = model ( input, cache: nil )
95+ let sum = output. sum ( ) . item ( Float . self)
96+ #expect( !sum. isNaN)
97+ #expect( !sum. isInfinite)
11498 }
11599}
0 commit comments