Skip to content

Commit d4329f8

Browse files
[prf/dec] Separate inference paths (InferenceEngine, InferenceCore, CPU/GPU) for standard, prefill-decode and prefill-decode with batching
1 parent 7429a63 commit d4329f8

5 files changed

Lines changed: 464 additions & 290 deletions

File tree

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
package org.beehive.gpullama3.inference;
2+
3+
import org.beehive.gpullama3.auxiliary.Parallel;
4+
import org.beehive.gpullama3.inference.state.State;
5+
import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
6+
import org.beehive.gpullama3.model.Configuration;
7+
import org.beehive.gpullama3.model.Model;
8+
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
9+
import org.beehive.gpullama3.tensor.standard.FloatTensor;
10+
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithBatchPrefillDecode;
11+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
12+
13+
/**
14+
* Low-level forward passes for the batched prefill/decode inference path (Phase 3/4).
15+
*
16+
* <p>Parallel to {@link InferenceCoreWithPrefillDecode} — does NOT modify it.</p>
17+
*
18+
* <p>Provides three operations:</p>
19+
* <ul>
20+
* <li>{@link #batchForwardJavaPrefill} — CPU batch prefill: processes a chunk of
21+
* prompt tokens in one pass using batch matmul, avoiding redundant weight
22+
* traversals. Only the KV cache is populated; logits are intentionally omitted.</li>
23+
* <li>{@link #batchForwardTornadoVMPrefill} — GPU batch prefill: delegates the chunk
24+
* to {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill}.</li>
25+
* <li>{@link #forwardTornadoVMDecode} — GPU decode: delegates a single decode step to
26+
* {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode}, which
27+
* handles the embedding copy and runs the full decode + logits graphs.</li>
28+
* </ul>
29+
*/
30+
public final class InferenceCoreBatchPrefillDecode {
31+
32+
private InferenceCoreBatchPrefillDecode() {}
33+
34+
/**
35+
* CPU batched prefill forward pass for LLaMA (Phase 3).
36+
*
37+
* <p>Processes {@code batchSize} prompt tokens simultaneously through all
38+
* transformer layers. For each layer, Q/K/V projections, output projection,
39+
* and FFN projections are computed via batch matmul
40+
* ({@link FloatTensor#matmul(int, FloatTensor[], FloatTensor[], int, int)}),
41+
* which parallelises over both output dimension and batch simultaneously.
42+
* Attention reuses {@code state.att} sequentially per token (parallel per
43+
* head within each token), keeping memory overhead minimal.</p>
44+
*
45+
* <p>The logits layer is intentionally omitted — only the KV cache matters
46+
* for prefill positions.</p>
47+
*
48+
* @param model the LLaMA model (must carry {@link StandardWeights})
49+
* @param state mutable inference state (KV cache, att buffer …)
50+
* @param tokens input token ids, {@code tokens[b]} at position {@code startPos+b}
51+
* @param startPos sequence position of {@code tokens[0]}
52+
* @param batchSize number of tokens in this chunk ({@code tokens.length})
53+
*/
54+
public static void batchForwardJavaPrefill(Model model, State state, int[] tokens, int startPos, int batchSize) {
55+
final Configuration config = model.configuration();
56+
final StandardWeights weights = (StandardWeights) model.weights();
57+
int dim = config.dim();
58+
int headSize = config.headSize();
59+
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
60+
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads();
61+
float sqrtHeadSize = (float) Math.sqrt(headSize);
62+
63+
// ── Batch activation tensors (allocated once per chunk) ───────────────
64+
FloatTensor[] x = new FloatTensor[batchSize];
65+
FloatTensor[] xb = new FloatTensor[batchSize];
66+
FloatTensor[] xb2 = new FloatTensor[batchSize];
67+
FloatTensor[] q = new FloatTensor[batchSize];
68+
FloatTensor[] k = new FloatTensor[batchSize];
69+
FloatTensor[] v = new FloatTensor[batchSize];
70+
FloatTensor[] hb = new FloatTensor[batchSize];
71+
FloatTensor[] hb2 = new FloatTensor[batchSize];
72+
for (int b = 0; b < batchSize; b++) {
73+
x[b] = ArrayFloatTensor.allocate(dim);
74+
xb[b] = ArrayFloatTensor.allocate(dim);
75+
xb2[b] = ArrayFloatTensor.allocate(dim);
76+
q[b] = ArrayFloatTensor.allocate(dim);
77+
k[b] = ArrayFloatTensor.allocate(kvDim);
78+
v[b] = ArrayFloatTensor.allocate(kvDim);
79+
hb[b] = ArrayFloatTensor.allocate(config.hiddenDim());
80+
hb2[b] = ArrayFloatTensor.allocate(config.hiddenDim());
81+
}
82+
83+
// ── Token embeddings ──────────────────────────────────────────────────
84+
Parallel.parallelFor(0, batchSize, b ->
85+
weights.token_embedding_table.copyTo(tokens[b] * dim, x[b], 0, dim));
86+
87+
// ── Transformer layers ────────────────────────────────────────────────
88+
for (int l = 0; l < config.numberOfLayers(); l++) {
89+
final int layer = l;
90+
91+
Parallel.parallelFor(0, batchSize, b ->
92+
InferenceCore.rmsnorm(xb[b], x[b], weights.rms_att_weight[layer], 0, dim, config.rmsNormEps()));
93+
94+
weights.wq[l].matmul(batchSize, xb, q, dim, dim);
95+
weights.wk[l].matmul(batchSize, xb, k, kvDim, dim);
96+
weights.wv[l].matmul(batchSize, xb, v, kvDim, dim);
97+
98+
Parallel.parallelFor(0, batchSize, b -> {
99+
int pos = startPos + b;
100+
for (int i = 0; i < dim; i += 2) {
101+
int head_dim = i % headSize;
102+
float fcr = weights.freq_cis_real.getFloat(pos * (headSize / 2) + (head_dim / 2));
103+
float fci = weights.freq_cis_imag.getFloat(pos * (headSize / 2) + (head_dim / 2));
104+
int rotn = i < kvDim ? 2 : 1;
105+
for (int vv = 0; vv < rotn; vv++) {
106+
FloatTensor vec = vv == 0 ? q[b] : k[b];
107+
float v0 = vec.getFloat(i);
108+
float v1 = vec.getFloat(i + 1);
109+
vec.setFloat(i, v0 * fcr - v1 * fci);
110+
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
111+
}
112+
}
113+
k[b].copyTo(0, state.keyCache[layer], pos * kvDim, kvDim);
114+
v[b].copyTo(0, state.valueCache[layer], pos * kvDim, kvDim);
115+
});
116+
117+
for (int b = 0; b < batchSize; b++) {
118+
final int pos_b = startPos + b;
119+
final int bFinal = b;
120+
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
121+
int qOffset = h * headSize;
122+
int attOffset = h * config.contextLength();
123+
124+
for (int t = 0; t <= pos_b; t++) {
125+
int keyCacheOffset = t * kvDim + (h / kvMul) * headSize;
126+
float score = q[bFinal].dot(qOffset, state.keyCache[layer], keyCacheOffset, headSize) / sqrtHeadSize;
127+
state.att.setFloat(attOffset + t, score);
128+
}
129+
state.att.softmaxInPlace(attOffset, pos_b + 1);
130+
131+
int xbOffset = h * headSize;
132+
xb[bFinal].fillInPlace(xbOffset, headSize, 0f);
133+
for (int t = 0; t <= pos_b; t++) {
134+
int vOffset = t * kvDim + (h / kvMul) * headSize;
135+
float a = state.att.getFloat(attOffset + t);
136+
xb[bFinal].saxpyInPlace(xbOffset, state.valueCache[layer], vOffset, headSize, a);
137+
}
138+
});
139+
}
140+
141+
weights.wo[l].matmul(batchSize, xb, xb2, dim, dim);
142+
143+
Parallel.parallelFor(0, batchSize, b -> {
144+
x[b].addInPlace(xb2[b]);
145+
InferenceCore.rmsnorm(xb[b], x[b], weights.rms_ffn_weight[layer], 0, dim, config.rmsNormEps());
146+
});
147+
148+
weights.w1[l].matmul(batchSize, xb, hb, config.hiddenDim(), dim);
149+
weights.w3[l].matmul(batchSize, xb, hb2, config.hiddenDim(), dim);
150+
151+
Parallel.parallelFor(0, batchSize, b -> {
152+
hb[b].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
153+
hb[b].multiplyInPlace(hb2[b]);
154+
});
155+
156+
weights.w2[l].matmul(batchSize, hb, xb, dim, config.hiddenDim());
157+
158+
Parallel.parallelFor(0, batchSize, b -> x[b].addInPlace(xb[b]));
159+
}
160+
// Final RMSNorm and vocab projection intentionally omitted —
161+
// logits are not needed for any token in a prefill batch.
162+
}
163+
164+
/**
165+
* GPU batched prefill forward pass (Phase 4).
166+
*
167+
* <p>Delegates the full chunk to
168+
* {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardBatchPrefill},
169+
* which handles embedding lookup and GPU execution internally.</p>
170+
*
171+
* @param model the LLaMA model
172+
* @param tokens token ids for this chunk
173+
* @param startPos sequence position of {@code tokens[0]}
174+
* @param chunkSize number of tokens in this chunk
175+
* @param plan the batched prefill/decode GPU plan
176+
*/
177+
public static void batchForwardTornadoVMPrefill(Model model, int[] tokens, int startPos, int chunkSize,
178+
TornadoVMMasterPlanWithBatchPrefillDecode plan) {
179+
plan.tornadoVMForwardBatchPrefill(tokens, startPos, model, chunkSize);
180+
}
181+
182+
/**
183+
* GPU decode forward pass (Phase 4).
184+
*
185+
* <p>Delegates a single-token decode step to
186+
* {@link TornadoVMMasterPlanWithBatchPrefillDecode#tornadoVMForwardDecode},
187+
* which copies the token embedding and runs the decode + logits graphs.</p>
188+
*
189+
* @param model the LLaMA model
190+
* @param token current token id
191+
* @param position sequence position
192+
* @param plan the batched prefill/decode GPU plan
193+
* @return logits array for token sampling
194+
*/
195+
public static FloatArray forwardTornadoVMDecode(Model model, int token, int position,
196+
TornadoVMMasterPlanWithBatchPrefillDecode plan) {
197+
return plan.tornadoVMForwardDecode(token, position, model);
198+
}
199+
}

src/main/java/org/beehive/gpullama3/inference/InferenceCoreWithPrefillDecode.java

Lines changed: 0 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
77
import org.beehive.gpullama3.model.Configuration;
88
import org.beehive.gpullama3.model.Model;
9-
import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
109
import org.beehive.gpullama3.tensor.standard.FloatTensor;
1110
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlanWithPrefillDecode;
1211

@@ -127,147 +126,6 @@ public static void forwardJavaPrefill(Model model, State state, int token, int p
127126
// logits are not needed for prefill positions — only the KV cache matters.
128127
}
129128

130-
/**
131-
* CPU batched prefill forward pass for LLaMA (Phase 3).
132-
*
133-
* <p>Processes {@code batchSize} prompt tokens simultaneously through all
134-
* transformer layers. For each layer, Q/K/V projections, output projection,
135-
* and FFN projections are computed via batch matmul
136-
* ({@link FloatTensor#matmul(int, FloatTensor[], FloatTensor[], int, int)}),
137-
* which parallelises over both output dimension and batch simultaneously.
138-
* Attention reuses {@code state.att} sequentially per token (parallel per
139-
* head within each token), keeping memory overhead minimal.</p>
140-
*
141-
* <p>The logits layer is intentionally omitted — only the KV cache matters
142-
* for prefill positions.</p>
143-
*
144-
* @param model the LLaMA model (must carry {@link StandardWeights})
145-
* @param state mutable inference state (KV cache, att buffer …)
146-
* @param tokens input token ids, {@code tokens[b]} at position {@code startPos+b}
147-
* @param startPos sequence position of {@code tokens[0]}
148-
* @param batchSize number of tokens in this chunk ({@code tokens.length})
149-
*/
150-
public static void batchForwardJavaPrefill(Model model, State state, int[] tokens, int startPos, int batchSize) {
151-
final Configuration config = model.configuration();
152-
final StandardWeights weights = (StandardWeights) model.weights();
153-
int dim = config.dim();
154-
int headSize = config.headSize();
155-
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
156-
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads();
157-
float sqrtHeadSize = (float) Math.sqrt(headSize);
158-
159-
// ── Batch activation tensors (allocated once per chunk) ───────────────
160-
FloatTensor[] x = new FloatTensor[batchSize];
161-
FloatTensor[] xb = new FloatTensor[batchSize];
162-
FloatTensor[] xb2 = new FloatTensor[batchSize];
163-
FloatTensor[] q = new FloatTensor[batchSize];
164-
FloatTensor[] k = new FloatTensor[batchSize];
165-
FloatTensor[] v = new FloatTensor[batchSize];
166-
FloatTensor[] hb = new FloatTensor[batchSize];
167-
FloatTensor[] hb2 = new FloatTensor[batchSize];
168-
for (int b = 0; b < batchSize; b++) {
169-
x[b] = ArrayFloatTensor.allocate(dim);
170-
xb[b] = ArrayFloatTensor.allocate(dim);
171-
xb2[b] = ArrayFloatTensor.allocate(dim);
172-
q[b] = ArrayFloatTensor.allocate(dim);
173-
k[b] = ArrayFloatTensor.allocate(kvDim);
174-
v[b] = ArrayFloatTensor.allocate(kvDim);
175-
hb[b] = ArrayFloatTensor.allocate(config.hiddenDim());
176-
hb2[b] = ArrayFloatTensor.allocate(config.hiddenDim());
177-
}
178-
179-
// ── Token embeddings ──────────────────────────────────────────────────
180-
Parallel.parallelFor(0, batchSize, b ->
181-
weights.token_embedding_table.copyTo(tokens[b] * dim, x[b], 0, dim));
182-
183-
// ── Transformer layers ────────────────────────────────────────────────
184-
for (int l = 0; l < config.numberOfLayers(); l++) {
185-
final int layer = l;
186-
187-
// Attention RMSNorm (parallel per b)
188-
Parallel.parallelFor(0, batchSize, b ->
189-
InferenceCore.rmsnorm(xb[b], x[b], weights.rms_att_weight[layer], 0, dim, config.rmsNormEps()));
190-
191-
// QKV projections — batch matmul parallelises over (dim × batchSize)
192-
weights.wq[l].matmul(batchSize, xb, q, dim, dim);
193-
weights.wk[l].matmul(batchSize, xb, k, kvDim, dim);
194-
weights.wv[l].matmul(batchSize, xb, v, kvDim, dim);
195-
196-
// RoPE + KV cache store (parallel per b — different positions, no conflict)
197-
Parallel.parallelFor(0, batchSize, b -> {
198-
int pos = startPos + b;
199-
for (int i = 0; i < dim; i += 2) {
200-
int head_dim = i % headSize;
201-
float fcr = weights.freq_cis_real.getFloat(pos * (headSize / 2) + (head_dim / 2));
202-
float fci = weights.freq_cis_imag.getFloat(pos * (headSize / 2) + (head_dim / 2));
203-
int rotn = i < kvDim ? 2 : 1;
204-
for (int vv = 0; vv < rotn; vv++) {
205-
FloatTensor vec = vv == 0 ? q[b] : k[b];
206-
float v0 = vec.getFloat(i);
207-
float v1 = vec.getFloat(i + 1);
208-
vec.setFloat(i, v0 * fcr - v1 * fci);
209-
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
210-
}
211-
}
212-
k[b].copyTo(0, state.keyCache[layer], pos * kvDim, kvDim);
213-
v[b].copyTo(0, state.valueCache[layer], pos * kvDim, kvDim);
214-
});
215-
216-
// Attention — sequential per b (state.att is shared), parallel per head
217-
for (int b = 0; b < batchSize; b++) {
218-
final int pos_b = startPos + b;
219-
final int bFinal = b;
220-
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
221-
int qOffset = h * headSize;
222-
int attOffset = h * config.contextLength();
223-
224-
for (int t = 0; t <= pos_b; t++) {
225-
int keyCacheOffset = t * kvDim + (h / kvMul) * headSize;
226-
float score = q[bFinal].dot(qOffset, state.keyCache[layer], keyCacheOffset, headSize) / sqrtHeadSize;
227-
state.att.setFloat(attOffset + t, score);
228-
}
229-
state.att.softmaxInPlace(attOffset, pos_b + 1);
230-
231-
int xbOffset = h * headSize;
232-
xb[bFinal].fillInPlace(xbOffset, headSize, 0f);
233-
for (int t = 0; t <= pos_b; t++) {
234-
int vOffset = t * kvDim + (h / kvMul) * headSize;
235-
float a = state.att.getFloat(attOffset + t);
236-
xb[bFinal].saxpyInPlace(xbOffset, state.valueCache[layer], vOffset, headSize, a);
237-
}
238-
});
239-
}
240-
241-
// Output projection — batch matmul
242-
weights.wo[l].matmul(batchSize, xb, xb2, dim, dim);
243-
244-
// Residual + FFN RMSNorm (parallel per b)
245-
Parallel.parallelFor(0, batchSize, b -> {
246-
x[b].addInPlace(xb2[b]);
247-
InferenceCore.rmsnorm(xb[b], x[b], weights.rms_ffn_weight[layer], 0, dim, config.rmsNormEps());
248-
});
249-
250-
// FFN projections — batch matmul
251-
weights.w1[l].matmul(batchSize, xb, hb, config.hiddenDim(), dim);
252-
weights.w3[l].matmul(batchSize, xb, hb2, config.hiddenDim(), dim);
253-
254-
// SwiGLU (parallel per b)
255-
Parallel.parallelFor(0, batchSize, b -> {
256-
hb[b].mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
257-
hb[b].multiplyInPlace(hb2[b]);
258-
});
259-
260-
// w2 projection — batch matmul (output reuses xb)
261-
weights.w2[l].matmul(batchSize, hb, xb, dim, config.hiddenDim());
262-
263-
// FFN residual (parallel per b)
264-
Parallel.parallelFor(0, batchSize, b -> x[b].addInPlace(xb[b]));
265-
}
266-
267-
// Final RMSNorm and vocab projection intentionally omitted —
268-
// logits are not needed for any token in a prefill batch.
269-
}
270-
271129
/**
272130
* GPU prefill-only forward pass for LLaMA (FP16, TornadoVM).
273131
*

0 commit comments

Comments
 (0)