Skip to content

hoshuaclawdbot/titans-cuda

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

1 Commit
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Titans CUDA Framework πŸš€

A high-performance CUDA implementation of the Titans neural memory architecture from "Titans: Learning to Memorize at Test Time" (Behrouz et al., 2024).

Overview

Titans introduces a novel neural memory module that learns at test time via gradient descent. This framework provides optimized CUDA implementations using:

  • CUDA Kernels β€” Custom kernels for memory operations
  • CUB β€” Block/warp-level primitives for reductions
  • Thrust β€” High-level parallel algorithms
  • cuBLAS β€” Optimized matrix operations

Architecture

Titans Memory Module
β”œβ”€β”€ Memory: MLP that updates at test time
β”œβ”€β”€ Surprise: Gradient + momentum signal
β”œβ”€β”€ Forgetting: Weight decay for memory management
└── Variants: MAC, MAG, MAL

The Core Update Rule

S_t = Ξ· Β· S_{t-1} - ΞΈ Β· βˆ‡β„“(M_{t-1}; x_t)    # Surprise (momentum + gradient)
M_t = (1 - Ξ±) Β· M_{t-1} + S_t                # Memory update with forgetting

Where:

  • M = Neural memory (MLP weights)
  • S = Surprise signal (accumulated gradients)
  • Ξ· = Momentum coefficient
  • ΞΈ = Learning rate
  • Ξ± = Forgetting rate (weight decay)

Project Structure

titans-cuda/
β”œβ”€β”€ include/
β”‚   β”œβ”€β”€ titans/
β”‚   β”‚   β”œβ”€β”€ memory.cuh          # Neural memory module
β”‚   β”‚   β”œβ”€β”€ surprise.cuh        # Surprise computation
β”‚   β”‚   β”œβ”€β”€ projections.cuh     # Key-value projections
β”‚   β”‚   β”œβ”€β”€ variants.cuh        # MAC, MAG, MAL variants
β”‚   β”‚   └── utils.cuh           # Utilities
β”‚   └── common/
β”‚       β”œβ”€β”€ cuda_utils.cuh      # Error checking, timing
β”‚       └── tensor.cuh          # Simple tensor wrapper
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ memory.cu               # Memory module implementation
β”‚   β”œβ”€β”€ surprise.cu             # Surprise kernels
β”‚   β”œβ”€β”€ projections.cu          # Projection kernels
β”‚   └── variants.cu             # Variant implementations
β”œβ”€β”€ kernels/
β”‚   β”œβ”€β”€ naive/                  # Baseline implementations
β”‚   β”œβ”€β”€ optimized/              # Optimized versions
β”‚   └── experimental/           # Cutting-edge optimizations
β”œβ”€β”€ tests/
β”‚   β”œβ”€β”€ test_memory.cu          # Memory module tests
β”‚   β”œβ”€β”€ test_surprise.cu        # Surprise computation tests
β”‚   └── benchmarks.cu           # Performance benchmarks
β”œβ”€β”€ examples/
β”‚   β”œβ”€β”€ simple_memory.cu        # Basic usage example
β”‚   β”œβ”€β”€ sequence_modeling.cu    # Sequence task example
β”‚   └── benchmark_vs_pytorch.py # Compare with PyTorch
β”œβ”€β”€ python/
β”‚   └── titans_cuda/            # Python bindings (optional)
β”œβ”€β”€ CMakeLists.txt
└── README.md

Building

Prerequisites

  • CUDA Toolkit 12.x
  • CMake 3.20+
  • C++17 compiler
  • (Optional) PyTorch for Python bindings

Build

mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release
make -j$(nproc)

Run Tests

cd build
ctest --verbose

Run Benchmarks

./benchmarks --all

Implementation Roadmap

Phase 1: Core Components βœ…

  • Basic tensor operations
  • MLP forward pass
  • Associative memory loss
  • Gradient computation
  • Memory update with momentum

Phase 2: Optimizations πŸ”„

  • Fused forward + backward kernel
  • Shared memory for MLP weights
  • CUB block reductions
  • Vectorized memory access
  • Multi-stream execution

Phase 3: Variants

  • MAC (Memory as Context)
  • MAG (Memory as Gate)
  • MAL (Memory as Layer)

Phase 4: Advanced

  • Chunked parallel training
  • Tensor Core support (FP16/BF16)
  • Multi-GPU support
  • Python/PyTorch bindings

Usage Example

#include <titans/memory.cuh>

int main() {
    // Create memory module
    titans::NeuralMemory memory(
        /*dim=*/256,
        /*depth=*/2,
        /*hidden_mult=*/4.0f
    );
    
    // Create update config
    titans::UpdateConfig config{
        .lr = 0.01f,
        .momentum = 0.9f,
        .forgetting = 0.01f
    };
    
    // Process sequence
    for (int t = 0; t < seq_len; t++) {
        // Project input to key-value
        auto [key, value] = projections.forward(input[t]);
        
        // Query memory
        auto output = memory.forward(key);
        
        // Update memory (test-time learning!)
        memory.update(key, value, config);
    }
    
    return 0;
}

Benchmarks

Target performance vs PyTorch baseline:

Operation PyTorch Ours Speedup
Memory Forward TBD TBD -
Surprise Compute TBD TBD -
Full Update TBD TBD -
1K Sequence TBD TBD -
10K Sequence TBD TBD -

Key Optimizations

1. Fused Forward-Backward

Instead of separate forward and backward passes, compute both in one kernel to avoid memory round-trips.

2. Shared Memory MLP

For small memory MLPs, keep weights in shared memory during the update step.

3. Warp-Level Gradient Accumulation

Use CUB's warp-level primitives for efficient gradient accumulation across threads.

4. Vectorized Access

Use float4 loads/stores for coalesced memory access patterns.

5. Chunked Processing

Process sequences in chunks for better parallelism, following the TTT paper's approach.

References

License

MIT


Built with πŸ¦€ for learning and research

About

High-performance CUDA implementation of Titans neural memory architecture (Learning to Memorize at Test Time)

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors