This comprehensive guide explains how to add support for new transformer-based language models to GPULlama3.java.
Last Updated: November 1, 2025 Example Model: Google Gemma 3 Difficulty: Advanced (requires understanding of transformer architectures)
- Prerequisites
- Architecture Analysis
- Step-by-Step Implementation
- Testing and Debugging
- Common Patterns
- Troubleshooting
- ✅ Java programming (records, interfaces, generics)
- ✅ Transformer architecture basics (attention, FFN, normalization)
- ✅ Model formats (GGUF, safetensors)
- ✅ Tokenization (BPE, SentencePiece, WordPiece)
- Java 21+ with preview features enabled
- Maven build system
- GGUF model files
- (Optional) TornadoVM for GPU support
Study these existing implementations:
- Simple: Llama (standard transformer)
- With GQA: Mistral (grouped-query attention)
- With Q/K Norm: Qwen3 (query/key normalization)
- Complex: Gemma3 (sandwich normalization)
Research and document:
- Model family: Llama-based, GPT-based, custom?
- Architecture variants: Standard, MoE, multimodal?
- Normalization type: LayerNorm, RMSNorm, custom?
- Attention mechanism: MHA, GQA, MQA?
- Special features: Rope, ALiBi, sliding window, etc.
Look for:
- Official HuggingFace transformers code
- llama.cpp implementation (C++)
- GGML format documentation
- Academic papers or blog posts
Example Resources:
# llama.cpp docs
https://github.com/ggml-org/llama.cpp/tree/master/docs
# HuggingFace model card
https://huggingface.co/[organization]/[model-name]
# Architecture diagrams
https://github.com/[org]/[repo]/blob/main/architecture.mdCompare with existing models:
| Feature | Llama | Mistral | Qwen3 | Your Model |
|---|---|---|---|---|
| Norm layers per block | 2 | 2 | 2 | ? |
| Attention type | MHA | GQA | GQA | ? |
| Q/K normalization | ❌ | ❌ | ✅ | ? |
| Embedding scaling | ❌ | ❌ | ❌ | ? |
| Special tokens | 3 | 5 | 4 | ? |
| Context window | 128K | 32K | 131K | ? |
File: src/main/java/org/beehive/gpullama3/model/{modelname}/{ModelName}Configuration.java
package org.beehive.gpullama3.model.{modelname};
import org.beehive.gpullama3.model.Configuration;
public record {ModelName}Configuration(
// Core dimensions
int dim, // Model dimension
int hiddenDim, // FFN hidden dimension
int numberOfLayers, // Number of transformer blocks
int numberOfHeads, // Number of attention heads
int numberOfKeyValueHeads, // For GQA (use numberOfHeads if MHA)
// Vocabulary and context
int vocabularySize, // Size of vocabulary
int contextLength, // Maximum sequence length
// Normalization
float rmsNormEps, // RMSNorm epsilon (usually 1e-5 or 1e-6)
// Position encoding
float ropeTheta // RoPE theta (usually 10000 or 500000)
// Add model-specific fields here:
// - int numberOfHeadsKey (if using Q/K norm like Qwen3/Gemma3)
// - int numberOfHeadsValue (if using Q/K norm)
// - boolean sharedWeights (if embeddings/output weights shared)
// - int slidingWindow (for Mistral)
) implements Configuration {
@Override
public int headSize() {
return dim / numberOfHeads;
}
// Implement other Configuration interface methods
@Override
public int contextLength() { return contextLength; }
@Override
public int dim() { return dim; }
// ... etc
}Decision Points:
- ❓ Does the model use Grouped-Query Attention? → Add
numberOfKeyValueHeads - ❓ Does it have Q/K normalization? → Add
numberOfHeadsKey,numberOfHeadsValue - ❓ Are output and embedding weights shared? → Add
sharedWeightsboolean - ❓ Does it use sliding window attention? → Add
slidingWindowint
File: src/main/java/org/beehive/gpullama3/inference/state/{ModelName}State.java
package org.beehive.gpullama3.inference.state;
import org.beehive.gpullama3.model.Configuration;
public class {ModelName}State extends State {
public {ModelName}State(Configuration config, int batchSize) {
super(config, batchSize);
// Add model-specific state buffers here if needed
// Most models can use the base State class
}
}When to extend:
- Only create custom state if you need additional buffers
- Most models can use base
Stateclass directly
File: src/main/java/org/beehive/gpullama3/model/{modelname}/{ModelName}.java
package org.beehive.gpullama3.model.{modelname};
import org.beehive.gpullama3.inference.InferenceCore;
import org.beehive.gpullama3.inference.InferenceEngine;
import org.beehive.gpullama3.inference.sampler.Sampler;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.AbstractModel;
import org.beehive.gpullama3.model.ModelType;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import java.util.List;
import java.util.Set;
import java.util.function.IntConsumer;
public class {ModelName} extends AbstractModel {
private final {ModelName}Configuration configuration;
public {ModelName}({ModelName}Configuration configuration,
Tokenizer tokenizer,
Weights weights,
ChatFormat chatFormat) {
super(tokenizer, weights, chatFormat, null);
this.configuration = configuration;
}
@Override
public {ModelName}Configuration configuration() {
return configuration;
}
@Override
public ModelType getModelType() {
return ModelType.{MODEL_NAME};
}
@Override
public State createNewState() {
State state = new {ModelName}State(configuration(), -1);
// Set initial token (usually BOS token)
state.latestToken = tokenizer.getSpecialTokens().get("<bos>");
return state;
}
@Override
public State createNewState(int batchSize) {
State state = new {ModelName}State(configuration(), batchSize);
state.latestToken = tokenizer.getSpecialTokens().get("<bos>");
return state;
}
@Override
public boolean shouldAddBeginOfText() {
return true; // Most models use BOS token
}
@Override
public void forward(State state, int token, int position) {
if (plan == null) {
// CPU inference path
InferenceCore.forwardJava{ModelName}(this, state, token, position);
} else {
// GPU inference path
InferenceCore.forwardTornadoVM(this, state, token, position, tornadoVMPlan());
}
}
@Override
public List<Integer> generateTokens(State state, int startPosition,
List<Integer> promptTokens,
Set<Integer> stopTokens, int maxTokens,
Sampler sampler, boolean echo,
IntConsumer onTokenGenerated) {
// Choose generation method based on architecture similarity:
// - Standard: InferenceEngine.generateTokensLlama()
// - With Q/K norm: InferenceEngine.generateTokensQwen3()
return InferenceEngine.generateTokensLlama(this, state, startPosition,
promptTokens, stopTokens,
maxTokens, sampler, echo,
onTokenGenerated);
}
@Override
public List<Integer> generateTokensGPU(State state, int startPosition,
List<Integer> promptTokens,
Set<Integer> stopTokens, int maxTokens,
Sampler sampler, boolean echo,
IntConsumer onTokenGenerated,
TornadoVMMasterPlan tornadoVMPlan) {
return InferenceEngine.generateTokensGPULlama(this, state, startPosition,
promptTokens, stopTokens,
maxTokens, sampler, echo,
onTokenGenerated, tornadoVMPlan);
}
}File: src/main/java/org/beehive/gpullama3/tokenizer/impl/{ModelName}Tokenizer.java
package org.beehive.gpullama3.tokenizer.impl;
import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
import java.util.*;
public class {ModelName}Tokenizer implements Tokenizer {
private final Vocabulary vocabulary;
private final Map<String, Integer> specialTokens;
public {ModelName}Tokenizer(Map<String, Object> metadata, Vocabulary vocabulary) {
this.vocabulary = vocabulary;
// Load special tokens from vocabulary
this.specialTokens = new HashMap<>();
// Scan vocabulary for special tokens
for (int i = 0; i < vocabulary.size(); i++) {
String token = vocabulary.get(i);
if (isSpecialTokenPattern(token)) {
specialTokens.put(token, i);
}
}
}
private boolean isSpecialTokenPattern(String token) {
// Define what makes a token "special" for your model
// Common patterns: <bos>, <eos>, <pad>, etc.
return token.startsWith("<") && token.endsWith(">") &&
!token.matches("<0x[0-9a-fA-F]{2}>") && // Not byte tokens
!token.matches("<unused\\d+>"); // Not placeholders
}
@Override
public List<Integer> encodeAsList(String text) {
// Implement encoding logic
// For most models, can delegate to existing tokenizer
// or implement BPE/SentencePiece algorithm
return List.of(); // TODO: Implement
}
@Override
public String decode(List<Integer> tokens) {
StringBuilder sb = new StringBuilder();
for (int token : tokens) {
// Handle special cases:
// 1. Byte tokens (if model uses them)
// 2. Special tokens (skip)
// 3. Regular tokens
String tokenString = vocabulary.get(token);
if (isSpecialToken(token)) {
continue; // Skip special tokens
}
// Handle model-specific replacements
// Examples:
// - SentencePiece: replace ▁ with space
// - Some models: decode hex bytes
sb.append(tokenString);
}
return sb.toString();
}
@Override
public Map<String, Integer> getSpecialTokens() {
return specialTokens;
}
@Override
public boolean isSpecialToken(int tokenIndex) {
return specialTokens.containsValue(tokenIndex);
}
@Override
public boolean shouldDisplayToken(int token) {
return !isSpecialToken(token);
}
}Key Decisions:
- Tokenization Algorithm: BPE, SentencePiece, WordPiece?
- Byte-Level Encoding: Does the model use raw bytes for first 256 tokens?
- Special Characters: How are spaces represented? (▁ in SentencePiece)
- Metadata Keys: Where are merges, vocab, and scores stored in GGUF?
File: src/main/java/org/beehive/gpullama3/model/format/{ModelName}ChatFormat.java
package org.beehive.gpullama3.model.format;
import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
import java.util.*;
public class {ModelName}ChatFormat implements ChatFormat {
private final int beginOfText;
private final int endOfText;
private final Set<Integer> stopTokens;
private final Tokenizer tokenizer;
public {ModelName}ChatFormat(Tokenizer tokenizer) {
this.tokenizer = tokenizer;
Map<String, Integer> specialTokens = tokenizer.getSpecialTokens();
// Load special tokens
this.beginOfText = specialTokens.getOrDefault("<bos>", -1);
this.endOfText = specialTokens.getOrDefault("<eos>", -1);
// Define stop tokens
this.stopTokens = new HashSet<>();
if (endOfText != -1) {
stopTokens.add(endOfText);
}
// Add model-specific stop tokens
}
@Override
public List<Integer> encodeHeader(Message message) {
List<Integer> tokens = new ArrayList<>();
// Encode role header
// Example: <|start_header_id|>user<|end_header_id|>
return tokens;
}
@Override
public List<Integer> encodeMessage(Message message) {
List<Integer> tokens = new ArrayList<>();
// Encode complete message with header and content
// Follow the model's specific chat template
tokens.addAll(encodeHeader(message));
tokens.addAll(tokenizer.encodeAsList(message.content().strip()));
// Add end-of-message tokens
return tokens;
}
@Override
public int getBeginOfText() {
return beginOfText;
}
@Override
public Set<Integer> getStopTokens() {
return stopTokens;
}
}Chat Template Research:
- Check model card on HuggingFace for
tokenizer_config.json - Look for
chat_templatefield in GGUF metadata - Reference implementations in transformers library
Common Templates:
- Llama 3:
<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n{msg}<|eot_id|> - Gemma:
<bos><start_of_turn>user\n{msg}<end_of_turn>\n<start_of_turn>model\n - ChatML:
<|im_start|>user\n{msg}<|im_end|>\n<|im_start|>assistant\n
CPU Weights - src/main/java/org/beehive/gpullama3/inference/weights/standard/{ModelName}StandardWeights.java:
package org.beehive.gpullama3.inference.weights.standard;
import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
public class {ModelName}StandardWeights extends StandardWeights {
// Add model-specific weight fields
// Example for sandwich normalization:
// public final FloatTensor[] postAttentionNorm;
// public final FloatTensor[] postFFNNorm;
public {ModelName}StandardWeights(
FloatTensor token_embedding_table,
FloatTensor[] rms_att_weight,
FloatTensor[] wq,
FloatTensor[] wk,
FloatTensor[] wv,
FloatTensor[] wo,
FloatTensor[] rms_ffn_weight,
FloatTensor[] w1,
FloatTensor[] w2,
FloatTensor[] w3,
FloatTensor rms_final_weight,
FloatTensor freq_cis_real,
FloatTensor freq_cis_imag,
FloatTensor wcls,
GGMLType ggmlType
// Add custom parameters
) {
super(token_embedding_table, rms_att_weight, wq, wk, wv, wo,
rms_ffn_weight, w1, w2, w3, rms_final_weight,
freq_cis_real, freq_cis_imag, wcls, ggmlType);
// Initialize custom fields
}
}GPU Weights - src/main/java/org/beehive/gpullama3/inference/weights/tornado/{ModelName}TornadoWeights.java:
package org.beehive.gpullama3.inference.weights.tornado;
import org.beehive.gpullama3.core.model.GGMLType;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
public class {ModelName}TornadoWeights extends FP16Weights {
// Add model-specific weight arrays
// Use FloatArray for GPU memory
public {ModelName}TornadoWeights(/* parameters */) {
super(/* base parameters */);
// Initialize custom fields
}
}File: src/main/java/org/beehive/gpullama3/model/loader/{ModelName}ModelLoader.java
package org.beehive.gpullama3.model.loader;
import org.beehive.gpullama3.core.model.GGUF;
import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
import org.beehive.gpullama3.core.types.Pair;
import org.beehive.gpullama3.inference.operation.RoPE;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.model.{modelname}.*;
import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
import java.io.IOException;
import java.nio.channels.FileChannel;
import java.util.Map;
public class {ModelName}ModelLoader extends ModelLoader {
public {ModelName}ModelLoader(FileChannel fileChannel, GGUF gguf,
int contextLength, boolean loadWeights,
boolean useTornadoVM) {
super(fileChannel, gguf, contextLength, loadWeights, useTornadoVM);
}
@Override
public {ModelName} loadModel() {
try {
Map<String, Object> metadata = gguf.getMetadata();
// 1. LOAD VOCABULARY
Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata);
Tokenizer tokenizer = new {ModelName}Tokenizer(metadata, vocabulary);
// 2. DETECT METADATA PREFIX
// Try different prefixes: {model}. or llama. or mistral.
String prefix;
if (metadata.containsKey("{model}.embedding_length")) {
prefix = "{model}.";
} else if (metadata.containsKey("llama.embedding_length")) {
prefix = "llama.";
} else {
throw new RuntimeException("Unknown architecture");
}
// 3. LOAD CONFIGURATION FROM METADATA
int dim = (int) metadata.get(prefix + "embedding_length");
int hiddenDim = (int) metadata.get(prefix + "feed_forward_length");
int nLayers = (int) metadata.get(prefix + "block_count");
int nHeads = (int) metadata.get(prefix + "attention.head_count");
int nKVHeads = metadata.containsKey(prefix + "attention.head_count_kv")
? (int) metadata.get(prefix + "attention.head_count_kv")
: nHeads;
int ctxLength = (int) metadata.get(prefix + "context_length");
float rmsNormEps = (float) metadata.getOrDefault(
prefix + "attention.layer_norm_rms_epsilon", 1e-6f);
float ropeTheta = (float) metadata.getOrDefault(
prefix + "rope.freq_base", 10000f);
// 4. LOAD TENSOR ENTRIES
Map<String, GGMLTensorEntry> tensorEntries =
GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(),
gguf.getTensorInfos());
// 5. GET VOCAB SIZE FROM EMBEDDINGS TENSOR
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
int[] embShape = tokenEmbeddings.shape();
int vocabSize = embShape.length > 1 ? embShape[1] : embShape[0];
// 6. CREATE CONFIGURATION
int actualContextLength = (contextLength < 0) ? ctxLength : contextLength;
{ModelName}Configuration config = new {ModelName}Configuration(
dim, hiddenDim, nLayers, nHeads, nKVHeads,
vocabSize, actualContextLength, rmsNormEps, ropeTheta
// Add model-specific parameters
);
// 7. LOAD WEIGHTS
Weights weights = null;
if (loadWeights) {
weights = loadWeights(tensorEntries, config);
}
// 8. RETURN MODEL
return new {ModelName}(config, tokenizer, weights,
ChatFormat.create(tokenizer, null));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
@Override
public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries,
Configuration config) {
// Precompute RoPE frequencies
Pair<float[], float[]> ropeFreqs = RoPE.precomputeFreqsCis(
config.contextLength(),
config.headSize(),
config.ropeTheta(),
false, 0, 0, 0, 0
);
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
GGMLTensorEntry outputWeight = tensorEntries.getOrDefault(
"output.weight", tokenEmbeddings);
if (useTornadovm) {
return createTornadoVMWeights(tensorEntries, config, ropeFreqs,
tokenEmbeddings, outputWeight);
} else {
return createStandardWeights(tensorEntries, config, ropeFreqs,
tokenEmbeddings, outputWeight);
}
}
@Override
public Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries,
Configuration config,
Pair<float[], float[]> ropeFreqs,
GGMLTensorEntry tokenEmbeddings,
GGMLTensorEntry outputWeight) {
// Load all weight tensors
// Pattern: "blk.{layer}.{component}.weight"
return new {ModelName}StandardWeights(
loadQuantized(tokenEmbeddings),
loadArrayOfQuantized(config.numberOfLayers(),
i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers(),
i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfQuantized(config.numberOfLayers(),
i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
loadArrayOfQuantized(config.numberOfLayers(),
i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
loadArrayOfQuantized(config.numberOfLayers(),
i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
// ... load all tensors
loadQuantized(tensorEntries.get("output_norm.weight")),
new ArrayFloatTensor(ropeFreqs.first()),
new ArrayFloatTensor(ropeFreqs.second()),
loadQuantized(outputWeight),
outputWeight.ggmlType()
);
}
@Override
public Weights createTornadoVMWeights(/* ... */) {
// Similar to createStandardWeights but using FloatArray
// Use loadTensorAsFloatArray() and loadArrayAsFloatArrayFromBuffer()
return new {ModelName}TornadoWeights(/* ... */);
}
}Debugging Tips:
- Print all tensor names:
tensorEntries.keySet().stream().sorted().forEach(System.err::println); - Check tensor shapes:
System.err.println("Shape: " + Arrays.toString(tensor.shape())); - Verify metadata keys:
metadata.keySet().stream().filter(k -> k.startsWith("llama")).forEach(System.err::println);
File: src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
Add method:
public static FloatTensor forwardJava{ModelName}(Model model, State state,
int token, int position) {
Configuration config = model.configuration();
{ModelName}StandardWeights weights = ({ModelName}StandardWeights) model.weights();
int dim = config.dim();
int kvDim = config.kvDim();
int kvMul = config.kvMul();
int headSize = config.headSize();
int hiddenDim = config.hiddenDim();
// 1. COPY TOKEN EMBEDDING
weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);
// 2. APPLY EMBEDDING SCALING (if model requires)
// Example for Gemma:
// float embeddingScale = (float) Math.sqrt(dim);
// for (int i = 0; i < dim; i++) {
// state.x.setFloat(i, state.x.getFloat(i) * embeddingScale);
// }
// 3. FORWARD THROUGH ALL LAYERS
for (int l = 0; l < config.numberOfLayers(); l++) {
int curLayer = l;
// ===== ATTENTION BLOCK =====
// 3.1 Pre-normalization
rmsnorm(state.xb, state.x, weights.rms_att_weight[curLayer],
dim, config.rmsNormEps());
// 3.2 QKV projections
weights.wq[l].matmul(state.xb, state.q, dim, dim);
weights.wk[l].matmul(state.xb, state.k, dim, kvDim);
weights.wv[l].matmul(state.xb, state.v, dim, kvDim);
// 3.3 Apply Q/K normalization (if model uses it)
// rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], ...);
// rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], ...);
// 3.4 Apply RoPE
for (int i = 0; i < dim; i += 2) {
int head_dim = i % headSize;
float fcr = weights.freq_cis_real.getFloat(position * (dim / 2) + i / 2);
float fci = weights.freq_cis_imag.getFloat(position * (dim / 2) + i / 2);
float q0 = state.q.getFloat(i);
float q1 = state.q.getFloat(i + 1);
state.q.setFloat(i, q0 * fcr - q1 * fci);
state.q.setFloat(i + 1, q0 * fci + q1 * fcr);
}
// Apply RoPE to keys similarly
// 3.5 Store KV in cache
int loff = l * config.contextLength() * kvDim;
state.k.copyTo(0, state.key_cache, loff + position * kvDim, kvDim);
state.v.copyTo(0, state.value_cache, loff + position * kvDim, kvDim);
// 3.6 Multi-head attention
for (int h = 0; h < config.numberOfHeads(); h++) {
// Compute attention for this head
// See existing implementations for detailed attention logic
}
// 3.7 Output projection
weights.wo[l].matmul(state.xb, state.xb2, dim, dim);
// 3.8 Apply post-attention normalization (if model uses it)
// rmsnorm(state.xb2, state.xb2, weights.postAttentionNorm[curLayer], ...);
// 3.9 Residual connection
state.x.addInPlace(state.xb2);
// ===== FFN BLOCK =====
// 3.10 Pre-normalization
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[curLayer],
dim, config.rmsNormEps());
// 3.11 FFN computation (SwiGLU activation)
weights.w1[l].matmul(state.xb, state.hb, dim, hiddenDim);
weights.w3[l].matmul(state.xb, state.hb2, dim, hiddenDim);
// Apply activation
for (int i = 0; i < hiddenDim; i++) {
float val = state.hb.getFloat(i);
val = val / (1.0f + (float) Math.exp(-val)); // Swish
val *= state.hb2.getFloat(i); // Gate
state.hb.setFloat(i, val);
}
// 3.12 Output projection
weights.w2[l].matmul(state.hb, state.xb2, hiddenDim, dim);
// 3.13 Apply post-FFN normalization (if model uses it)
// rmsnorm(state.xb2, state.xb2, weights.postFFNNorm[curLayer], ...);
// 3.14 Residual connection
state.x.addInPlace(state.xb2);
}
// 4. FINAL LAYER NORM
rmsnorm(state.x, state.x, weights.rms_final_weight, dim, config.rmsNormEps());
// 5. CLASSIFIER
weights.wcls.matmul(state.x, state.logits, dim, config.vocabularySize());
return state.logits;
}Key Considerations:
- Normalization: RMSNorm, LayerNorm, or custom?
- Activation: SwiGLU, GELU, ReLU?
- Attention: Standard, GQA, sliding window?
- Special operations: Embedding scaling, rope scaling, etc.
File: src/main/java/org/beehive/gpullama3/model/ModelType.java
{MODEL_NAME} {
@Override
public Model loadModel(FileChannel fileChannel, GGUF gguf,
int contextLength, boolean loadWeights,
boolean useTornadovm) {
return new {ModelName}ModelLoader(fileChannel, gguf, contextLength,
loadWeights, useTornadovm).loadModel();
}
}File: src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java
else if (lowerName.contains("{model}")) {
return ModelType.{MODEL_NAME};
}File: src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
case {MODEL_NAME} -> createLlamaPlanner(state, model); // or createQWEN3PlannerPlanner Selection:
- Use
createLlamaPlannerfor standard transformers - Use
createQWEN3Plannerfor models with Q/K normalization - Create custom planner if architecture is significantly different
Create test file: src/test/java/org/beehive/gpullama3/model/{modelname}/{ModelName}Test.java
@Test
public void testTokenization() {
// Test basic tokenization
}
@Test
public void testChatFormatting() {
// Test chat template
}
@Test
public void testModelLoading() {
// Test GGUF loading
}# 1. Test model loading
./llama-tornado --model {model}.gguf --prompt "test"
# 2. Test with different quantizations
./llama-tornado --model {model}-Q8_0.gguf --prompt "Hello"
./llama-tornado --model {model}-f16.gguf --prompt "Hello"
# 3. Test CPU vs GPU
./llama-tornado --model {model}.gguf --prompt "test" # CPU
./llama-tornado --model {model}.gguf --prompt "test" --gpu # GPU
# 4. Test interactive mode
./llama-tornado --model {model}.gguf -i
# 5. Test with system prompt
./llama-tornado --model {model}.gguf --prompt "test" -sp "You are a helpful assistant"-
Model loads without errors
- Check metadata keys match expected names
- Verify all tensors are found
-
Vocabulary size matches
- Compare GGUF vocab size with config
- Check embedding tensor shape
-
Tokenization works
- Test encode/decode round-trip
- Verify special tokens are recognized
-
Generates tokens
- Not just stop tokens immediately
- Token IDs are within vocabulary range
-
Output is readable
- Not garbled or nonsensical
- Follows prompt context
-
Performance is reasonable
- CPU: 5-20 tok/s depending on size
- GPU: 50-200 tok/s depending on size
- 2 norm layers per block
- Standard multi-head attention
- SwiGLU activation
- RoPE position encoding
Reuse:
StandardWeightsclassforwardJavaLlamainferenceLlamaChatFormat(with modifications)
- Fewer KV heads than Q heads
- Otherwise similar to Llama
Reuse:
- Same as Llama
- Adjust
numberOfKeyValueHeadsin config
- Per-head normalization of Q and K
- May use separate head dimensions
Reuse:
StandardWeightsWithQKNormbase classforwardJavaQwen3inferencegenerateTokensQwen3generation method
- 4 norm layers per block
- Pre and post normalization
New Implementation Required:
- Custom weights class with 4 norm arrays
- Custom forward pass with extra norm steps
Symptoms: Exception during model loading
Debug Steps:
- Print all metadata keys:
metadata.keySet().forEach(System.err::println);
- Check architecture name:
String arch = (String) metadata.get("general.architecture"); System.err.println("Architecture: " + arch);
- Try different prefixes (llama., mistral., {model}.)
Symptoms: Model generates stop token as first token
Possible Causes:
- Chat format is wrong (missing model turn setup)
- Normalization epsilon is incorrect
- Embedding scaling is missing or wrong
- Weights are loaded incorrectly
Debug:
- Enable echo mode to see what's generated
- Check prompt token IDs are correct
- Verify chat template matches model's expected format
- Add debug prints in forward pass to check tensor values
Symptoms: Nonsensical or random characters
Possible Causes:
- Tokenizer decode logic is wrong
- Byte tokens not handled correctly
- Special tokens not filtered
- Wrong vocabulary
Debug:
- Print token IDs being generated
- Check token ID → string mapping
- Verify byte token handling
- Test with known-good prompts
Symptoms: Much slower than expected
Possible Causes:
- Not using vectorization (Java Vector API)
- Memory layout inefficient
- Missing optimizations in matmul
Solutions:
- Check
USE_VECTOR_APIflag is enabled - Profile with JMH
- Compare with reference implementation speeds
Symptoms: GPU mode crashes or falls back to CPU
Possible Causes:
- TornadoVM not installed correctly
- Wrong planner selected
- Memory insufficient
Debug:
- Check TornadoVM installation:
tornado --devices - Try with smaller model first
- Enable verbose logging:
--verbose-init
Before considering implementation complete:
- Model loads from GGUF file
- Tokenization encode/decode works
- Chat format is correct
- Generates coherent output
- Stop tokens work correctly
- Special tokens are handled
- Multiple quantization types work (Q8_0, F16)
- CPU inference speed is reasonable
- GPU inference works (if applicable)
- Memory usage is acceptable
- No memory leaks
- Follows existing code style
- Has inline documentation
- Complex logic is commented
- No debug prints in production code
- Exception handling is proper
- Manual testing with various prompts
- Tested with different quantization formats
- Tested in interactive mode
- Tested with system prompts
- Compared output with reference implementation
- Changes documented in CHANGES.md
- Added model to README.md
- Chat template documented
- Any quirks or limitations noted
- Model cards with architecture details
config.jsonfor hyperparameterstokenizer_config.jsonfor tokenization
- Reference C++ implementations
- GGUF format documentation
- Performance benchmarks
- Original model papers
- Architecture variants
- Tokenization methods
- GitHub issues for similar models
- Discord/forums for Q&A
- Existing PRs as examples
# Download model from HuggingFace
huggingface-cli download {org}/{model}-GGUF {file}.gguf --local-dir .
# Build project
make clean && make
# Test basic inference
./llama-tornado --model {model}.gguf --prompt "Hello, how are you?"
# Test with echo to see tokens
./llama-tornado --model {model}.gguf --prompt "test" --echo true
# Interactive mode
./llama-tornado --model {model}.gguf -i
# GPU mode
./llama-tornado --model {model}.gguf --prompt "test" --gpu --gpu-memory 8GB
# Debug vocabulary
./llama-tornado --model {model}.gguf --prompt "test" 2>&1 | grep -i vocabAdding a new model requires:
- Understanding the architecture deeply
- Implementing 8-10 core classes
- Testing thoroughly
- Debugging patiently
Estimated Time: 1-3 days for experienced developers
Difficulty Factors:
- Standard transformer: ⭐⭐ (Easy)
- With GQA: ⭐⭐⭐ (Medium)
- With Q/K norm: ⭐⭐⭐⭐ (Hard)
- Completely custom: ⭐⭐⭐⭐⭐ (Expert)
Good luck! 🚀