Skip to content

Commit d99a888

Browse files
[prf/dec] Add unsupported operation exceptions for CPU/GPU prefill-decode and batched-prefill-decode paths in Mistral, Phi3, Qwen2, and Qwen3 models.
1 parent d4329f8 commit d99a888

4 files changed

Lines changed: 58 additions & 0 deletions

File tree

src/main/java/org/beehive/gpullama3/model/mistral/Mistral.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import org.beehive.gpullama3.inference.InferenceCore;
44
import org.beehive.gpullama3.inference.InferenceEngine;
5+
import org.beehive.gpullama3.inference.InferenceEngineWithBatchPrefillDecode;
6+
import org.beehive.gpullama3.inference.InferenceEngineWithPrefillDecode;
57
import org.beehive.gpullama3.inference.sampler.Sampler;
68
import org.beehive.gpullama3.inference.state.LlamaState;
79
import org.beehive.gpullama3.inference.state.State;
@@ -17,6 +19,8 @@
1719
import java.util.Set;
1820
import java.util.function.IntConsumer;
1921

22+
import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.WITH_PREFILL_DECODE;
23+
2024
public class Mistral extends AbstractModel {
2125

2226
MistralConfiguration configuration;
@@ -61,12 +65,24 @@ public void forward(State state, int token, int position) {
6165
@Override
6266
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
6367
IntConsumer onTokenGenerated) {
68+
if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) {
69+
throw new UnsupportedOperationException("Batch prefill/decode on CPU not yet implemented for Mistral");
70+
}
71+
if (WITH_PREFILL_DECODE) {
72+
throw new UnsupportedOperationException("Prefill/decode on CPU not yet implemented for Mistral");
73+
}
6474
return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
6575
}
6676

6777
@Override
6878
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
6979
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
80+
if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) {
81+
throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Mistral");
82+
}
83+
if (WITH_PREFILL_DECODE) {
84+
throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Mistral");
85+
}
7086
return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
7187
}
7288

src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import java.util.Set;
1818
import java.util.function.IntConsumer;
1919

20+
import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.WITH_PREFILL_DECODE;
21+
2022
public class Phi3 extends AbstractModel {
2123

2224
Phi3Configuration configuration;
@@ -73,12 +75,24 @@ public void forward(State state, int token, int position) {
7375
@Override
7476
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
7577
IntConsumer onTokenGenerated) {
78+
if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) {
79+
throw new UnsupportedOperationException("Batch prefill/decode on CPU not yet implemented for Phi3");
80+
}
81+
if (WITH_PREFILL_DECODE) {
82+
throw new UnsupportedOperationException("Prefill/decode on CPU not yet implemented for Phi3");
83+
}
7684
return InferenceEngine.generateTokensPhi3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
7785
}
7886

7987
@Override
8088
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
8189
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
90+
if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) {
91+
throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Phi3");
92+
}
93+
if (WITH_PREFILL_DECODE) {
94+
throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Phi3");
95+
}
8296
return InferenceEngine.generateTokensGPUPhi3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
8397
}
8498
}

src/main/java/org/beehive/gpullama3/model/qwen2/Qwen2.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import java.util.Set;
1818
import java.util.function.IntConsumer;
1919

20+
import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.WITH_PREFILL_DECODE;
21+
2022
public class Qwen2 extends AbstractModel {
2123

2224
Qwen2Configuration configuration;
@@ -92,12 +94,24 @@ public void forward(State state, int token, int position) {
9294
@Override
9395
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
9496
IntConsumer onTokenGenerated) {
97+
if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) {
98+
throw new UnsupportedOperationException("Batch prefill/decode on CPU not yet implemented for Qwen2/Deepseek-R1-Distill-Qwen");
99+
}
100+
if (WITH_PREFILL_DECODE) {
101+
throw new UnsupportedOperationException("Prefill/decode on CPU not yet implemented for Qwen2/Deepseek-R1-Distill-Qwen");
102+
}
95103
return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
96104
}
97105

98106
@Override
99107
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
100108
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
109+
if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) {
110+
throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Qwen2/Deepseek-R1-Distill-Qwen");
111+
}
112+
if (WITH_PREFILL_DECODE) {
113+
throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Qwen2/Deepseek-R1-Distill-Qwen");
114+
}
101115
return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
102116
}
103117
}

src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import java.util.Set;
1818
import java.util.function.IntConsumer;
1919

20+
import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.WITH_PREFILL_DECODE;
21+
2022
public class Qwen3 extends AbstractModel {
2123

2224
Qwen3Configuration configuration;
@@ -73,12 +75,24 @@ public void forward(State state, int token, int position) {
7375
@Override
7476
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
7577
IntConsumer onTokenGenerated) {
78+
if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) {
79+
throw new UnsupportedOperationException("Batch prefill/decode on CPU not yet implemented for Qwen3");
80+
}
81+
if (WITH_PREFILL_DECODE) {
82+
throw new UnsupportedOperationException("Prefill/decode on CPU not yet implemented for Qwen3");
83+
}
7684
return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
7785
}
7886

7987
@Override
8088
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
8189
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
90+
if (WITH_PREFILL_DECODE && TornadoVMMasterPlan.PREFILL_BATCH_SIZE > 1) {
91+
throw new UnsupportedOperationException("Batch prefill/decode on GPU not yet implemented for Qwen3");
92+
}
93+
if (WITH_PREFILL_DECODE) {
94+
throw new UnsupportedOperationException("Prefill/decode on GPU not yet implemented for Qwen3");
95+
}
8296
return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
8397
}
8498

0 commit comments

Comments
 (0)