A high-performance CUDA implementation of the Titans neural memory architecture from "Titans: Learning to Memorize at Test Time" (Behrouz et al., 2024).
Titans introduces a novel neural memory module that learns at test time via gradient descent. This framework provides optimized CUDA implementations using:
- CUDA Kernels β Custom kernels for memory operations
- CUB β Block/warp-level primitives for reductions
- Thrust β High-level parallel algorithms
- cuBLAS β Optimized matrix operations
Titans Memory Module
βββ Memory: MLP that updates at test time
βββ Surprise: Gradient + momentum signal
βββ Forgetting: Weight decay for memory management
βββ Variants: MAC, MAG, MAL
S_t = Ξ· Β· S_{t-1} - ΞΈ Β· ββ(M_{t-1}; x_t) # Surprise (momentum + gradient)
M_t = (1 - Ξ±) Β· M_{t-1} + S_t # Memory update with forgetting
Where:
M= Neural memory (MLP weights)S= Surprise signal (accumulated gradients)Ξ·= Momentum coefficientΞΈ= Learning rateΞ±= Forgetting rate (weight decay)
titans-cuda/
βββ include/
β βββ titans/
β β βββ memory.cuh # Neural memory module
β β βββ surprise.cuh # Surprise computation
β β βββ projections.cuh # Key-value projections
β β βββ variants.cuh # MAC, MAG, MAL variants
β β βββ utils.cuh # Utilities
β βββ common/
β βββ cuda_utils.cuh # Error checking, timing
β βββ tensor.cuh # Simple tensor wrapper
βββ src/
β βββ memory.cu # Memory module implementation
β βββ surprise.cu # Surprise kernels
β βββ projections.cu # Projection kernels
β βββ variants.cu # Variant implementations
βββ kernels/
β βββ naive/ # Baseline implementations
β βββ optimized/ # Optimized versions
β βββ experimental/ # Cutting-edge optimizations
βββ tests/
β βββ test_memory.cu # Memory module tests
β βββ test_surprise.cu # Surprise computation tests
β βββ benchmarks.cu # Performance benchmarks
βββ examples/
β βββ simple_memory.cu # Basic usage example
β βββ sequence_modeling.cu # Sequence task example
β βββ benchmark_vs_pytorch.py # Compare with PyTorch
βββ python/
β βββ titans_cuda/ # Python bindings (optional)
βββ CMakeLists.txt
βββ README.md
- CUDA Toolkit 12.x
- CMake 3.20+
- C++17 compiler
- (Optional) PyTorch for Python bindings
mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release
make -j$(nproc)cd build
ctest --verbose./benchmarks --all- Basic tensor operations
- MLP forward pass
- Associative memory loss
- Gradient computation
- Memory update with momentum
- Fused forward + backward kernel
- Shared memory for MLP weights
- CUB block reductions
- Vectorized memory access
- Multi-stream execution
- MAC (Memory as Context)
- MAG (Memory as Gate)
- MAL (Memory as Layer)
- Chunked parallel training
- Tensor Core support (FP16/BF16)
- Multi-GPU support
- Python/PyTorch bindings
#include <titans/memory.cuh>
int main() {
// Create memory module
titans::NeuralMemory memory(
/*dim=*/256,
/*depth=*/2,
/*hidden_mult=*/4.0f
);
// Create update config
titans::UpdateConfig config{
.lr = 0.01f,
.momentum = 0.9f,
.forgetting = 0.01f
};
// Process sequence
for (int t = 0; t < seq_len; t++) {
// Project input to key-value
auto [key, value] = projections.forward(input[t]);
// Query memory
auto output = memory.forward(key);
// Update memory (test-time learning!)
memory.update(key, value, config);
}
return 0;
}Target performance vs PyTorch baseline:
| Operation | PyTorch | Ours | Speedup |
|---|---|---|---|
| Memory Forward | TBD | TBD | - |
| Surprise Compute | TBD | TBD | - |
| Full Update | TBD | TBD | - |
| 1K Sequence | TBD | TBD | - |
| 10K Sequence | TBD | TBD | - |
Instead of separate forward and backward passes, compute both in one kernel to avoid memory round-trips.
For small memory MLPs, keep weights in shared memory during the update step.
Use CUB's warp-level primitives for efficient gradient accumulation across threads.
Use float4 loads/stores for coalesced memory access patterns.
Process sequences in chunks for better parallelism, following the TTT paper's approach.
- Titans: Learning to Memorize at Test Time (Behrouz et al., 2024)
- Learning to (Learn at Test Time) (Sun et al., 2024) - TTT
- It's All Connected (Behrouz et al., 2025) - Miras Framework
MIT
Built with π¦ for learning and research