Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
```
src/
├── lib.rs # Public API re-exports + module declarations
├── config.rs # AttnResConfig — validated builder pattern
├── config.rs # AttnResConfig — validated builder pattern (JSON save/load)
├── attn_res_op.rs # Core AttnRes operation (depth-wise softmax attention)
├── block_state.rs # BlockState — cumulative block representation tracking
├── layer.rs # AttnResLayer — transformer layer with dual AttnRes
├── model.rs # AttnResTransformer — full model (embed → layers → LM head)
├── model.rs # AttnResTransformer — full model with standard + two-phase forward
├── rms_norm.rs # RMSNorm implementation
├── two_phase.rs # Two-phase inference optimization
├── serialization.rs # Model weight save/load (NamedMpk, binary, compact formats)
├── two_phase.rs # Two-phase inference primitives (phase1_batched, online_softmax_merge)
├── attention.rs # Multi-head self-attention
├── feed_forward.rs # SwiGLU-style MLP
└── utils.rs # Causal mask generation helpers
Expand Down Expand Up @@ -54,7 +55,7 @@ fixtures/ # Reference outputs from PyTorch

```bash
cargo build # Build the project
cargo test --all-features # Run all 57 tests
cargo test --all-features # Run all 66 tests
cargo test test_name # Run specific test
cargo clippy -- -D warnings # Lint (warnings = errors)
cargo fmt # Format code
Expand Down Expand Up @@ -112,7 +113,7 @@ Input IDs → Embedding → [AttnResLayer × N] → RMSNorm → LM Head → Logi

## Known Gaps

- No safetensors serialization
- Two-phase inference not integrated into main forward path
- No PyTorch checkpoint loading (safetensors format)
- GPU backends (wgpu, CUDA, Metal) untested
- No distributed training support
- Pre-trained weight import/export utilities
13 changes: 8 additions & 5 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@ attnres-rs: First Rust implementation of Attention Residuals (MoonshotAI/Kimi pa
| ML Framework| burn | 0.20 | tracel-ai/burn — multi-backend DL framework |
| Backends | CUDA, Metal, wgpu, NdArray | — | NdArray for CPU testing, wgpu for cross-platform GPU |
| Testing | cargo test | — | + proptest (property-based), criterion (benchmarks) |
| Serialization | safetensors | — | For weight loading/saving |
| Serialization | burn record (NamedMpk, bin) | — | Model weight save/load via burn's record system |
| Linting | clippy + rustfmt | — | Enforced in CI |
| CI | GitHub Actions | — | cargo test, clippy, fmt, build-examples |
</stack>

<status>
PROJECT PHASE: Alpha (v0.1.0 — core algorithm implemented, tests passing).
All source modules implemented. 57 tests passing (28 inline unit + 18 external unit + 3 differential + 2 property + 5 integration + 1 doctest).
CI configured (test, clippy, fmt, build-examples). Examples and benchmarks functional. burn upgraded to 0.20.
Known gaps: no safetensors serialization, two-phase inference not integrated into main forward path, GPU backends untested.
PROJECT PHASE: v0.2.0 — serialization and two-phase inference integrated.
All source modules implemented. 66 tests passing (32 inline unit + 18 external unit + 3 differential + 2 property + 9 integration + 2 doctest).
CI configured (test, clippy, fmt, build-examples). Examples and benchmarks functional. burn 0.20.
Model save/load via burn record system (NamedMpk, binary, compact). Config save/load via JSON.
Two-phase inference integrated into model via `forward_two_phase` method.
Known gaps: no PyTorch checkpoint import, GPU backends untested.
</status>

<structure>
Expand All @@ -44,6 +46,7 @@ attnres-rs/
│ ├── layer.rs # AttnResLayer [agent: CREATE/MODIFY]
│ ├── model.rs # AttnResTransformer [agent: CREATE/MODIFY]
│ ├── rms_norm.rs # RMSNorm implementation [agent: CREATE/MODIFY]
│ ├── serialization.rs # Model weight save/load [agent: CREATE/MODIFY]
│ ├── two_phase.rs # Two-phase inference [agent: CREATE/MODIFY]
│ ├── attention.rs # Multi-head attention [agent: CREATE/MODIFY]
│ ├── feed_forward.rs # MLP module [agent: CREATE/MODIFY]
Expand Down
10 changes: 6 additions & 4 deletions ROADMAP.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@ Core algorithm implemented and tested. Suitable for research and experimentation
- [x] TwoPhase: Two-phase inference optimization (standalone)
- [x] Config: Validated configuration with builder pattern
- [x] Zero initialization of pseudo-query vectors
- [x] 57 tests passing (unit, differential, property-based, integration, doctest)
- [x] CI pipeline (test, clippy, fmt, build-examples)
- [x] 3 examples (train_tiny, compare_residuals, visualize_weights)
- [x] Criterion benchmarks
- [x] Upgrade to burn 0.20

## v0.2.0 — Serialization & Inference (Planned)
## v0.2.0 — Serialization & Inference

- [ ] Safetensors weight save/load
- [ ] Integrate two-phase inference into main forward path
- [x] Model weight save/load (NamedMpk default, binary, compact/half-precision formats)
- [x] Config save/load (JSON via burn's Config trait)
- [x] Integrate two-phase inference into main `forward_two_phase` method
- [x] Layer accessor methods for two-phase inference components
- [x] 66 tests passing (unit, differential, property-based, integration, doctest)
- [ ] Pre-trained weight loading from PyTorch checkpoints
- [ ] Model export utilities

Expand Down
28 changes: 28 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
/// `num_layers` counts *sublayers* — each transformer layer has 2 sublayers
/// (attention + MLP), so `num_layers=8` creates 4 transformer layers.
///
/// Supports JSON serialization via [`save`](AttnResConfig::save) and
/// [`load`](AttnResConfig::load) methods.
///
/// Paper reference: Section 3, Block Attention Residuals.
use burn::config::Config;

Expand Down Expand Up @@ -177,6 +180,31 @@ mod tests {
assert_eq!(config.num_transformer_layers(), 6);
}

#[test]
fn test_config_save_load_roundtrip() {
let config = AttnResConfig::new(128, 24, 8)
.with_num_heads(8)
.with_d_ff(512)
.with_vocab_size(50000)
.with_dropout(0.1);

let path = std::env::temp_dir().join("attnres_test_config.json");
config.save(&path).expect("Failed to save config");

let loaded = AttnResConfig::load(&path).expect("Failed to load config");

assert_eq!(config.d_model, loaded.d_model);
assert_eq!(config.num_layers, loaded.num_layers);
assert_eq!(config.num_blocks, loaded.num_blocks);
assert_eq!(config.num_heads, loaded.num_heads);
assert_eq!(config.d_ff, loaded.d_ff);
assert_eq!(config.vocab_size, loaded.vocab_size);
assert!((config.dropout - loaded.dropout).abs() < 1e-10);
assert!((config.rms_norm_eps - loaded.rms_norm_eps).abs() < 1e-15);

let _ = std::fs::remove_file(&path);
}

#[test]
fn test_full_attnres_block_size_one() {
// Full AttnRes: each sublayer is its own block
Expand Down
41 changes: 41 additions & 0 deletions src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,47 @@ impl AttnResConfig {
}

impl<B: Backend> AttnResLayer<B> {
/// Get the layer index.
pub fn layer_idx(&self) -> usize {
self.layer_idx
}

/// Get the block size.
pub fn block_size(&self) -> usize {
self.block_size
}

/// Check if this layer is at a block boundary.
pub fn is_at_boundary(&self) -> bool {
let half_block = self.block_size / 2;
self.layer_idx > 0 && (half_block == 0 || self.layer_idx.is_multiple_of(half_block))
}

/// Get references to the AttnRes operations (attn_res, mlp_res).
pub fn attn_res_ops(&self) -> (&AttnResOp<B>, &AttnResOp<B>) {
(&self.attn_res, &self.mlp_res)
}

/// Execute only the attention sublayer (norm + multi-head attention).
///
/// Used by two-phase inference after AttnRes has been computed externally.
pub fn forward_attn_sublayer(
&self,
h: Tensor<B, 3>,
mask: Option<&Tensor<B, 3>>,
) -> Tensor<B, 3> {
let normed = self.attn_norm.forward(h);
self.attn.forward(normed, mask)
}

/// Execute only the MLP sublayer (norm + feed-forward).
///
/// Used by two-phase inference after AttnRes has been computed externally.
pub fn forward_mlp_sublayer(&self, h: Tensor<B, 3>) -> Tensor<B, 3> {
let normed = self.mlp_norm.forward(h);
self.mlp.forward(normed)
}

/// Forward pass for a single transformer layer with Block AttnRes.
///
/// Maps directly to the `forward` function in Figure 2 of the paper.
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub mod feed_forward;
pub mod layer;
pub mod model;
pub mod rms_norm;
pub mod serialization;
pub mod two_phase;
pub mod utils;

Expand All @@ -47,4 +48,5 @@ pub use feed_forward::{FeedForward, FeedForwardConfig};
pub use layer::AttnResLayer;
pub use model::AttnResTransformer;
pub use rms_norm::{RmsNorm, RmsNormConfig};
pub use serialization::SerializationError;
pub use utils::causal_mask;
164 changes: 164 additions & 0 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::block_state::BlockState;
use crate::config::AttnResConfig;
use crate::layer::AttnResLayer;
use crate::rms_norm::{RmsNorm, RmsNormConfig};
use crate::two_phase::{compute_intra_logit, online_softmax_merge, phase1_batched};

#[derive(Module, Debug)]
pub struct AttnResTransformer<B: Backend> {
Expand Down Expand Up @@ -85,6 +86,149 @@ impl<B: Backend> AttnResTransformer<B> {
self.lm_head.forward(normed)
}

/// Forward pass using two-phase inference optimization.
///
/// Produces identical results to [`forward`](Self::forward) but uses the two-phase
/// strategy from Algorithm 1 of the paper:
/// - Phase 1: Batch inter-block attention for all sublayers in each block
/// - Phase 2: Sequential intra-block attention with online softmax merge
///
/// This is beneficial during inference when blocks are cached, as Phase 1
/// can be parallelized across sublayers.
///
/// Paper reference: Algorithm 1, Section 4.1.
pub fn forward_two_phase(
&self,
input_ids: Tensor<B, 2, Int>,
mask: Option<&Tensor<B, 3>>,
) -> Tensor<B, 3> {
let embeddings = self.embedding.forward(input_ids);
let mut state = BlockState::new(embeddings);

// Group layers into blocks based on block boundaries
let mut block_start = 0;
while block_start < self.layers.len() {
// Find the end of this block: layers until next boundary
let mut block_end = block_start + 1;
while block_end < self.layers.len() && !self.layers[block_end].is_at_boundary() {
block_end += 1;
}

let block_layers = &self.layers[block_start..block_end];

if state.blocks.is_empty() {
// No inter-block context yet; use standard forward
for layer in block_layers {
state = layer.forward(state, mask);
}
} else {
// Two-phase forward for this block of layers
state = self.forward_block_two_phase(state, block_layers, mask);
}

block_start = block_end;
}

let output = state
.partial_block
.expect("partial_block missing after forward pass; this is a bug in AttnResLayer");

let normed = self.final_norm.forward(output);
self.lm_head.forward(normed)
}

/// Two-phase forward for a single block of layers.
///
/// Uses Phase 1 (batched inter-block attention) + Phase 2 (sequential intra-block).
fn forward_block_two_phase(
&self,
mut state: BlockState<B>,
block_layers: &[AttnResLayer<B>],
mask: Option<&Tensor<B, 3>>,
) -> BlockState<B> {
// Phase 2 setup: handle block boundary first so blocks are correct for Phase 1
let current_partial = state
.partial_block
.take()
.unwrap_or_else(|| Tensor::zeros_like(state.blocks.last().unwrap()));

let first_layer = &block_layers[0];
let at_boundary = first_layer.is_at_boundary();

if at_boundary {
state.blocks.push(current_partial.clone());
}

let mut partial = if at_boundary {
Tensor::zeros_like(state.blocks.last().unwrap())
} else {
current_partial
};

// Collect all AttnResOp references for Phase 1 batching
// Each layer has 2 ops: attn_res, mlp_res
let all_ops: Vec<_> = block_layers
.iter()
.flat_map(|layer| {
let (attn_op, mlp_op) = layer.attn_res_ops();
vec![attn_op, mlp_op]
})
.collect();

// Phase 1: Batch all inter-block attention (now with correct blocks)
let phase1 = phase1_batched(&all_ops, &state.blocks);

// Process each sublayer using Phase 1 results + online softmax merge
for (layer_idx, layer) in block_layers.iter().enumerate() {
let attn_op_idx = layer_idx * 2;
let mlp_op_idx = layer_idx * 2 + 1;

// === AttnRes before self-attention (using two-phase) ===
let h = if phase1.outputs.is_empty() {
// No inter-block context: fall back to standard
let (attn_op, _) = layer.attn_res_ops();
attn_op.forward(&state.blocks, &partial)
} else {
let (attn_op, _) = layer.attn_res_ops();
let intra_logit = compute_intra_logit(attn_op, &partial);
online_softmax_merge(
phase1.outputs[attn_op_idx].clone(),
phase1.max_logits[attn_op_idx].clone(),
phase1.sum_exp[attn_op_idx].clone(),
intra_logit,
partial.clone(),
)
};

// Attention sublayer
let attn_out = layer.forward_attn_sublayer(h, mask);
partial = partial + attn_out;

// === AttnRes before MLP (using two-phase) ===
let h = if phase1.outputs.is_empty() {
let (_, mlp_op) = layer.attn_res_ops();
mlp_op.forward(&state.blocks, &partial)
} else {
let (_, mlp_op) = layer.attn_res_ops();
let intra_logit = compute_intra_logit(mlp_op, &partial);
online_softmax_merge(
phase1.outputs[mlp_op_idx].clone(),
phase1.max_logits[mlp_op_idx].clone(),
phase1.sum_exp[mlp_op_idx].clone(),
intra_logit,
partial.clone(),
)
};

// MLP sublayer
let mlp_out = layer.forward_mlp_sublayer(h);
partial = partial + mlp_out;
}

state.partial_block = Some(partial);
state
}

/// Forward pass returning hidden states (without LM head).
pub fn forward_hidden(
&self,
Expand Down Expand Up @@ -128,6 +272,26 @@ mod tests {
assert_eq!(output.dims(), [1, 8, 100]);
}

#[test]
fn test_two_phase_matches_standard() {
let device = Default::default();
let config = AttnResConfig::new(32, 8, 2)
.with_num_heads(4)
.with_vocab_size(100);

let model = config.init_model::<TestBackend>(&device);

let input_ids = Tensor::<TestBackend, 2, Int>::zeros([1, 8], &device);
let standard_out = model.forward(input_ids.clone(), None);
let two_phase_out = model.forward_two_phase(input_ids, None);

let diff: f32 = (standard_out - two_phase_out).abs().max().into_scalar();
assert!(
diff < 1e-3,
"Two-phase forward should match standard forward, diff={diff}"
);
}

#[test]
fn test_model_forward_hidden_shape() {
let device = Default::default();
Expand Down
Loading
Loading