Skip to content

Commit 8585dd6

Browse files
[prf/dec][refactor] Add unsupported exceptions for Q8_0 weights in GPU prefill-decode and batched-prefill-decode paths
1 parent d99a888 commit 8585dd6

4 files changed

Lines changed: 30 additions & 15 deletions

File tree

src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithBatchPrefillDecode.java

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import org.beehive.gpullama3.auxiliary.LastRunMetrics;
44
import org.beehive.gpullama3.inference.sampler.Sampler;
55
import org.beehive.gpullama3.inference.state.State;
6-
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
7-
import org.beehive.gpullama3.tensor.GGMLType;
86
import org.beehive.gpullama3.model.Configuration;
97
import org.beehive.gpullama3.model.Model;
108
import org.beehive.gpullama3.tokenizer.Tokenizer;
@@ -156,11 +154,6 @@ public static List<Integer> generateTokensGPULlama(
156154
int maxTokens, Sampler sampler, boolean echo,
157155
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
158156

159-
if (((TornadoWeights) model.weights()).getWeightType() == GGMLType.Q8_0) {
160-
throw new UnsupportedOperationException(
161-
"GPU batched prefill/decode path not yet implemented for Q8_0 weights");
162-
}
163-
164157
long startNanos = System.nanoTime();
165158

166159
final Configuration config = model.configuration();

src/main/java/org/beehive/gpullama3/inference/InferenceEngineWithPrefillDecode.java

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import org.beehive.gpullama3.auxiliary.LastRunMetrics;
44
import org.beehive.gpullama3.inference.sampler.Sampler;
55
import org.beehive.gpullama3.inference.state.State;
6-
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
7-
import org.beehive.gpullama3.tensor.GGMLType;
86
import org.beehive.gpullama3.model.Configuration;
97
import org.beehive.gpullama3.model.Model;
108
import org.beehive.gpullama3.tokenizer.Tokenizer;
@@ -129,11 +127,6 @@ public static List<Integer> generateTokensGPULlama(
129127
int maxTokens, Sampler sampler, boolean echo,
130128
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
131129

132-
if (((TornadoWeights) model.weights()).getWeightType() == GGMLType.Q8_0) {
133-
throw new UnsupportedOperationException(
134-
"GPU prefill/decode path not yet implemented for Q8_0 weights");
135-
}
136-
137130
long startNanos = System.nanoTime();
138131

139132
final Configuration config = model.configuration();

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithBatchPrefillDecode.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import org.beehive.gpullama3.inference.state.LlamaState;
44
import org.beehive.gpullama3.inference.state.State;
5+
import org.beehive.gpullama3.tensor.GGMLType;
56
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
67
import org.beehive.gpullama3.model.Model;
78
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
@@ -142,9 +143,21 @@ private TaskGraph buildDecodeActivationGraph(KernelContext ctx, String lastBatch
142143

143144
/**
144145
* Creates the {@link TornadoExecutionPlan} for forward pass with *prefill in batches and separated decode*.
146+
*
147+
* TODO: support Q8_0 weights
148+
* To implement this, consult how {@link TornadoVMMasterPlanStandard} uses the {@link QuantizationPlannerFactory}
145149
*/
146150
@Override
147151
public TornadoExecutionPlan createExecutionPlan() {
152+
GGMLType weightType = model.weights().getWeightType();
153+
switch (weightType) {
154+
case F16 -> { /* supported — continue below */ }
155+
case Q8_0 -> throw new UnsupportedOperationException(
156+
"Batched prefill/decode GPU path not yet implemented for Q8_0 weights");
157+
default -> throw new UnsupportedOperationException(
158+
"Batched prefill/decode GPU path not supported for weight type: " + weightType);
159+
}
160+
148161
LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
149162
SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model);
150163

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlanWithPrefillDecode.java

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import org.beehive.gpullama3.inference.state.LlamaState;
44
import org.beehive.gpullama3.inference.state.State;
5+
import org.beehive.gpullama3.tensor.GGMLType;
56
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
67
import org.beehive.gpullama3.model.Model;
78
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
@@ -121,9 +122,24 @@ private TaskGraph buildActivationGraph(KernelContext ctx) {
121122
}
122123

123124
// ── Plan construction ─────────────────────────────────────────────────────
124-
125+
/**
126+
* Creates the {@link TornadoExecutionPlan} for forward pass with *prefill/decode separation*.
127+
* Prefill is token-by-token but does not compute logits.
128+
*
129+
* TODO: support Q8_0 weights
130+
* To implement this, consult how {@link TornadoVMMasterPlanStandard} uses the {@link QuantizationPlannerFactory}
131+
*/
125132
@Override
126133
public TornadoExecutionPlan createExecutionPlan() {
134+
GGMLType weightType = model.weights().getWeightType();
135+
switch (weightType) {
136+
case F16 -> { /* supported — continue below */ }
137+
case Q8_0 -> throw new UnsupportedOperationException(
138+
"Prefill/decode GPU path not yet implemented for Q8_0 weights");
139+
default -> throw new UnsupportedOperationException(
140+
"Prefill/decode GPU path not supported for weight type: " + weightType);
141+
}
142+
127143
LlamaTornadoWeights weights = (LlamaTornadoWeights) model.weights();
128144
SchedulerType schedulerType = SchedulerDetectionService.determineSchedulerType(model);
129145

0 commit comments

Comments
 (0)