Skip to content

Commit b94b20f

Browse files
Merge pull request #108 from AdamBien/main
Add Q4_K/Q5_K/Q6_K GPU support via Q8_0 dequantization
2 parents f08c1f5 + 58f7e2a commit b94b20f

13 files changed

Lines changed: 507 additions & 38 deletions

llamaTornado

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ record Config(
1212
double temperature, double topP, long seed, int maxTokens,
1313
boolean stream, boolean echo, boolean interactive, boolean instruct,
1414
boolean useGpu, Backend backend, String gpuMemory,
15-
String heapMin, String heapMax,
15+
String heapMin, String heapMax, String directMemory,
1616
boolean debug, boolean profiler, String profilerDumpDir,
1717
boolean printBytecodes, boolean threads, boolean printKernel,
1818
boolean fullDump, boolean verboseInit,
@@ -37,6 +37,7 @@ Config parseArgs(String[] args) {
3737
String gpuMemory = "14GB";
3838
String heapMin = "20g";
3939
String heapMax = "20g";
40+
String directMemory = null;
4041
boolean debug = false;
4142
boolean profiler = false;
4243
String profilerDumpDir = null;
@@ -71,6 +72,7 @@ Config parseArgs(String[] args) {
7172
case "--gpu-memory" -> gpuMemory = args[++i];
7273
case "--heap-min" -> heapMin = args[++i];
7374
case "--heap-max" -> heapMax = args[++i];
75+
case "--direct-memory" -> directMemory = args[++i];
7476
case "--debug" -> debug = true;
7577
case "--profiler" -> profiler = true;
7678
case "--profiler-dump-dir" -> profilerDumpDir = args[++i];
@@ -101,12 +103,27 @@ Config parseArgs(String[] args) {
101103
profilerDumpDir = System.getenv("LLAMA_ROOT") + "/profiler-log.json";
102104
}
103105

106+
// Default direct memory to 3x heap to accommodate K-quant dequantization
107+
if (directMemory == null) {
108+
directMemory = parseAndScale(heapMax, 3);
109+
}
110+
104111
return new Config(modelPath, prompt, systemPrompt, temperature, topP, seed, maxTokens,
105-
stream, echo, interactive, instruct, useGpu, backend, gpuMemory, heapMin, heapMax,
112+
stream, echo, interactive, instruct, useGpu, backend, gpuMemory, heapMin, heapMax, directMemory,
106113
debug, profiler, profilerDumpDir, printBytecodes, threads, printKernel, fullDump,
107114
verboseInit, showCommand, executeAfterShow, openclFlags, maxWaitEvents, verbose);
108115
}
109116

117+
String parseAndScale(String memoryValue, int multiplier) {
118+
var matcher = java.util.regex.Pattern.compile("(\\d+)([gGmM]?)").matcher(memoryValue);
119+
if (matcher.matches()) {
120+
long value = Long.parseLong(matcher.group(1)) * multiplier;
121+
String suffix = matcher.group(2).isEmpty() ? "" : matcher.group(2);
122+
return value + suffix;
123+
}
124+
return memoryValue;
125+
}
126+
110127
void printUsage() {
111128
IO.println("""
112129
Usage: %s --model <path> [options]
@@ -138,6 +155,7 @@ void printUsage() {
138155
--gpu-memory <val> GPU memory allocation (default: 14GB)
139156
--heap-min <val> Min JVM heap (default: 20g)
140157
--heap-max <val> Max JVM heap (default: 20g)
158+
--direct-memory <val> Max direct buffer memory (default: 3x heap-max)
141159

142160
Debug:
143161
--debug Enable debug output
@@ -195,6 +213,7 @@ List<String> buildCommand(Config cfg, String javaHome, String tornadoSdk, String
195213
"-XX:+EnableJVMCI",
196214
"-Xms" + cfg.heapMin(),
197215
"-Xmx" + cfg.heapMax(),
216+
"-XX:MaxDirectMemorySize=" + cfg.directMemory(),
198217
"--enable-preview",
199218
"-Djava.library.path=" + tornadoSdk + "/lib",
200219
"-Djdk.module.showModuleResolution=false",

src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.beehive.gpullama3.model.loader;
22

3+
import org.beehive.gpullama3.tensor.GGMLType;
34
import org.beehive.gpullama3.tensor.GGUF;
45
import org.beehive.gpullama3.tensor.GGMLTensorEntry;
56
import org.beehive.gpullama3.auxiliary.Pair;
@@ -8,6 +9,7 @@
89
import org.beehive.gpullama3.model.Model;
910
import org.beehive.gpullama3.tokenizer.Tokenizer;
1011
import org.beehive.gpullama3.tokenizer.Vocabulary;
12+
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
1113

1214
import java.io.IOException;
1315
import java.nio.channels.FileChannel;
@@ -40,10 +42,39 @@ protected String getModelQuantization(Map<String, Object> metadata) {
4042
return switch (modelQuantizationAsInt) {
4143
case 1 -> "FP16";
4244
case 7 -> "Q8_0";
45+
case 14, 15 -> "Q8_0"; // Q4_K_S, Q4_K_M (K-quants use Q8_0 activations)
46+
case 16, 17 -> "Q8_0"; // Q5_K_S, Q5_K_M
47+
case 18 -> "Q8_0"; // Q6_K
4348
default -> throw new UnsupportedOperationException("Unsupported quantization format: " + modelQuantizationAsInt + " (as int).");
4449
};
4550
}
4651

52+
/**
53+
* Returns the effective GPU weight type for TornadoVM execution.
54+
* K-quant types (Q4_K, Q5_K, Q6_K) are dequantized to Q8_0 at load time.
55+
*/
56+
protected static GGMLType effectiveGpuWeightType(GGMLType ggmlType) {
57+
return switch (ggmlType) {
58+
case F16, F32, Q8_0 -> ggmlType;
59+
case Q4_K, Q5_K, Q6_K -> GGMLType.Q8_0;
60+
default -> ggmlType;
61+
};
62+
}
63+
64+
private static String fileTypeName(int fileType) {
65+
return switch (fileType) {
66+
case 0 -> "F32";
67+
case 1 -> "F16";
68+
case 7 -> "Q8_0";
69+
case 14 -> "Q4_K_S";
70+
case 15 -> "Q4_K_M";
71+
case 16 -> "Q5_K_S";
72+
case 17 -> "Q5_K_M";
73+
case 18 -> "Q6_K";
74+
default -> "type_" + fileType;
75+
};
76+
}
77+
4778
/**
4879
* Template method that defines the model loading workflow. Subclasses should not override this method.
4980
*
@@ -123,6 +154,11 @@ public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, C config)
123154

124155
// Delegate to specific implementation
125156
if (useTornadovm) {
157+
GGMLType gpuType = effectiveGpuWeightType(outputWeight.ggmlType());
158+
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
159+
int fileType = (int) gguf.getMetadata().get("general.file_type");
160+
System.out.println("Loading model weights in TornadoVM format (" + fileTypeName(fileType) + " -> " + gpuType + ")");
161+
}
126162
return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
127163
} else {
128164
return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);

src/main/java/org/beehive/gpullama3/model/loader/DevstralModelLoader.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,7 @@ protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntri
143143
// @formatter:off
144144
@Override
145145
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, DevstralConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
146-
GGMLType ggmlType = outputWeight.ggmlType();
147-
148-
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
149-
System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")");
150-
}
146+
GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType());
151147

152148
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {
153149
throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");

src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
136136
Pair<float[], float[]> ropeFreqs,
137137
GGMLTensorEntry tokenEmbeddings,
138138
GGMLTensorEntry outputWeight) {
139-
GGMLType ggmlType = outputWeight.ggmlType();
140-
141-
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
142-
System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")");
143-
}
139+
GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType());
144140

145141
// Validate supported types
146142
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {

src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,7 @@ protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntr
106106
Pair<float[], float[]> ropeFreqs,
107107
GGMLTensorEntry tokenEmbeddings,
108108
GGMLTensorEntry outputWeight) {
109-
GGMLType ggmlType = outputWeight.ggmlType();
110-
111-
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
112-
System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")");
113-
}
109+
GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType());
114110

115111
// Validate supported types
116112
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {

src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,7 @@ protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntri
116116
// @formatter:off
117117
@Override
118118
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, MistralConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
119-
GGMLType ggmlType = outputWeight.ggmlType();
120-
121-
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
122-
System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")");
123-
}
119+
GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType());
124120

125121
// Validate supported types
126122
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import uk.ac.manchester.tornado.api.types.arrays.*;
1717

1818
import java.io.IOException;
19+
import java.lang.foreign.Arena;
1920
import java.lang.foreign.MemorySegment;
2021
import java.lang.foreign.ValueLayout;
2122
import java.nio.ByteOrder;
@@ -122,6 +123,9 @@ public static FloatTensor loadTensor(GGMLTensorEntry entry) {
122123
case F32 -> new FP32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
123124
case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
124125
case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
126+
case Q4_K -> new Q4_KFloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
127+
case Q5_K -> new Q5_KFloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
128+
case Q6_K -> new Q6_KFloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
125129
case F16 -> new FP16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
126130
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
127131
};
@@ -150,11 +154,69 @@ public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) {
150154
case F32 -> FP32TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
151155
case F16 -> FP16TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
152156
case Q8_0 -> Q8_0TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
153-
case Q4_0 -> throw new UnsupportedOperationException("Q4 format not supported yet");
157+
case Q4_K, Q5_K, Q6_K -> dequantizeToQ8_0TornadoTensor(entry);
158+
case Q4_0 -> throw new UnsupportedOperationException("Q4_0 format not supported for TornadoVM yet");
154159
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
155160
};
156161
}
157162

163+
/**
164+
* Dequantizes a K-quant tensor (Q4_K, Q5_K, Q6_K) to Q8_0 format for TornadoVM/GPU execution.
165+
* This is a load-time conversion that allows K-quant models to run on GPU with existing Q8_0 kernels.
166+
*/
167+
private static Q8_0TornadoTensor dequantizeToQ8_0TornadoTensor(GGMLTensorEntry entry) {
168+
// The entry's memorySegment includes a TornadoVM ARRAY_HEADER prefix (16 bytes of zeros).
169+
// Slice past it so the K-quant FloatTensor reads raw tensor data starting at byte 0.
170+
long headerBytes = TornadoNativeArray.ARRAY_HEADER;
171+
GGMLTensorEntry dataEntry = new GGMLTensorEntry(
172+
entry.mappedFile(), entry.name(), entry.ggmlType(), entry.shape(),
173+
entry.memorySegment().asSlice(headerBytes));
174+
FloatTensor sourceTensor = loadTensor(dataEntry);
175+
int numElements = sourceTensor.size();
176+
int blockSize = 32;
177+
int blocksNeeded = (numElements + blockSize - 1) / blockSize;
178+
int q8BlockBytes = 34; // 2 bytes scale + 32 bytes quants
179+
int q8BytesNeeded = blocksNeeded * q8BlockBytes;
180+
181+
byte[] q8Data = new byte[q8BytesNeeded];
182+
183+
for (int b = 0; b < blocksNeeded; b++) {
184+
int start = b * blockSize;
185+
int end = Math.min(start + blockSize, numElements);
186+
187+
// Find max absolute value for scale
188+
float maxAbs = 0;
189+
for (int i = start; i < end; i++) {
190+
maxAbs = Math.max(maxAbs, Math.abs(sourceTensor.getFloat(i)));
191+
}
192+
float scale = maxAbs / 127.0f;
193+
194+
// Write scale as fp16 (little-endian)
195+
short scaleF16 = Float.floatToFloat16(scale);
196+
int blockOff = b * q8BlockBytes;
197+
q8Data[blockOff] = (byte) (scaleF16 & 0xFF);
198+
q8Data[blockOff + 1] = (byte) ((scaleF16 >> 8) & 0xFF);
199+
200+
// Quantize values
201+
float invScale = scale != 0 ? 1.0f / scale : 0;
202+
for (int i = start; i < end; i++) {
203+
int qi = Math.round(sourceTensor.getFloat(i) * invScale);
204+
qi = Math.max(-128, Math.min(127, qi));
205+
q8Data[blockOff + 2 + (i - start)] = (byte) qi;
206+
}
207+
}
208+
209+
// Allocate native memory with TornadoNativeArray header, matching GGUF.loadTensorsTornado layout
210+
MemorySegment nativeSegment = Arena.ofAuto().allocate(headerBytes + q8BytesNeeded, 4);
211+
// Zero out the header
212+
for (int i = 0; i < headerBytes; i++) {
213+
nativeSegment.set(ValueLayout.JAVA_BYTE, i, (byte) 0);
214+
}
215+
// Copy Q8_0 data after header
216+
MemorySegment.copy(MemorySegment.ofArray(q8Data), 0, nativeSegment, headerBytes, q8BytesNeeded);
217+
return Q8_0TornadoTensor.fromTornadoMemorySegment(nativeSegment);
218+
}
219+
158220
/**
159221
* Dispatcher method for loading a TornadoVM tensor array based on type.
160222
* Used in GPU-path.

src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,7 @@ protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntri
126126
// @formatter:off
127127
@Override
128128
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Phi3Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
129-
GGMLType ggmlType = outputWeight.ggmlType();
130-
131-
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
132-
System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")");
133-
}
129+
GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType());
134130

135131
// Validate supported types
136132
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {

src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,7 @@ protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntri
126126
@Override
127127
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Qwen2Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
128128
GGMLTensorEntry outputWeight) {
129-
GGMLType ggmlType = outputWeight.ggmlType();
130-
131-
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
132-
System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")");
133-
}
129+
GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType());
134130

135131
// Validate supported types
136132
if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) {

src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,7 @@ protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntri
129129
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Qwen3Configuration config,
130130
Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
131131
GGMLTensorEntry outputWeight) {
132-
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
133-
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
134-
}
135-
136-
GGMLType ggmlType = outputWeight.ggmlType();
132+
GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType());
137133

138134
final int nl = config.numberOfLayers();
139135

0 commit comments

Comments
 (0)