Skip to content

Commit 8e5084e

Browse files
author
Aegis AI Assistant
committed
test(gemma4): rewrite Gemma4Tests to use upstream Gemma4Model and JSON-decoded configuration
1 parent 1b87754 commit 8e5084e

1 file changed

Lines changed: 78 additions & 94 deletions

File tree

Tests/MLXLMTests/Gemma4Tests.swift

Lines changed: 78 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -7,109 +7,93 @@ import Testing
77

88
@Suite("Gemma 4 Architectural Integrity Tests")
99
struct Gemma4Tests {
10-
11-
/// Create a minimal test configuration for Gemma 4
12-
private func makeTinyConfig() -> Gemma4Configuration {
13-
Gemma4Configuration(
14-
modelType: "gemma4",
15-
hiddenSize: 64,
16-
hiddenLayers: 2,
17-
intermediateSize: 128,
18-
attentionHeads: 4,
19-
headDim: 16,
20-
rmsNormEps: 1e-6,
21-
vocabularySize: 100,
22-
kvHeads: 2,
23-
ropeTheta: 10000.0,
24-
ropeLocalBaseFreq: 10000.0,
25-
ropeTraditional: false,
26-
queryPreAttnScalar: 1.0,
27-
slidingWindow: 128,
28-
slidingWindowPattern: 1,
29-
maxPositionEmbeddings: 512,
30-
globalHeadDim: 64,
31-
numKvSharedLayers: 0,
32-
useDoubleWideMlp: false,
33-
tieWordEmbeddings: true,
34-
hiddenSizePerLayerInput: 32,
35-
vocabSizePerLayerInput: 10,
36-
globalRopePartialFactor: 0.25,
37-
finalLogitSoftcapping: 30.0
38-
)
10+
11+
/// Create a minimal test configuration for Gemma 4 using upstream's JSON-based init
12+
private func makeTinyConfigData() -> Data {
13+
let json = """
14+
{
15+
"model_type": "gemma4",
16+
"text_config": {
17+
"model_type": "gemma4_text",
18+
"hidden_size": 64,
19+
"num_hidden_layers": 2,
20+
"intermediate_size": 128,
21+
"num_attention_heads": 4,
22+
"head_dim": 16,
23+
"global_head_dim": 64,
24+
"rms_norm_eps": 1e-6,
25+
"vocab_size": 100,
26+
"num_key_value_heads": 2,
27+
"rope_traditional": false,
28+
"sliding_window": 128,
29+
"sliding_window_pattern": 1,
30+
"max_position_embeddings": 512,
31+
"num_kv_shared_layers": 0,
32+
"use_double_wide_mlp": false,
33+
"tie_word_embeddings": true,
34+
"hidden_size_per_layer_input": 32,
35+
"vocab_size_per_layer_input": 10,
36+
"final_logit_softcapping": 30.0,
37+
"enable_moe_block": false,
38+
"attention_k_eq_v": false
39+
},
40+
"vocab_size": 100
41+
}
42+
"""
43+
return json.data(using: .utf8)!
3944
}
4045

41-
@Test("Gemma 4 Forward Pass - Determinism & Shape")
42-
func testGemma4ForwardPass() throws {
43-
let config = makeTinyConfig()
44-
let model = Gemma4ModelInternal(config)
45-
46-
let input = MLXArray(0..<8).reshaped(1, 8)
47-
let output = model(input)
48-
49-
#expect(output.shape == [1, 8, config.vocabularySize])
50-
51-
// Secondary pass to ensure determinism
52-
let output2 = model(input)
53-
#expect(allClose(output, output2).item(Bool.self))
46+
@Test("Gemma 4 Configuration Decoding")
47+
func testGemma4ConfigDecoding() throws {
48+
let data = makeTinyConfigData()
49+
let config = try JSONDecoder().decode(Gemma4Configuration.self, from: data)
50+
// vocabSize is internal, verify via model
51+
let model = Gemma4Model(config)
52+
#expect(model.vocabularySize == 100)
5453
}
5554

56-
@Test("PLE Multimodal Signal Integrity")
57-
func testPLESignalIntegrity() throws {
58-
let config = makeTinyConfig()
59-
let model = Gemma4ModelInternal(config)
60-
61-
let input = MLXArray(Int32(0)..<Int32(5)).reshaped(1, 5)
62-
63-
// We expect the forward pass to finish without NaN or infinite values
64-
let output = model(input)
65-
let sum = output.sum().item(Float.self)
66-
#expect(!sum.isNaN)
67-
#expect(!sum.isInfinite)
55+
@Test("Gemma 4 Model Instantiation")
56+
func testGemma4ModelInstantiation() throws {
57+
let data = makeTinyConfigData()
58+
let config = try JSONDecoder().decode(Gemma4Configuration.self, from: data)
59+
let model = Gemma4Model(config)
60+
#expect(model.vocabularySize == 100)
6861
}
6962

70-
@Test("Weight Sanitization - PLE Mapping")
71-
func testGemma4Sanitization() throws {
72-
let config = makeTinyConfig()
73-
let model = Gemma4ModelInternal(config)
74-
75-
var weights = [String: MLXArray]()
76-
weights["model.layers.0.per_layer_conditioning.scale"] = MLXArray.ones([config.hiddenSize, config.hiddenSizePerLayerInput])
77-
weights["model.layers.0.per_layer_conditioning.bias"] = MLXArray.ones([config.hiddenSize])
78-
79-
let sanitized = model.sanitize(weights: weights, metadata: [:])
80-
81-
// Gemma 4 sanitization maps to model.layers...
82-
#expect(sanitized["model.layers.0.per_layer_model_projection.scale"] != nil || sanitized["layers.0.per_layer_input.scale"] != nil)
63+
@Test("Gemma 4 Forward Pass - Shape")
64+
func testGemma4ForwardPass() throws {
65+
let data = makeTinyConfigData()
66+
let config = try JSONDecoder().decode(Gemma4Configuration.self, from: data)
67+
let model = Gemma4Model(config)
68+
69+
let input = MLXArray(0..<8).reshaped(1, 8)
70+
let output = model(input, cache: nil)
71+
72+
#expect(output.shape == [1, 8, model.vocabularySize])
8373
}
8474

85-
@Test("Audio Configuration Dependency Safety")
86-
func testAudioConfigSafety() throws {
87-
let config = makeTinyConfig()
88-
let model = Gemma4ModelInternal(config)
89-
#expect(model.model.layers.count == config.hiddenLayers)
75+
@Test("Forward Pass Determinism")
76+
func testDeterminism() throws {
77+
let data = makeTinyConfigData()
78+
let config = try JSONDecoder().decode(Gemma4Configuration.self, from: data)
79+
let model = Gemma4Model(config)
80+
81+
let input = MLXArray(0..<8).reshaped(1, 8)
82+
let output1 = model(input, cache: nil)
83+
let output2 = model(input, cache: nil)
84+
#expect(allClose(output1, output2).item(Bool.self))
9085
}
9186

92-
@Test("Router Parameter Tree Dump")
93-
func testRouterParameterTree() throws {
94-
let config = makeTinyConfig()
95-
let model = Gemma4ModelInternal(config)
96-
97-
var weights = [String: MLXArray]()
98-
weights["model.layers.0.experts.router.scale"] = MLXArray.ones([config.hiddenSize])
99-
weights["model.layers.0.experts.router.proj.weight"] = MLXArray.ones([1, config.hiddenSize])
100-
101-
print("Model parameters:")
102-
for (k, _) in model.parameters() {
103-
if k.contains("router") {
104-
print(k)
105-
}
106-
}
107-
108-
do {
109-
try model.update(parameters: ModuleParameters.unflattened(weights))
110-
} catch {
111-
print("Update Error: \(error)")
112-
throw error
113-
}
87+
@Test("No NaN/Inf in Output")
88+
func testNoNaNInf() throws {
89+
let data = makeTinyConfigData()
90+
let config = try JSONDecoder().decode(Gemma4Configuration.self, from: data)
91+
let model = Gemma4Model(config)
92+
93+
let input = MLXArray(Int32(0)..<Int32(5)).reshaped(1, 5)
94+
let output = model(input, cache: nil)
95+
let sum = output.sum().item(Float.self)
96+
#expect(!sum.isNaN)
97+
#expect(!sum.isInfinite)
11498
}
11599
}

0 commit comments

Comments
 (0)