This module provides a custom GPT-2 architecture for serving with max serve using the Module V3 API.
Current Status: Work in Progress
The model compiles and serves successfully, but produces incorrect (gibberish) output. The issue is still being investigated. The standalone model in main.py works correctly.
pixi installpixi run max serve \
--model openai-community/gpt2 \
--custom-architectures gpt2_module_v3 \
--port 8888Note: We do NOT use
--use-module-v3here because we're registering a new architecture. the--use-module-v3flag is only needed when adding a new version of an existing MAX-registered architecture (it automatically appends_ModuleV3to the architecture name).
GPT-2 is a base language model (not a chat model), so use the completions API:
curl http://localhost:8888/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "openai-community/gpt2",
"prompt": "The future of AI",
"max_tokens": 30
}' | jq .Note: Do NOT use /v1/chat/completions.
GPT-2 does not have a chat template.
| File | Description |
|---|---|
__init__.py |
Exports ARCHITECTURES list for custom arch discovery |
arch.py |
Defines SupportedArchitecture for GPT-2 |
model.py |
GPT2Model class extending PipelineModel |
model_config.py |
GPT2Config dataclass for model configuration |
gpt2.py |
Neural network module definitions (attention, MLP, etc.) |
weight_adapters.py |
Converts HuggingFace safetensor weights to MAX format |
Key requirements for custom architecture registration:
-
Export
ARCHITECTURES: The__init__.pymust export anARCHITECTURESlist:ARCHITECTURES = [gpt2]
-
Weight Adapter: Register a weight adapter for safetensor format:
weight_adapters={ WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict, }
The following changes were required to adapt the model from main.py for serving:
Original (main.py): 2D input [batch_size, seq_length]
batch_size, seq_length = input_ids.shapeServing: 1D ragged input [total_tokens] with input_row_offsets
seq_length = tokens.shape[0] # Flattened tokens
# input_row_offsets tells where each sequence starts/endsBoth implementations use Tensor.arange:
positions = Tensor.arange(seq_length, dtype=tokens.dtype, device=tokens.device)
pos_embeds = self.wpe(positions)GPT-2 uses Conv1D layers which store weights as [in_features, out_features], but MAX's Linear expects [out_features, in_features]. Required transposition for:
.c_attn.weight.c_proj.weight.c_fc.weight
Important: The weight adapter must return raw numpy arrays, not WeightData objects:
# Correct:
new_state_dict[max_name] = arr
# Wrong (causes issues):
new_state_dict[max_name] = WeightData.from_numpy(arr, max_name)Transposed arrays must be made contiguous:
arr = np.ascontiguousarray(arr.T)GPT-2 ties lm_head.weight to wte.weight:
if "language_model.lm_head.weight" not in new_state_dict:
new_state_dict["language_model.lm_head.weight"] = wte_array.copy()-
F.rangevsTensor.arange: The functionalF.rangeAPI was deprecated/changed. Had to useTensor.arangeinstead. -
DLPack Conversion: Weight data from safetensors required careful conversion:
arr = np.array(np.from_dlpack(weight_data), copy=True)
-
Non-Contiguous Tensor Errors: MAX doesn't support non-contiguous tensors. Error message:
ValueError: Max does not currently support executing non-contiguous tensors.Solution: Always use
np.ascontiguousarray()after transpose. -
Weight Adapter Return Type: Despite type hint
dict[str, WeightData], the actual return must be raw data (numpy arrays), following the pattern ingpt_oss_module_v3. -
Chat Template Error: GPT-2 is a base model without chat template. Using
/v1/chat/completionsresults in:ValueError: Cannot use chat template functions because tokenizer.chat_template is not setSolution: Use
/v1/completionsinstead.
The model currently produces gibberish output when served, despite:
- Weights loading correctly (verified via logging)
- Weight shapes being correct
- Transposition being applied
- Tied embeddings being handled
The standalone test comparing main.py model with the serving model shows identical output when weights are loaded via load_state_dict from PyTorch, suggesting the issue may be in how weights flow through the compile/serve pipeline.