Skip to content

Commit c66df12

Browse files
author
Aegis-AI
committed
Fix Swift compiler warnings and refine MTP output2D scatter logic
1 parent e3fb72a commit c66df12

5 files changed

Lines changed: 22 additions & 7 deletions

File tree

Libraries/MLXLLM/Models/Gemma4Text.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1116,7 +1116,7 @@ public class Gemma4AssistantModel: Module, LLMModel, DualModelMTP, KVCacheDimens
11161116
// Use mlx scatter via the __setitem__ approach:
11171117
let scatterIdx2D = selectedCanonicalShaped.reshaped([B * S, totalCandidates]).asType(.int32)
11181118
let selectedLogits2D = selectedLogits.reshaped([B * S, totalCandidates])
1119-
var output2D = output.reshaped([B * S, vocabSize])
1119+
let output2D = output.reshaped([B * S, vocabSize])
11201120
let rowIndices = MLXArray.arange(B * S).asType(.int32).reshaped([B * S, 1])
11211121
output2D[rowIndices, scatterIdx2D] = selectedLogits2D
11221122
output = output2D.reshaped([B, S, vocabSize])

Libraries/MLXLMCommon/Load.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,10 @@ public func loadWeights(
126126
let allPrefixes = ["", "model.", "language_model.", "model.language_model."]
127127
let candidates = [expert0Name, stripped0Name, strippedMtpName] + allPrefixes.map { $0 + stripped0Name } + allPrefixes.map { $0 + strippedMtpName }
128128
var foundUnstacked = false
129-
var matchedCandidate = ""
130129

131130
for candidate in candidates {
132131
if ExpertStreamerManager.shared?.getFile(for: candidate) != nil {
133132
foundUnstacked = true
134-
matchedCandidate = candidate
135133
var map = [Int: (path: String, tensorName: String)]()
136134
for i in 0 ..< sl.numExperts {
137135
let c = candidate.replacingOccurrences(of: ".experts.0.", with: ".experts.\(i).")

Libraries/MLXLMCommon/SwitchLayers.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,7 @@ public class SwitchGLU: Module, @unchecked Sendable {
316316
var outShape = x.shape
317317
outShape[outShape.count - 1] = downProj.outputDims
318318
let result = MLXArray.zeros(outShape).asType(.float16)
319-
if doSort {
320-
return MLX.squeezed(scatterUnsort(x: result, invOrder: inverseOrder, shape: indices.shape), axis: -2)
321-
}
322-
return MLX.squeezed(result, axis: -2)
319+
return MLX.squeezed(scatterUnsort(x: result, invOrder: inverseOrder, shape: indices.shape), axis: -2)
323320
}
324321

325322
// Parse routing — `idx.asArray()` is the actual sync point on GPU.

test_array_init.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import Foundation
2+
import MLX
3+
MLX.GPU.set(cacheLimit: 10 * 1024 * 1024)
4+
5+
let size: Int = 10
6+
let arr = MLXArray(0 ..< size).asType(.int32)
7+
print(arr)

test_scatter.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import Foundation
2+
import MLX
3+
4+
MLX.GPU.set(cacheLimit: 10 * 1024 * 1024)
5+
6+
var out = MLXArray.zeros([4, 10])
7+
let rows = MLXArray(0 ..< Int32(4)).reshaped([4, 1])
8+
let cols = MLXArray([1, 2, 0, 4, 3, 5, 2, 9]).reshaped([4, 2])
9+
let vals = MLXArray([10, 20, 30, 40, 50, 60, 70, 80]).reshaped([4, 2])
10+
11+
out[rows, cols] = vals
12+
MLX.eval(out)
13+
print(out)

0 commit comments

Comments
 (0)