TritonForge is a server-based SFT + RL post-training framework that operates across both AMD and NVIDIA platforms, featuring complete disaggregation of training, inference, and evaluation components.
graph TB
subgraph DL["Data Layer"]
KB["KernelBook Dataset<br/>18.2k PyTorch-to-Triton pairs"]
KBench["KernelBench Dataset<br/>200+ problems L1-L4"]
subgraph Pipeline["SFT Data Pipeline"]
DP1[Multi-Turn Data Generator]
DP2[Thinking Tags Injector]
DP3["Length & Quality Filter"]
DP1 --> DP2
DP2 --> DP3
end
KB --> DP1
KBench --> DP1
DP3 --> ProcessedData["Processed Training Data<br/>~17k filtered samples"]
end
subgraph TI["Training Infrastructure"]
subgraph SLIME["SLIME Framework (Server-Based)"]
Router["SLIME Router<br/>Orchestration Layer"]
subgraph ActorG["Actor Group (Training)"]
MegatronServer["Megatron-LM Server<br/>Distributed Training<br/>TP=2, CP=4, PP=1"]
GRPO["GRPO Actor<br/>Group Relative Policy Optimization"]
WeightSync[Weight Synchronizer]
end
subgraph RolloutG["Rollout Group (Generation)"]
SGLangServer["SGLang Server Pool<br/>High-Performance Inference"]
RolloutBuffer["Async Rollout Buffer<br/>Experience Collection"]
subgraph Gens[Generators]
SingleTurnGen[Single-Turn Generator]
MultiTurnGen["Multi-Turn Generator<br/>Max 3 iterations"]
end
end
Router --> MegatronServer
Router --> SGLangServer
Router --> RolloutBuffer
WeightSync -.->|Weight Updates| SGLangServer
end
end
subgraph EI["Evaluation Infrastructure"]
subgraph EvalServers["Disaggregated Eval Servers"]
EvalServer1["KernelBench Eval Server<br/>Port 18188"]
EvalServer2["Remote Eval Server<br/>Compilation & Correctness"]
EvalServer3["Performance Eval Server<br/>Speedup Metrics"]
end
subgraph Backends["Platform-Specific Backends"]
NVBackend["NVIDIA Backend<br/>CUDA 12.6+<br/>Triton Compiler"]
AMDBackend["AMD Backend<br/>ROCm 6.3+<br/>HIP Translation Layer"]
end
MultiTurnGen -->|Eval Request| EvalServer1
SingleTurnGen -->|Eval Request| EvalServer1
EvalServer1 --> EvalServer2
EvalServer2 --> EvalServer3
EvalServer3 --> NVBackend
EvalServer3 --> AMDBackend
end
subgraph ME["Model Evolution"]
BaseModel["Qwen3-8B<br/>Base Model"]
SFTModel["Qwen3-8B-Kernelbook-SFT<br/>Fine-tuned Model"]
RLModel["TritonForge-8B<br/>RL-Optimized Model"]
BaseModel -->|SFT Training| SFTModel
SFTModel -->|RL Training| RLModel
end
ProcessedData --> MegatronServer
MegatronServer --> SFTModel
subgraph RS["Reward System"]
CompileReward["Compilation Success<br/>+0.1 base"]
CorrectReward["Functional Correctness<br/>+0.3 if correct"]
SpeedupReward["Performance Speedup<br/>+(speedup-1), max 2.0"]
DiscountFactor["γ = 0.4<br/>Multi-turn discount"]
CompileReward --> RewardAgg[Reward Aggregator]
CorrectReward --> RewardAgg
SpeedupReward --> RewardAgg
DiscountFactor --> RewardAgg
end
NVBackend --> RewardAgg
AMDBackend --> RewardAgg
RewardAgg --> GRPO
style Router fill:#f9f,stroke:#333,stroke-width:4px
style SFTModel fill:#9f9,stroke:#333,stroke-width:2px
style RLModel fill:#99f,stroke:#333,stroke-width:4px
flowchart LR
subgraph RDS["Raw Data Sources"]
KB["KernelBook<br/>18.2k samples"]
end
subgraph DA["Data Augmentation"]
MT[Multi-Turn Converter]
TT["Add Thinking Tags<br/>⟨thinking⟩...⟨/thinking⟩"]
Conv[Chat Template Format]
KB --> MT
KBench["KBench Eval Backend"] --> MT
MT --> TT
TT --> Conv
end
subgraph QC["Quality Control"]
LenFilter["Length Filter<br/>Min: 100 tokens<br/>Max: 8192 tokens"]
QualFilter["Quality Filter<br/>Remove invalid code"]
Dedup[Deduplication]
Conv --> LenFilter
LenFilter --> QualFilter
QualFilter --> Dedup
end
subgraph OF["Output Formats"]
JSONL["JSONL Format<br/>kernel_bench_triton_all_levels.jsonl"]
Dedup --> JSONL
end
JSONL --> FinalData[17k Filtered Samples]
flowchart TB
subgraph ATL["Async Training Loop"]
Init[Initialize Ray Cluster]
CreatePG["Create Placement Groups<br/>Actor GPUs | Rollout GPUs"]
StartServers["Start All Servers<br/>Megatron | SGLang | Eval"]
Init --> CreatePG
CreatePG --> StartServers
subgraph TIter["Training Iteration"]
GenExp["Generate Experience<br/>Async Rollout"]
CollectData[Collect to Buffer]
TrainStep[GRPO Training Step]
UpdateWeights[Update Weights]
GenExp --> CollectData
CollectData --> TrainStep
TrainStep --> UpdateWeights
UpdateWeights -->|Every N steps| GenExp
end
StartServers --> GenExp
end
subgraph SC["Server Communication"]
HTTP["HTTP/REST APIs<br/>Between servers"]
Ray["Ray Remote Calls<br/>Actor coordination"]
Queue["Message Queues<br/>Async buffering"]
HTTP -.-> GenExp
Ray -.-> CollectData
Queue -.-> TrainStep
end
sequenceDiagram
participant User
participant Router as SLIME Router
participant Gen as Multi-Turn Generator
participant SGLang as SGLang Server
participant Eval as Eval Server
participant Backend as GPU Backend
User->>Router: Start RL Training
Router->>Gen: Initialize Multi-Turn Rollout
loop For each problem (up to 3 turns)
Gen->>SGLang: Generate Kernel (Turn N)
SGLang-->>Gen: Triton Code
Gen->>Eval: Evaluate Kernel
Eval->>Backend: Compile & Execute
alt Compilation Success
Backend-->>Eval: Runtime & Correctness
Eval-->>Gen: Reward + Metrics
alt If Correct & Fast
Gen-->>Gen: Mark Complete
else If Issues Found
Gen-->>Gen: Prepare Feedback
Note over Gen: Include error messages<br/>and performance data
Gen->>SGLang: Generate Improved Kernel
end
else Compilation Failed
Backend-->>Eval: Error Message
Eval-->>Gen: Negative Reward
Gen-->>Gen: Prepare Error Feedback
end
end
Gen->>Router: Return Trajectory
Router->>Router: Aggregate Returns (γ=0.4)
Router-->>User: Training Metrics
Hardware:
GPU: H100 80GB
CUDA: 12.6.1
Triton: 3.0+
Training Config:
Tensor Parallel: 2
Context Parallel: 4
Pipeline Parallel: 1
Batch Size: 32
Learning Rate: 1e-5
Evaluation:
Backend: CUDA + Triton JIT
Profiling: NSight Systems
Metrics: FLOPS, Memory BandwidthHardware:
GPU: MI300X 192GB
ROCm: 6.3.4
HIP: 6.3+
Training Config:
Tensor Parallel: 2
Context Parallel: 4
Pipeline Parallel: 1
Batch Size: 32
Learning Rate: 1e-5
Evaluation:
Backend: ROCm + HIP Translation
Profiling: rocprof
Metrics: FLOPS, Memory Bandwidth
Special Handling:
- Subprocess isolation for memory faults
- HIP_PLATFORM=amd environment
- PYTORCH_ROCM_ARCH=gfx942- Purpose: Orchestrates communication between all components
- Features:
- Async message passing
- Load balancing across servers
- Fault tolerance and retry logic
- Weight synchronization management
- Purpose: Asynchronously collects generation experiences
- Features:
- Multi-process data collection
- Experience replay buffer
- Priority sampling support
- Trajectory aggregation
- Purpose: Isolated evaluation environments
- Features:
- Sandboxed execution
- Resource limit enforcement
- Performance profiling
- Error recovery mechanisms
-
GRPO (Group Relative Policy Optimization):
- Uses group-relative returns instead of traditional advantage estimation
- Better suited for multi-turn refinement scenarios
- Reduces variance in policy gradient estimation
-
Multi-Turn Aggregation:
total_return = sum(reward_t * (gamma ** t) for t, reward_t in enumerate(rewards))
-
Reward Components:
- Compilation: 0.1 base reward (if code compiles successfully)
- Correctness: 0.3 total reward (if functionally correct, includes compilation)
- Performance: +(speedup - 1.0), capped at 2.0
- Example: 2x speedup → +1.0 reward, 3x speedup → +2.0 reward (max)
- Maximum total reward: 2.3 (0.3 correctness + 2.0 max performance)
- Input: PyTorch operations from KernelBook/KernelBench
- SFT Pipeline: Multi-turn augmentation → Thinking tags → Filtering → 17k samples
- SFT Training: Distributed training via Megatron-LM → Fine-tuned model
- RL Pipeline: Multi-turn generation → Evaluation → Reward → GRPO updates
- Output: Optimized TritonForge model capable of generating efficient GPU kernels
- Green: Component operational
- Yellow: Component under load
- Red: Component failed/recovering
- Dotted Lines: Async communication
- Solid Lines: Sync communication
graph LR
Generate[Generate Kernel] --> Evaluate[Evaluate Performance]
Evaluate --> Feedback[Compilation + Runtime Feedback]
Feedback --> Improve[Multi-Turn Refinement]
Improve --> Generate
Evaluate --> Reward[Calculate Reward]
Reward --> Train[GRPO Training]
Train --> Model[Update Model]
Model --> Generate
This architecture enables:
- ✅ Cross-platform GPU kernel generation (NVIDIA & AMD)
- ✅ Server-based disaggregated training at scale
- ✅ Multi-turn iterative kernel refinement
- ✅ Comprehensive evaluation with compilation and performance metrics
- ✅ Asynchronous training for improved efficiency
- ✅ Fault tolerance and recovery mechanisms