From e3e9de828d258d727abc2ba9287d549c9afcc8bd Mon Sep 17 00:00:00 2001 From: mertalev <101130780+mertalev@users.noreply.github.com> Date: Thu, 1 Jan 2026 00:24:14 -0500 Subject: [PATCH] allow memory reuse --- crates/cubecl-cpu/src/compiler/memref.rs | 14 +++ crates/cubecl-cpu/src/compiler/mlir_data.rs | 40 +++----- crates/cubecl-cpu/src/compiler/mod.rs | 7 +- .../src/compiler/passes/shared_memories.rs | 92 +++++++++---------- 4 files changed, 77 insertions(+), 76 deletions(-) diff --git a/crates/cubecl-cpu/src/compiler/memref.rs b/crates/cubecl-cpu/src/compiler/memref.rs index f3a81d7e96..ab73e0a19d 100644 --- a/crates/cubecl-cpu/src/compiler/memref.rs +++ b/crates/cubecl-cpu/src/compiler/memref.rs @@ -30,4 +30,18 @@ impl LineMemRef { stride: [1], } } + + /// Create a LineMemRef from a raw pointer and length. + /// # Safety + /// The pointer must be valid and point to at least `len` bytes of writable memory. + pub unsafe fn from_raw_parts(pointer: *mut u8, len: usize) -> Self { + let pointer = pointer as *mut c_void; + Self { + allocated: pointer, + aligned: pointer, + offset: 0, + shape: [len as c_longlong], + stride: [1], + } + } } diff --git a/crates/cubecl-cpu/src/compiler/mlir_data.rs b/crates/cubecl-cpu/src/compiler/mlir_data.rs index 08c22a665a..33af109a5f 100644 --- a/crates/cubecl-cpu/src/compiler/mlir_data.rs +++ b/crates/cubecl-cpu/src/compiler/mlir_data.rs @@ -1,6 +1,6 @@ use super::passes::shared_memories::SharedMemories; use crate::{ - compiler::{builtin::BuiltinArray, memref::LineMemRef, passes::shared_memories::SharedMemory}, + compiler::{builtin::BuiltinArray, memref::LineMemRef}, compute::schedule::BindingsResource, }; use cubecl_common::stream_id::StreamId; @@ -86,33 +86,23 @@ impl MlirData { } let stream_id = StreamId::current(); - let mut smem_handles = Vec::with_capacity(shared_memories.0.len()); - for shared_memory in shared_memories.0.iter() { - let (handle, length) = match shared_memory { - SharedMemory::Array { ty, length, .. } => { - let length = (ty.size() * *length as usize) as u64; - let handle = memory_management_shared_memory.reserve(length).unwrap(); - (handle, length) - } - SharedMemory::Value { ty, .. } => { - let length = ty.size() as u64; - let handle = memory_management_shared_memory.reserve(length).unwrap(); - (handle, length) - } - }; - - smem_handles.push(handle.clone()); - - let b = Handle::new(handle, None, None, stream_id, 0, length).binding(); - let mut handle = memory_management_shared_memory + if let Some(smem_size) = shared_memories.size() { + let handle = memory_management_shared_memory.reserve(smem_size).unwrap(); + let b = Handle::new(handle.clone(), None, None, stream_id, 0, smem_size).binding(); + let mut resource = memory_management_shared_memory .get_resource(b.memory, b.offset_start, b.offset_end) .expect("Failed to find resource"); - let ptr = handle.write(); - let line_memref = LineMemRef::new(ptr); - push_undirected(line_memref); + + let smem_pool_ptr = resource.write().as_mut_ptr(); + for shared_memory in shared_memories.0.iter() { + // Compute pointer into the pool at the appropriate offset + let offset = shared_memory.offset() as usize; + let size = shared_memory.size() as usize; + let ptr = unsafe { smem_pool_ptr.add(offset) }; + let line_memref = unsafe { LineMemRef::from_raw_parts(ptr, size) }; + push_undirected(line_memref); + } } - // It is important to make sure multiple shared memories don't shared the same handle. - core::mem::drop(smem_handles); let ptr = shared_mlir_data.metadata.as_mut(); let line_memref = LineMemRef::new(ptr); diff --git a/crates/cubecl-cpu/src/compiler/mod.rs b/crates/cubecl-cpu/src/compiler/mod.rs index 1737149504..9e93d80f3b 100644 --- a/crates/cubecl-cpu/src/compiler/mod.rs +++ b/crates/cubecl-cpu/src/compiler/mod.rs @@ -22,7 +22,7 @@ use cubecl_core::{ prelude::KernelDefinition, server::ExecutionMode, }; -use cubecl_opt::OptimizerBuilder; +use cubecl_opt::{OptimizerBuilder, SharedLiveness}; use mlir_engine::MlirEngine; use crate::compiler::passes::{ @@ -63,7 +63,7 @@ impl Compiler for MlirCompiler { #[cfg(feature = "mlir-dump")] dump_scope(&kernel.body, &kernel.options.kernel_name); - let opt = OptimizerBuilder::default() + let mut opt = OptimizerBuilder::default() .with_transformer(ErfTransform) .with_transformer(HypotTransform) .with_transformer(RhypotTransform) @@ -72,8 +72,7 @@ impl Compiler for MlirCompiler { .with_processor(PredicateProcessor) .optimize(kernel.body.clone(), kernel.cube_dim); - let mut shared_memories = SharedMemories::default(); - shared_memories.visit(&opt); + let shared_memories = SharedMemories::from_liveness(&opt.analysis::()); #[cfg(feature = "mlir-dump")] dump_opt(&opt, &kernel.options.kernel_name); diff --git a/crates/cubecl-cpu/src/compiler/passes/shared_memories.rs b/crates/cubecl-cpu/src/compiler/passes/shared_memories.rs index e20efdc735..000b17e377 100644 --- a/crates/cubecl-cpu/src/compiler/passes/shared_memories.rs +++ b/crates/cubecl-cpu/src/compiler/passes/shared_memories.rs @@ -1,17 +1,19 @@ -use cubecl_core::ir::{OperationReflect, StorageType, Variable, VariableKind}; -use cubecl_opt::Optimizer; +use cubecl_core::ir::Type; +use cubecl_opt::SharedLiveness; #[derive(Debug, PartialEq, Eq, Clone)] pub enum SharedMemory { Array { id: u32, - ty: StorageType, - // Length include the vectorization factor + ty: Type, + // Length includes unroll_factor; vectorization is in ty.size() length: u32, + offset: u32, }, Value { id: u32, - ty: StorageType, + ty: Type, + offset: u32, }, } @@ -22,55 +24,51 @@ impl SharedMemory { SharedMemory::Value { id, .. } => *id, } } + + pub fn offset(&self) -> u32 { + match self { + SharedMemory::Array { offset, .. } => *offset, + SharedMemory::Value { offset, .. } => *offset, + } + } + + pub fn size(&self) -> u32 { + match self { + SharedMemory::Array { ty, length, .. } => *length * ty.size() as u32, + SharedMemory::Value { ty, .. } => ty.size() as u32, + } + } } #[derive(Default)] pub struct SharedMemories(pub Vec); impl SharedMemories { - pub fn visit_variable(&mut self, variable: Variable) { - // Alignment is ignored for the moment it is taken from the type - match variable.kind { - VariableKind::SharedArray { id, length, .. } => { - if self.0.iter().all(|shared_memory| shared_memory.id() != id) { - let elem = variable.storage_type(); - let vectorization = variable.line_size(); - let length = length * vectorization; - self.0.push(SharedMemory::Array { - id, - ty: elem, - length, - }); - } - } - VariableKind::Shared { id } => { - if self.0.iter().all(|shared_memory| shared_memory.id() != id) { - let elem = variable.storage_type(); - self.0.push(SharedMemory::Value { id, ty: elem }); - } - } - _ => {} - } + /// Build from the [SharedLiveness] analysis so non-overlapping lifetimes can reuse memory. + pub fn from_liveness(shared_liveness: &SharedLiveness) -> Self { + let mut memories: Vec = shared_liveness + .allocations + .values() + .map(|alloc| match alloc.smem { + cubecl_opt::SharedMemory::Array { id, length, ty, .. } => SharedMemory::Array { + id, + ty, + length, + offset: alloc.offset, + }, + cubecl_opt::SharedMemory::Value { id, ty, .. } => SharedMemory::Value { + id, + ty, + offset: alloc.offset, + }, + }) + .collect(); + + memories.sort_by_key(|m| m.id()); + Self(memories) } - pub fn visit(&mut self, opt: &Optimizer) { - for node in opt.program.node_indices().collect::>() { - let phi = opt.program[node].phi_nodes.clone(); - let ops = opt.program[node].ops.clone(); - for phi in phi.borrow_mut().iter_mut() { - self.visit_variable(phi.out); - } - for op in ops.borrow_mut().values_mut() { - if let Some(out) = op.out { - self.visit_variable(out); - } - if let Some(args) = op.operation.args() { - for arg in args { - self.visit_variable(arg); - } - } - } - } - self.0.sort_by_key(|a| a.id()); + pub fn size(&self) -> Option { + self.0.iter().map(|m| (m.offset() + m.size()) as u64).max() } }