diff --git a/crates/cubecl-convolution/src/kernels/layered/algorithm/simple.rs b/crates/cubecl-convolution/src/kernels/layered/algorithm/simple.rs index 3d893ccc84..555afdb5b3 100644 --- a/crates/cubecl-convolution/src/kernels/layered/algorithm/simple.rs +++ b/crates/cubecl-convolution/src/kernels/layered/algorithm/simple.rs @@ -17,6 +17,7 @@ use cubecl_matmul::components::{ tile::TileMatmulFamily, }; +use cubecl_runtime::stride::{is_contiguous, is_inner_contiguous_rows}; use cubecl_std::{ CubeOption, tensor::{TensorHandle, into_contiguous}, @@ -80,12 +81,8 @@ impl(handle: &TensorHandleRef<'_, R>, ident: MatmulIdent) -> bool { - let rank = handle.shape.len(); - let dim_c = rank - 1; - match ident { - MatmulIdent::Lhs => handle.strides[dim_c] == 1, - MatmulIdent::Rhs => handle.strides[dim_c] == 1, - MatmulIdent::Out => unreachable!(), - } +fn has_valid_layout(handle: &TensorHandleRef<'_, R>, _ident: MatmulIdent) -> bool { + // Accept fully contiguous or inner‑contiguous rows (rank>=2) + is_contiguous(handle.shape, handle.strides) + || is_inner_contiguous_rows(handle.shape, handle.strides) } diff --git a/crates/cubecl-cpu/src/compute/server.rs b/crates/cubecl-cpu/src/compute/server.rs index b3d0e9b13f..584a3e68f8 100644 --- a/crates/cubecl-cpu/src/compute/server.rs +++ b/crates/cubecl-cpu/src/compute/server.rs @@ -18,6 +18,7 @@ use cubecl_runtime::{ }; use crate::{CpuCompiler, compute::alloc_controller::CpuAllocController}; +use cubecl_runtime::stride::{contiguous_strides, pitched_rows_layout, row_pitch_elems}; use super::scheduler::Scheduler; @@ -66,10 +67,43 @@ impl CpuServer { ) -> Result, IoError> { let mut result = Vec::with_capacity(descriptors.len()); for desc in descriptors { - let len = desc.binding.size() as usize; - let (controller, alloc) = - CpuAllocController::init(desc.binding, &mut ctx.memory_management)?; - result.push(unsafe { Bytes::from_raw_parts(alloc, len, Box::new(controller)) }); + let binding = desc.binding; + let elem = desc.elem_size; + let size = desc.shape.iter().product::() * elem; + + // Contiguous: return zero-copy Bytes over the binding with logical len + if contiguous_strides(desc.shape) == desc.strides { + let (controller, alloc) = + CpuAllocController::init(binding, &mut ctx.memory_management)?; + result + .push(unsafe { Bytes::from_raw_parts(alloc, size, Box::new(controller)) }); + continue; + } + + // Inner-contiguous rows: reconstruct rows into contiguous buffer + if let Some(row_pitch_elems) = row_pitch_elems(desc.shape, desc.strides) { + let resource = ctx + .memory_management + .get_resource(binding.memory, binding.offset_start, binding.offset_end) + .ok_or(IoError::InvalidHandle)?; + let last = desc.shape.len() - 1; + let rows = desc.shape[..last].iter().product::(); + let cols = desc.shape[last]; + let row_bytes = cols * elem; + let row_pitch = row_pitch_elems * elem; + let src = resource.read(); + let mut out = vec![0u8; rows * row_bytes]; + for r in 0..rows { + let src_off = r * row_pitch; + let dst_off = r * row_bytes; + out[dst_off..dst_off + row_bytes] + .copy_from_slice(&src[src_off..src_off + row_bytes]); + } + result.push(Bytes::from_bytes_vec(out)); + continue; + } + + return Err(IoError::UnsupportedStrides); } Ok(result) } @@ -90,14 +124,22 @@ impl ComputeServer for CpuServer { descriptors: Vec>, ) -> Result, IoError> { let align = 8; - let strides = descriptors - .iter() - .map(|desc| contiguous_strides(desc.shape)) - .collect::>(); - let sizes = descriptors - .iter() - .map(|desc| desc.shape.iter().product::() * desc.elem_size) - .collect::>(); + let mut strides = Vec::with_capacity(descriptors.len()); + let mut sizes = Vec::with_capacity(descriptors.len()); + + use cubecl_core::server::AllocationKind; + + for desc in &descriptors { + let rank = desc.shape.len(); + if matches!(desc.kind, AllocationKind::Optimized) && rank > 1 { + let (s, size) = pitched_rows_layout(desc.shape, desc.elem_size, align); + strides.push(s); + sizes.push(size); + } else { + strides.push(contiguous_strides(desc.shape)); + sizes.push(desc.shape.iter().product::() * desc.elem_size); + } + } let total_size = sizes .iter() .map(|it| it.next_multiple_of(align)) @@ -123,11 +165,42 @@ impl ComputeServer for CpuServer { fn write(&mut self, descriptors: Vec<(CopyDescriptor<'_>, &[u8])>) -> Result<(), IoError> { for (desc, data) in descriptors { - if desc.strides != contiguous_strides(desc.shape) { - return Err(IoError::UnsupportedStrides); + // Contiguous path + if contiguous_strides(desc.shape) == desc.strides { + self.copy_to_binding(desc.binding, data); + continue; + } + + // Inner-contiguous rows: copy into pitched destination row-by-row + if let Some(row_pitch_elems) = row_pitch_elems(desc.shape, desc.strides) { + let last = desc.shape.len() - 1; + let rows = desc.shape[..last].iter().product::(); + let cols = desc.shape[last]; + let elem = desc.elem_size; + let row_bytes = cols * elem; + let row_pitch = row_pitch_elems * elem; + + let resource = self + .ctx + .memory_management + .get_resource( + desc.binding.memory, + desc.binding.offset_start, + desc.binding.offset_end, + ) + .ok_or(IoError::InvalidHandle)?; + + let dst = resource.write(); + for r in 0..rows { + let dst_off = r * row_pitch; + let src_off = r * row_bytes; + dst[dst_off..dst_off + row_bytes] + .copy_from_slice(&data[src_off..src_off + row_bytes]); + } + continue; } - self.copy_to_binding(desc.binding, data); + return Err(IoError::UnsupportedStrides); } Ok(()) } @@ -219,12 +292,3 @@ impl CpuServer { resource.write().copy_from_slice(data); } } - -pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec { - let rank = shape.len(); - let mut strides = vec![1; rank]; - for i in (0..rank - 1).rev() { - strides[i] = strides[i + 1] * shape[i + 1]; - } - strides -} diff --git a/crates/cubecl-cpu/src/runtime.rs b/crates/cubecl-cpu/src/runtime.rs index ff00357b11..9798a69018 100644 --- a/crates/cubecl-cpu/src/runtime.rs +++ b/crates/cubecl-cpu/src/runtime.rs @@ -5,12 +5,12 @@ use cubecl_core::{ client::ComputeClient, ir::{StorageType, TargetProperties}, }; +use cubecl_runtime::stride::{is_contiguous, is_inner_contiguous_rows}; use cubecl_runtime::{ ComputeRuntime, DeviceProperties, memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement}, storage::BytesStorage, }; -use cubecl_std::tensor::is_contiguous; use sysinfo::System; use crate::{ @@ -72,6 +72,8 @@ fn create_client(options: RuntimeOptions) -> ComputeClient { mem_properties, topology, TimingMethod::Device, + // Default to contiguous on CPU. + cubecl_runtime::server::AllocationKind::Contiguous, ); register_supported_types(&mut device_props); @@ -111,7 +113,7 @@ impl Runtime for CpuRuntime { } fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool { - is_contiguous(shape, strides) + is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides) } fn target_properties() -> TargetProperties { diff --git a/crates/cubecl-cuda/src/compute/io/base.rs b/crates/cubecl-cuda/src/compute/io/base.rs index 57e46f9ef4..33ee1a6bfc 100644 --- a/crates/cubecl-cuda/src/compute/io/base.rs +++ b/crates/cubecl-cuda/src/compute/io/base.rs @@ -1,8 +1,9 @@ use super::controller::PinnedMemoryManagedAllocController; -use crate::compute::{CudaContext, MB, valid_strides}; +use crate::compute::{CudaContext, MB}; use cubecl_common::bytes::Bytes; use cubecl_core::server::{CopyDescriptor, IoError}; use cubecl_runtime::memory_management::MemoryHandle; +use cubecl_runtime::stride::{is_contiguous, is_inner_contiguous_rows, row_pitch_elems}; use cudarc::driver::sys::{CUDA_MEMCPY2D_st, CUmemorytype, cuMemcpy2DAsync_v2}; use std::{ffi::c_void, ops::DerefMut}; @@ -52,7 +53,7 @@ pub fn register_copy_to_bytes( elem_size, } = descriptor; - if !valid_strides(shape, strides) { + if !(is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides)) { return Err(IoError::UnsupportedStrides); } @@ -77,7 +78,7 @@ pub fn register_copy_to_bytes( let dim_x = shape[rank - 1]; let width_bytes = dim_x * elem_size; let dim_y: usize = shape.iter().rev().skip(1).product(); - let pitch = strides[rank - 2] * elem_size; + let pitch = row_pitch_elems(shape, strides).unwrap() * elem_size; let slice = bytes.deref_mut(); let cpy = CUDA_MEMCPY2D_st { diff --git a/crates/cubecl-cuda/src/compute/server.rs b/crates/cubecl-cuda/src/compute/server.rs index dee40b70dc..a11c6f7d34 100644 --- a/crates/cubecl-cuda/src/compute/server.rs +++ b/crates/cubecl-cuda/src/compute/server.rs @@ -14,6 +14,14 @@ use cubecl_cpp::formatter::format_cpp; use cubecl_cpp::{cuda::arch::CudaArchitecture, shared::CompilationOptions}; use super::storage::gpu::{GpuResource, GpuStorage}; +use cubecl_runtime::data_service::DataTransferId; +use cubecl_runtime::logging::ServerLogger; +use cubecl_runtime::stride::{ + contiguous_strides, is_contiguous, is_inner_contiguous_rows, pitched_rows_layout, + row_pitch_elems, +}; +use cubecl_runtime::{memory_management::offset_handles, timestamp_profiler::TimestampProfiler}; + use super::sync::{Fence, SyncStream}; use crate::compute::{ DataTransferItem, DataTransferRuntime, io::register_copies_to_bytes, @@ -27,15 +35,14 @@ use cubecl_core::{ ir::FloatKind, server::{Bindings, CopyDescriptor, TensorMapBinding}, }; -use cubecl_runtime::data_service::DataTransferId; -use cubecl_runtime::logging::ServerLogger; +// deduped above: DataTransferId, ServerLogger use cubecl_runtime::memory_management::MemoryUsage; use cubecl_runtime::storage::BindingResource; use cubecl_runtime::{ memory_management::MemoryManagement, server::{self, ComputeServer}, }; -use cubecl_runtime::{memory_management::offset_handles, timestamp_profiler::TimestampProfiler}; +// deduped above: offset_handles, TimestampProfiler use cudarc::driver::sys::{ CUDA_MEMCPY2D_st, CUctx_st, CUfunction_attribute, CUmemorytype, CUtensorMap, CUtensorMapDataType, CUtensorMapFloatOOBfill, CUtensorMapL2promotion, CUtensorMapSwizzle, @@ -145,31 +152,20 @@ impl ComputeServer for CudaServer { let mut total_size = 0; for descriptor in descriptors { - let pitch_align = match descriptor.kind { - AllocationKind::Contiguous => 1, - AllocationKind::Optimized => self.mem_alignment, - }; - let rank = descriptor.shape.len(); - let width = *descriptor.shape.last().unwrap_or(&1); - let height: usize = descriptor.shape.iter().rev().skip(1).product(); - let height = height.max(1); - let width_bytes = width * descriptor.elem_size; - let pitch = width_bytes.next_multiple_of(pitch_align); - let size = height * pitch; - total_size += size.next_multiple_of(self.mem_alignment); - let mut stride = vec![1; rank]; - if rank > 1 { - stride[rank - 2] = pitch / descriptor.elem_size; - } - if rank > 2 { - for i in (0..rank - 2).rev() { - stride[i] = stride[i + 1] * descriptor.shape[i + 1]; - } + if matches!(descriptor.kind, AllocationKind::Optimized) && rank > 1 { + let (s, size) = + pitched_rows_layout(descriptor.shape, descriptor.elem_size, self.mem_alignment); + total_size += size.next_multiple_of(self.mem_alignment); + strides.push(s); + sizes.push(size); + } else { + let s = contiguous_strides(descriptor.shape); + let size = descriptor.shape.iter().product::() * descriptor.elem_size; + total_size += size.next_multiple_of(self.mem_alignment); + strides.push(s); + sizes.push(size); } - - strides.push(stride); - sizes.push(size); } let ctx = self.get_context(); @@ -197,7 +193,7 @@ impl ComputeServer for CudaServer { } = descriptor; let rank = shape.len(); - if !valid_strides(shape, strides) { + if !(is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides)) { return Err(IoError::UnsupportedStrides); } @@ -210,7 +206,8 @@ impl ComputeServer for CudaServer { let dim_x = shape[rank - 1]; let width_bytes = dim_x * elem_size; let dim_y: usize = shape.iter().rev().skip(1).product(); - let pitch = strides[rank - 2] * elem_size; + let pitch = + row_pitch_elems(shape, strides).unwrap_or(strides[rank - 2]) * elem_size; let cpy = CUDA_MEMCPY2D_st { srcMemoryType: CUmemorytype::CU_MEMORYTYPE_HOST, @@ -940,28 +937,3 @@ fn oob_to_cuda(fill: OobFill) -> CUtensorMapFloatOOBfill { OobFill::NaN => CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA, } } - -pub fn valid_strides(shape: &[usize], strides: &[usize]) -> bool { - let rank = shape.len(); - if strides[rank - 1] != 1 { - return false; - } - if rank <= 1 { - return true; - } - - let mut sorted = strides.to_vec(); - sorted.sort(); - sorted.reverse(); - - if sorted != strides { - return false; - } - - for i in 0..rank - 2 { - if strides[i] != shape[i + 1] * strides[i + 1] { - return false; - } - } - true -} diff --git a/crates/cubecl-cuda/src/runtime.rs b/crates/cubecl-cuda/src/runtime.rs index 11284ace31..2b0ea07b15 100644 --- a/crates/cubecl-cuda/src/runtime.rs +++ b/crates/cubecl-cuda/src/runtime.rs @@ -6,7 +6,6 @@ use crate::{ cpu::{PINNED_MEMORY_ALIGNMENT, PinnedMemoryStorage}, gpu::GpuStorage, }, - valid_strides, }, device::CudaDevice, }; @@ -27,6 +26,7 @@ use cubecl_cpp::{ register_wmma_features, }, }; +use cubecl_runtime::stride::{is_contiguous, is_inner_contiguous_rows}; use cubecl_runtime::{ ComputeRuntime, DeviceProperties, Plane, Tma, TypeUsage, channel::MutexComputeChannel, @@ -184,6 +184,8 @@ fn create_client>>( mem_properties, hardware_props, TimingMethod::System, + // Prefer pitched rows by default on CUDA (hardware efficient). + cubecl_runtime::server::AllocationKind::Optimized, ); register_supported_types(&mut device_props); device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion); @@ -311,7 +313,7 @@ impl Runtime for CudaRuntime { } fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool { - valid_strides(shape, strides) + is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides) } fn target_properties() -> TargetProperties { diff --git a/crates/cubecl-hip/src/compute/io/base.rs b/crates/cubecl-hip/src/compute/io/base.rs index ce088ed851..ce3a30a0b7 100644 --- a/crates/cubecl-hip/src/compute/io/base.rs +++ b/crates/cubecl-hip/src/compute/io/base.rs @@ -1,9 +1,10 @@ use super::controller::PinnedMemoryManagedAllocController; -use crate::compute::{HipContext, MB, valid_strides}; +use crate::compute::{HipContext, MB}; use cubecl_common::bytes::Bytes; use cubecl_core::server::{CopyDescriptor, IoError}; use cubecl_hip_sys::{HIP_SUCCESS, hipMemcpyKind_hipMemcpyDeviceToHost}; use cubecl_runtime::memory_management::MemoryHandle; +use cubecl_runtime::stride::{is_contiguous, is_inner_contiguous_rows, row_pitch_elems}; /// Registers multiple lazy buffer copies to [Bytes], potentially using pinned memory. /// @@ -51,7 +52,7 @@ pub fn register_copy_to_bytes( elem_size, } = descriptor; - if !valid_strides(shape, strides) { + if !(is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides)) { return Err(IoError::UnsupportedStrides); } @@ -84,7 +85,7 @@ pub fn register_copy_to_bytes( let dim_x = shape[rank - 1]; let width_bytes = dim_x * elem_size; let dim_y: usize = shape.iter().rev().skip(1).product(); - let pitch = strides[rank - 2] * elem_size; + let pitch = row_pitch_elems(shape, strides).unwrap() * elem_size; unsafe { let status = cubecl_hip_sys::hipMemcpy2DAsync( diff --git a/crates/cubecl-hip/src/compute/server.rs b/crates/cubecl-hip/src/compute/server.rs index 31bb296851..7402cf884b 100644 --- a/crates/cubecl-hip/src/compute/server.rs +++ b/crates/cubecl-hip/src/compute/server.rs @@ -25,6 +25,10 @@ use cubecl_runtime::logging::ServerLogger; use cubecl_runtime::memory_management::MemoryUsage; use cubecl_runtime::memory_management::offset_handles; use cubecl_runtime::storage::BindingResource; +use cubecl_runtime::stride::{ + contiguous_strides, is_contiguous, is_inner_contiguous_rows, pitched_rows_layout, + row_pitch_elems, +}; use cubecl_runtime::timestamp_profiler::TimestampProfiler; use cubecl_runtime::{ memory_management::MemoryManagement, @@ -139,30 +143,17 @@ impl ComputeServer for HipServer { let mut sizes = Vec::new(); for descriptor in descriptors { - let pitch_align = match descriptor.kind { - AllocationKind::Contiguous => 1, - AllocationKind::Optimized => self.mem_alignment, - }; - let rank = descriptor.shape.len(); - let width = *descriptor.shape.last().unwrap_or(&1); - let height: usize = descriptor.shape.iter().rev().skip(1).product(); - let height = height.max(1); - let width_bytes = width * descriptor.elem_size; - let pitch = width_bytes.next_multiple_of(pitch_align); - let size = height * pitch; + let (stride, size) = if matches!(descriptor.kind, AllocationKind::Optimized) && rank > 1 + { + pitched_rows_layout(descriptor.shape, descriptor.elem_size, self.mem_alignment) + } else { + ( + contiguous_strides(descriptor.shape), + descriptor.shape.iter().product::() * descriptor.elem_size, + ) + }; total_size += size.next_multiple_of(self.mem_alignment); - - let mut stride = vec![1; rank]; - if rank > 1 { - stride[rank - 2] = pitch / descriptor.elem_size; - } - if rank > 2 { - for i in (0..rank - 2).rev() { - stride[i] = stride[i + 1] * descriptor.shape[i + 1]; - } - } - strides.push(stride); sizes.push(size); } @@ -199,14 +190,16 @@ impl ComputeServer for HipServer { } = descriptor; let rank = shape.len(); - if !valid_strides(shape, strides) { + if !(is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides)) { return Err(IoError::UnsupportedStrides); } if rank > 1 { - let stride = strides[rank - 2]; - - self.copy_to_binding_2d(binding, data, shape, stride, elem_size); + if let Some(pitch_elems) = row_pitch_elems(shape, strides) { + self.copy_to_binding_2d(binding, data, shape, pitch_elems, elem_size); + } else { + self.copy_to_binding(binding, data); + } } else { self.copy_to_binding(binding, data); } @@ -657,15 +650,6 @@ impl HipServer { } } -pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec { - let rank = shape.len(); - let mut strides = vec![1; rank]; - for i in (0..rank - 1).rev() { - strides[i] = strides[i + 1] * shape[i + 1]; - } - strides -} - #[derive(Debug)] pub(crate) enum LaunchError { OutOfMemory, @@ -680,28 +664,3 @@ impl From for ProfileError { } } } - -pub fn valid_strides(shape: &[usize], strides: &[usize]) -> bool { - let rank = shape.len(); - if strides[rank - 1] != 1 { - return false; - } - if rank <= 1 { - return true; - } - - let mut sorted = strides.to_vec(); - sorted.sort(); - sorted.reverse(); - - if sorted != strides { - return false; - } - - for i in 0..rank - 2 { - if strides[i] != shape[i + 1] * strides[i + 1] { - return false; - } - } - true -} diff --git a/crates/cubecl-hip/src/runtime.rs b/crates/cubecl-hip/src/runtime.rs index 00f48a41f3..3409dcff9e 100644 --- a/crates/cubecl-hip/src/runtime.rs +++ b/crates/cubecl-hip/src/runtime.rs @@ -25,12 +25,13 @@ use cubecl_runtime::{ use crate::{ HipWmmaCompiler, compute::{ - HipContext, HipServer, contiguous_strides, + HipContext, HipServer, cpu::{PINNED_MEMORY_ALIGNMENT, PinnedMemoryStorage}, storage::gpu::GpuStorage, }, device::AmdDevice, }; +use cubecl_runtime::stride::{is_contiguous, is_inner_contiguous_rows}; /// The values that control how a HIP Runtime will perform its calculations. #[derive(Default)] @@ -174,6 +175,8 @@ fn create_client>>( mem_properties, topology, TimingMethod::System, + // Prefer pitched rows by default on HIP (hardware efficient). + cubecl_runtime::server::AllocationKind::Optimized, ); register_supported_types(&mut device_props); @@ -232,17 +235,7 @@ impl Runtime for HipRuntime { } fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool { - if shape.is_empty() { - return true; - } - - for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) { - if expected != stride { - return false; - } - } - - true + is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides) } fn target_properties() -> TargetProperties { diff --git a/crates/cubecl-random/src/tests/bernoulli.rs b/crates/cubecl-random/src/tests/bernoulli.rs index 5407856fbf..0ae49fe2eb 100644 --- a/crates/cubecl-random/src/tests/bernoulli.rs +++ b/crates/cubecl-random/src/tests/bernoulli.rs @@ -15,7 +15,7 @@ macro_rules! testgen_random_bernoulli { random_bernoulli::(&client, prob, output.as_ref()); - let output_data = client.read_one(output.handle); + let output_data = client.read_one_tensor(output.as_copy_descriptor()); let output_data = E::from_bytes(&output_data); output_data.to_owned() diff --git a/crates/cubecl-random/src/tests/normal.rs b/crates/cubecl-random/src/tests/normal.rs index 6e2f217ac0..d6db8357c5 100644 --- a/crates/cubecl-random/src/tests/normal.rs +++ b/crates/cubecl-random/src/tests/normal.rs @@ -16,7 +16,7 @@ macro_rules! testgen_random_normal { random_normal::(&client, mean, std, output.as_ref()); - let output_data = client.read_one(output.handle); + let output_data = client.read_one_tensor(output.as_copy_descriptor()); let output_data = E::from_bytes(&output_data); output_data.to_owned() diff --git a/crates/cubecl-random/src/tests/uniform.rs b/crates/cubecl-random/src/tests/uniform.rs index 6f107057e0..30de521021 100644 --- a/crates/cubecl-random/src/tests/uniform.rs +++ b/crates/cubecl-random/src/tests/uniform.rs @@ -17,7 +17,7 @@ macro_rules! testgen_random_uniform { random_uniform::(&client, lower_bound, upper_bound, output.as_ref()); - let output_data = client.read_one(output.handle); + let output_data = client.read_one_tensor(output.as_copy_descriptor()); let output_data = E::from_bytes(&output_data); output_data.to_owned() diff --git a/crates/cubecl-reduce/src/config.rs b/crates/cubecl-reduce/src/config.rs index d6b6ab806c..b9eb3aa337 100644 --- a/crates/cubecl-reduce/src/config.rs +++ b/crates/cubecl-reduce/src/config.rs @@ -2,7 +2,7 @@ use cubecl_core::{ channel::ComputeChannel, prelude::*, server::ComputeServer, tensor_line_size_parallel, tensor_line_size_perpendicular, }; -use cubecl_std::tensor::is_contiguous; +use cubecl_runtime::stride::is_contiguous; use crate::ReduceStrategy; diff --git a/crates/cubecl-runtime/src/client.rs b/crates/cubecl-runtime/src/client.rs index 8d28d50eb5..93e5399517 100644 --- a/crates/cubecl-runtime/src/client.rs +++ b/crates/cubecl-runtime/src/client.rs @@ -108,7 +108,21 @@ where async fn do_read(&self, descriptors: Vec>) -> Result, IoError> { self.profile_guard(); - + // Preflight stride compatibility using shared helpers for clearer, uniform errors. + let mut ok = true; + for d in &descriptors { + let shape = d.shape; + let strides = d.strides; + if !(crate::stride::is_contiguous(shape, strides) + || crate::stride::is_inner_contiguous_rows(shape, strides)) + { + ok = false; + break; + } + } + if !ok { + return Err(IoError::UnsupportedStrides); + } self.channel.read(descriptors).await } @@ -201,7 +215,14 @@ where data: Vec<&[u8]>, ) -> Result, IoError> { self.profile_guard(); - + // Preflight size checks to avoid allocating on mismatch. + for (desc, buf) in descriptors.iter().zip(data.iter().copied()) { + let expected = desc.shape.iter().product::() * desc.elem_size; + let actual = buf.len(); + if expected != actual { + return Err(IoError::InvalidSize { expected, actual }); + } + } let allocations = self.channel.create(descriptors.clone())?; let descriptors = descriptors .into_iter() @@ -297,7 +318,13 @@ where /// Reserves `shape` in the storage, and returns a tensor handle for it. /// See [ComputeClient::create_tensor] pub fn empty_tensor(&self, shape: &[usize], elem_size: usize) -> Allocation { - let descriptor = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size); + // Use the device's default allocation preference for rank > 1 tensors. + let kind = if shape.len() > 1 { + self.state.properties.default_alloc_rank_gt1 + } else { + AllocationKind::Contiguous + }; + let descriptor = AllocationDescriptor::new(kind, shape, elem_size); self.do_empty(vec![descriptor]).unwrap().remove(0) } @@ -324,10 +351,45 @@ where ) -> Allocation { let shape = src_descriptor.shape; let elem_size = src_descriptor.elem_size; - let alloc_desc = AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size); + // Prefer the device default for rank > 1 inner‑contiguous rows, otherwise + // fall back to contiguous. + let kind = if shape.len() > 1 + && crate::stride::is_inner_contiguous_rows(shape, src_descriptor.strides) + { + self.state.properties.default_alloc_rank_gt1 + } else { + AllocationKind::Contiguous + }; + let alloc_desc = AllocationDescriptor::new(kind, shape, elem_size); self.data_transfer(src_descriptor, alloc_desc, dst_server) } + /// Write tensor data to bindings asynchronously and wait for completion. + pub async fn write_async( + &self, + writes: Vec<(CopyDescriptor<'_>, &[u8])>, + ) -> Result<(), IoError> { + self.profile_guard(); + // Preflight stride compatibility. + for (d, _data) in &writes { + let shape = d.shape; + let strides = d.strides; + if !(crate::stride::is_contiguous(shape, strides) + || crate::stride::is_inner_contiguous_rows(shape, strides)) + { + return Err(IoError::UnsupportedStrides); + } + } + self.channel.write(writes)?; + self.channel.sync().await; + Ok(()) + } + + /// Write tensor data to bindings; panics on error. + pub fn write(&self, writes: Vec<(CopyDescriptor<'_>, &[u8])>) { + cubecl_common::reader::read_sync(self.write_async(writes)).unwrap() + } + #[track_caller] unsafe fn execute_inner( &self, diff --git a/crates/cubecl-runtime/src/feature_set.rs b/crates/cubecl-runtime/src/feature_set.rs index 15ebe4298a..acdac69423 100644 --- a/crates/cubecl-runtime/src/feature_set.rs +++ b/crates/cubecl-runtime/src/feature_set.rs @@ -8,6 +8,10 @@ use enumset::EnumSet; /// Properties of what the device can do, like what `Feature` are /// supported by it and what its memory properties are. +use crate::server::AllocationKind; + +/// Properties/features exposed by a device/runtime, used by higher layers for +/// capability checks and defaults. #[derive(Debug)] pub struct DeviceProperties { /// The features supported by the runtime. @@ -18,6 +22,14 @@ pub struct DeviceProperties { pub hardware: HardwareProperties, /// The method used for profiling on the device. pub timing_method: TimingMethod, + /// Default allocation preference for rank > 1 tensors when both contiguous and + /// inner‑contiguous row layouts are supported by the backend IO path. + /// + /// Backends can set this to `AllocationKind::Optimized` (pitched rows) when + /// strided IO is efficient in hardware (e.g., CUDA/HIP), or to + /// `AllocationKind::Contiguous` when contiguous copies are generally faster + /// (e.g., WGPU/CPU by default). + pub default_alloc_rank_gt1: AllocationKind, } impl DeviceProperties { @@ -27,12 +39,14 @@ impl DeviceProperties { memory_props: MemoryDeviceProperties, hardware: HardwareProperties, timing_method: TimingMethod, + default_alloc_rank_gt1: AllocationKind, ) -> Self { DeviceProperties { features, memory: memory_props, hardware, timing_method, + default_alloc_rank_gt1, } } diff --git a/crates/cubecl-runtime/src/lib.rs b/crates/cubecl-runtime/src/lib.rs index 0e4a9aad9b..a309b4b31c 100644 --- a/crates/cubecl-runtime/src/lib.rs +++ b/crates/cubecl-runtime/src/lib.rs @@ -53,3 +53,6 @@ pub mod timestamp_profiler; /// Utilities for data transfers between servers pub mod data_service; + +/// Stride compatibility helpers for preflight checks and host I/O planning. +pub mod stride; diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index cf952a2768..4de5e7bd12 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -229,6 +229,14 @@ pub enum IoError { /// Handle wasn't found in the memory pool #[error("couldn't find resource for that handle")] InvalidHandle, + /// Data size does not match `shape.product() * elem_size` (in bytes) + #[error("data size {actual} does not match expected {expected}")] + InvalidSize { + /// Expected byte length computed from `shape.product() * elem_size`. + expected: usize, + /// Actual byte length of the provided data buffer. + actual: usize, + }, /// Unknown error happened during execution #[error("Unknown error happened during execution")] Unknown(String), diff --git a/crates/cubecl-runtime/src/stride.rs b/crates/cubecl-runtime/src/stride.rs new file mode 100644 index 0000000000..8ebd734fa7 --- /dev/null +++ b/crates/cubecl-runtime/src/stride.rs @@ -0,0 +1,228 @@ +//! Stride compatibility helpers for preflight checks and host I/O planning. +//! +//! These utilities describe common stride patterns independently of any backend, +//! so hosts and higher layers can make informed choices and surface clearer errors. +//! +//! Strides are expressed in element units (not bytes). Element size may be used +//! by callers to convert to/from byte pitches as needed. + +use crate::server::AllocationKind; +use alloc::vec; +use alloc::vec::Vec; + +/// Canonical contiguous row-major strides for a given shape (in elements). +/// +/// Example: shape [R, C] -> strides [C, 1] +pub fn contiguous_strides(shape: &[usize]) -> Vec { + if shape.is_empty() { + return vec![]; + } + let mut strides = vec![0; shape.len()]; + let mut s = 1usize; + for (i, dim) in shape.iter().enumerate().rev() { + strides[i] = s; + s = s.saturating_mul(*dim.max(&1)); + } + strides +} + +/// A coarse description of a stride pattern. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StridePattern { + /// Fully contiguous row-major layout. + Contiguous, + /// 2D with inner-most contiguous axis and a row pitch (in elements) on the outer axis. + /// `row_pitch_elems >= cols` is required. + InnerContiguous2D { + /// Pitch between consecutive rows in elements (not bytes). + row_pitch_elems: usize, + }, + /// Rank >= 2 with inner-most contiguous axis and a row pitch over all outer dimensions flattened + /// into rows. `row_pitch_elems >= cols` is required. + InnerContiguousRows { + /// Pitch between consecutive flattened rows in elements (not bytes). + row_pitch_elems: usize, + }, + /// Any other non-supported or irregular stride pattern. + Other, +} + +/// Describe the given shape/strides pair. +pub fn describe(shape: &[usize], strides: &[usize]) -> StridePattern { + if shape.len() != strides.len() { + return StridePattern::Other; + } + + if strides == contiguous_strides(shape).as_slice() { + return StridePattern::Contiguous; + } + + if shape.len() == 2 { + let rows = shape[0]; + let cols = shape[1]; + let row_pitch = strides[0]; + let inner = strides[1]; + + // Accept inner-most contiguous 2D with row pitch >= cols. + if inner == 1 && row_pitch >= cols && rows > 0 && cols > 0 { + return StridePattern::InnerContiguous2D { + row_pitch_elems: row_pitch, + }; + } + } + + // General inner-contiguous rows for rank >= 2: last axis contiguous, outer strides chain + // multiplicatively while allowing row pitch padding on the last-but-one axis. + if shape.len() >= 2 { + let last = shape.len() - 1; + if strides[last] == 1 { + // Verify the stride chain for the outer dimensions: s[i] == shape[i+1] * s[i+1] + let mut ok = true; + for i in 0..last - 1 { + if strides[i] != shape[i + 1].saturating_mul(strides[i + 1]) { + ok = false; + break; + } + } + // For rank==2, the above loop is skipped; we fall back to the >= cols check below + let row_pitch = strides[last - 1]; + if ok && row_pitch >= shape[last] { + return StridePattern::InnerContiguousRows { + row_pitch_elems: row_pitch, + }; + } + } + } + + StridePattern::Other +} + +/// Whether the given shape/strides is fully contiguous. +#[inline] +pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool { + matches!(describe(shape, strides), StridePattern::Contiguous) +} + +/// Whether the given shape/strides is rank-2 with inner-most contiguous axis and a row pitch. +#[inline] +pub fn is_inner_contiguous_2d(shape: &[usize], strides: &[usize]) -> bool { + matches!( + describe(shape, strides), + StridePattern::InnerContiguous2D { .. } + ) +} + +/// Whether the given shape/strides is rank>=2 with inner-most contiguous axis and a row pitch +/// across all outer dimensions flattened. +#[inline] +pub fn is_inner_contiguous_rows(shape: &[usize], strides: &[usize]) -> bool { + matches!( + describe(shape, strides), + StridePattern::InnerContiguous2D { .. } | StridePattern::InnerContiguousRows { .. } + ) +} + +/// If `shape/strides` forms inner-contiguous rows (rank>=2), return the row pitch (in elements). +#[inline] +pub fn row_pitch_elems(shape: &[usize], strides: &[usize]) -> Option { + match describe(shape, strides) { + StridePattern::InnerContiguous2D { row_pitch_elems } + | StridePattern::InnerContiguousRows { row_pitch_elems } => Some(row_pitch_elems), + _ => None, + } +} + +/// Compute pitched-rows layout and allocation size for rank>1 tensors. +/// Returns (strides_in_elements, total_size_in_bytes). The row pitch (in bytes) +/// is aligned up to `align`. +pub fn pitched_rows_layout(shape: &[usize], elem_size: usize, align: usize) -> (Vec, usize) { + let rank = shape.len(); + let width = *shape.last().unwrap_or(&1); + let height: usize = shape.iter().rev().skip(1).product(); + let height = height.max(1); + let width_bytes = width * elem_size; + let row_pitch_bytes = width_bytes.next_multiple_of(align); + let size = height * row_pitch_bytes; + + let mut strides = vec![1usize; rank]; + if rank > 1 { + strides[rank - 2] = row_pitch_bytes / elem_size; + } + if rank > 2 { + for i in (0..rank - 2).rev() { + strides[i] = strides[i + 1] * shape[i + 1]; + } + } + + (strides, size) +} + +/// Suggest an allocation kind given a shape/strides pair. +/// - Contiguous -> `AllocationKind::Contiguous` +/// - Inner-contiguous rows -> `AllocationKind::Optimized` +/// - Otherwise fallback to `AllocationKind::Contiguous` (caller may still reject on use). +pub fn preferred_allocation_kind(shape: &[usize], strides: &[usize]) -> AllocationKind { + if is_contiguous(shape, strides) { + AllocationKind::Contiguous + } else if is_inner_contiguous_rows(shape, strides) { + AllocationKind::Optimized + } else { + AllocationKind::Contiguous + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn contiguous_for_1d_and_2d() { + assert_eq!(contiguous_strides(&[8]), vec![1]); + assert_eq!(contiguous_strides(&[2, 3]), vec![3, 1]); + assert!(is_contiguous(&[8], &[1])); + assert!(is_contiguous(&[2, 3], &[3, 1])); + assert!(!is_contiguous(&[2, 3], &[4, 1])); + } + + #[test] + fn inner_contiguous_2d_detection() { + // 2D pitched: rows=4, cols=5, pitch=8 (in elems) + let shape = [4, 5]; + let strides = [8, 1]; + assert!(is_inner_contiguous_2d(&shape, &strides)); + match describe(&shape, &strides) { + StridePattern::InnerContiguous2D { row_pitch_elems } => assert_eq!(row_pitch_elems, 8), + other => panic!("unexpected: {other:?}"), + } + // Not inner-contiguous + assert!(!is_inner_contiguous_2d(&shape, &[8, 2])); + // Pitch less than cols should not be accepted + assert!(!is_inner_contiguous_2d(&shape, &[4, 1])); + } + + #[test] + fn inner_contiguous_rows_rank3() { + // Rank 3 with inner-contiguous rows and padded pitch + let shape = [2, 3, 5]; // rows=2*3=6, cols=5 + let row_pitch = 8usize; + // Stride chain: s[1] arbitrary row_pitch, s[0] == shape[1] * s[1] == 3 * 8 = 24 + let strides = [24, row_pitch, 1]; + assert!(is_inner_contiguous_rows(&shape, &strides)); + assert_eq!(row_pitch_elems(&shape, &strides), Some(row_pitch)); + // Reject when last stride not 1 + assert!(!is_inner_contiguous_rows(&shape, &[24, 8, 2])); + // Reject when chain breaks + assert!(!is_inner_contiguous_rows(&shape, &[16, 8, 1])); + } + + #[test] + fn describe_other() { + // Rank 3 non-contiguous pattern should be Other + assert!(matches!( + describe(&[2, 3, 4], &[10, 4, 1]), + StridePattern::Other + )); + // Mismatched lengths + assert!(matches!(describe(&[2], &[]), StridePattern::Other)); + } +} diff --git a/crates/cubecl-runtime/tests/dummy/compute.rs b/crates/cubecl-runtime/tests/dummy/compute.rs index 3619918fe1..e066d8cb35 100644 --- a/crates/cubecl-runtime/tests/dummy/compute.rs +++ b/crates/cubecl-runtime/tests/dummy/compute.rs @@ -52,6 +52,7 @@ pub fn init_client() -> ComputeClient( ); } -/// Checks if the tensor associated with the given shape and strides is contiguous. -pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool { - if shape.is_empty() { - return true; - } - - for (expected, &stride) in compact_strides(shape).into_iter().zip(strides) { - if expected != stride { - return false; - } - } - - true -} - /// Checks if a tensor is only strided on the last dimension, and could be safely reinterpreted as /// a 2D tensor with unit stride on the last dimension. This will always hold for non-permuted /// tensors allocated on a runtime. pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool { - let rank = shape.len(); - if strides[rank - 1] != 1 { - return false; - } - if rank <= 1 { - return true; - } - - let mut sorted = strides.to_vec(); - sorted.sort(); - sorted.reverse(); - - if sorted != strides { - return false; - } - - for i in 0..rank - 2 { - if strides[i] != shape[i + 1] * strides[i + 1] { - return false; - } - } - true -} - -pub fn compact_strides(shape: &[usize]) -> Vec { - let rank = shape.len(); - let mut strides = vec![1; rank]; - for i in (0..rank - 1).rev() { - strides[i] = strides[i + 1] * shape[i + 1]; - } - strides + cubecl_runtime::stride::is_inner_contiguous_rows(shape, strides) } diff --git a/crates/cubecl-std/src/tensor/handle.rs b/crates/cubecl-std/src/tensor/handle.rs index 33cd8527fe..8790e11a96 100644 --- a/crates/cubecl-std/src/tensor/handle.rs +++ b/crates/cubecl-std/src/tensor/handle.rs @@ -4,6 +4,7 @@ use cubecl_core::{Runtime, server}; use cubecl_core::{calculate_cube_count_elemwise, server::Allocation}; use cubecl_core::{prelude::*, server::CopyDescriptor}; use cubecl_runtime::server::Handle; +use cubecl_runtime::stride as stride_util; /// Tensor representation containing a [server handle](Handle) as well as basic tensor metadata., pub struct TensorHandle @@ -86,7 +87,7 @@ where /// Create a new tensor with a contiguous memory layout. pub fn new_contiguous(shape: Vec, handle: Handle) -> Self { - let strides = Self::contiguous_strides(&shape); + let strides = stride_util::contiguous_strides(&shape); Self { handle, @@ -135,18 +136,6 @@ where elem_size: size_of::(), } } - - fn contiguous_strides(shape: &[usize]) -> Vec { - let mut strides = Vec::with_capacity(shape.len()); - - let mut current = 1; - shape.iter().enumerate().rev().for_each(|(_, val)| { - strides.push(current); - current *= val; - }); - strides.reverse(); - strides - } } impl TensorHandle where diff --git a/crates/cubecl-std/src/tensor/layout/linear.rs b/crates/cubecl-std/src/tensor/layout/linear.rs index 5d814feb0b..721203753e 100644 --- a/crates/cubecl-std/src/tensor/layout/linear.rs +++ b/crates/cubecl-std/src/tensor/layout/linear.rs @@ -2,7 +2,7 @@ use cubecl::prelude::*; use cubecl_core::{self as cubecl, unexpanded}; use crate::tensor::{ - is_contiguous, is_contiguous_pitched, + is_contiguous_pitched, launch::{TypedView, TypedViewLaunch}, layout::{ Coords1d, Layout, LayoutExpand, VirtualLayoutOperationsExpand, @@ -11,6 +11,7 @@ use crate::tensor::{ strided::{StridedLayout, StridedLayoutLaunch}, }, }; +use cubecl_runtime::stride::is_contiguous; /// Maps a linear index based on line count to a potentially strided tensor. Only applies the /// necessary level of striding, either none, only the last dim (for freshly allocated strided diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index c3762ee23c..6a49fce8d0 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -13,10 +13,11 @@ use cubecl_core::{ }; use cubecl_core::{ compute::{CubeTask, DebugInformation}, - server::{Allocation, AllocationDescriptor, IoError}, + server::{Allocation, AllocationDescriptor, AllocationKind, IoError}, }; use cubecl_runtime::logging::ServerLogger; use cubecl_runtime::memory_management::offset_handles; +use cubecl_runtime::stride::{contiguous_strides, pitched_rows_layout, row_pitch_elems}; use cubecl_runtime::{ memory_management::MemoryDeviceProperties, server::ComputeServer, storage::BindingResource, }; @@ -128,14 +129,23 @@ impl ComputeServer for WgpuServer { descriptors: Vec>, ) -> Result, IoError> { let align = self.device.limits().min_storage_buffer_offset_alignment as usize; - let strides = descriptors - .iter() - .map(|desc| contiguous_strides(desc.shape)) - .collect::>(); - let sizes = descriptors - .iter() - .map(|desc| desc.shape.iter().product::() * desc.elem_size) - .collect::>(); + + let mut strides_out = Vec::with_capacity(descriptors.len()); + let mut sizes = Vec::with_capacity(descriptors.len()); + + for desc in &descriptors { + let rank = desc.shape.len(); + if matches!(desc.kind, AllocationKind::Optimized) && rank > 1 { + let (strides, size) = pitched_rows_layout(desc.shape, desc.elem_size, align); + strides_out.push(strides); + sizes.push(size); + } else { + // Contiguous allocation + strides_out.push(contiguous_strides(desc.shape)); + sizes.push(desc.shape.iter().product::() * desc.elem_size); + } + } + let total_size = sizes .iter() .map(|it| it.next_multiple_of(align)) @@ -146,7 +156,7 @@ impl ComputeServer for WgpuServer { Ok(handles .into_iter() - .zip(strides) + .zip(strides_out) .map(|(handle, strides)| Allocation::new(handle, strides)) .collect()) } @@ -155,20 +165,33 @@ impl ComputeServer for WgpuServer { &mut self, descriptors: Vec>, ) -> DynFut, IoError>> { - for desc in &descriptors { - if contiguous_strides(desc.shape) != desc.strides { - return Box::pin(async { Err(IoError::UnsupportedStrides) }); - } - } self.stream.read_buffers(descriptors) } fn write(&mut self, descriptors: Vec<(CopyDescriptor<'_>, &[u8])>) -> Result<(), IoError> { for (desc, data) in descriptors { - if contiguous_strides(desc.shape) != desc.strides { - return Err(IoError::UnsupportedStrides); + // Contiguous path + if contiguous_strides(desc.shape) == desc.strides { + self.stream.write(desc.binding, data); + continue; + } + + // Inner-contiguous pitched rows: rank>=2 + if let Some(pitch_elems) = row_pitch_elems(desc.shape, desc.strides) { + let last = desc.shape.len() - 1; + let rows = desc.shape[..last].iter().product::() as u64; + let cols = desc.shape[last] as u64; + let elem = desc.elem_size as u64; + let row_bytes = cols * elem; + let row_pitch = pitch_elems as u64 * elem; + + let resource = self.stream.mem_manage.get_resource(desc.binding); + self.stream + .write_rows_pitched(&resource, rows, row_bytes, row_pitch, data); + continue; } - self.stream.write(desc.binding, data); + + return Err(IoError::UnsupportedStrides); } Ok(()) } @@ -230,12 +253,3 @@ fn compiler(backend: wgpu::Backend) -> AutoCompiler { _ => AutoCompiler::Wgsl(Default::default()), } } - -pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec { - let rank = shape.len(); - let mut strides = vec![1; rank]; - for i in (0..rank - 1).rev() { - strides[i] = strides[i + 1] * shape[i + 1]; - } - strides -} diff --git a/crates/cubecl-wgpu/src/compute/stream.rs b/crates/cubecl-wgpu/src/compute/stream.rs index 4c843bb6a4..3076bafe3e 100644 --- a/crates/cubecl-wgpu/src/compute/stream.rs +++ b/crates/cubecl-wgpu/src/compute/stream.rs @@ -1,5 +1,6 @@ +use super::controller::WgpuAllocController; use super::{mem_manager::WgpuMemManager, poll::WgpuPoll, timings::QueryProfiler}; -use crate::{WgpuResource, controller::WgpuAllocController}; +use crate::WgpuResource; use cubecl_common::{ bytes::Bytes, profile::{ProfileDuration, TimingMethod}, @@ -9,8 +10,10 @@ use cubecl_core::{ future::{self, DynFut}, server::{Binding, Bindings, CopyDescriptor, Handle, IoError, ProfileError, ProfilingToken}, }; +use cubecl_runtime::stride::{contiguous_strides, row_pitch_elems}; use cubecl_runtime::{ - memory_management::MemoryDeviceProperties, timestamp_profiler::TimestampProfiler, + memory_management::{MemoryDeviceProperties, SliceBinding}, + timestamp_profiler::TimestampProfiler, }; use std::{future::Future, num::NonZero, pin::Pin, sync::Arc}; use wgpu::ComputePipeline; @@ -163,6 +166,8 @@ impl WgpuStream { pass.set_pipeline(&pipeline); pass.set_bind_group(0, &bind_group, &[]); + // Rationale: support common row-pitched (2D) host layouts on WGPU. + // Limits: rank==2 and inner-most contiguous; other stride patterns remain UnsupportedStrides. match dispatch.clone() { CubeCount::Static(x, y, z) => { pass.dispatch_workgroups(x, y, z); @@ -190,44 +195,103 @@ impl WgpuStream { descriptors: Vec, ) -> DynFut, IoError>> { self.compute_pass = None; - let mut staging_info = Vec::with_capacity(descriptors.len()); + enum Stored { + Managed { + staging: WgpuResource, + binding: SliceBinding, + size: usize, + }, + Pitched { + buffer: wgpu::Buffer, + size: usize, + row_bytes: usize, + row_pitch: usize, + }, + } + let mut entries: Vec = Vec::with_capacity(descriptors.len()); let mut callbacks = Vec::with_capacity(descriptors.len()); for descriptor in descriptors { let binding = descriptor.binding; - let size = descriptor.shape.iter().product::() * descriptor.elem_size; - // Copying into a buffer has to be 4 byte aligned. We can safely do so, as - // memory is 32 bytes aligned (see WgpuStorage). - let align = wgpu::COPY_BUFFER_ALIGNMENT; + let elem = descriptor.elem_size as u64; let resource = self.mem_manage.get_resource(binding); - let aligned_len = resource.size.div_ceil(align) * align; - let (staging, binding) = self.mem_manage.reserve_staging(aligned_len).unwrap(); - - self.tasks_count += 1; - self.encoder.copy_buffer_to_buffer( - &resource.buffer, - resource.offset, - &staging.buffer, - 0, - aligned_len, - ); - staging_info.push((staging, binding, size)); - } + let align = wgpu::COPY_BUFFER_ALIGNMENT; - // Flush all commands to the queue, so GPU gets started on copying to the staging buffer. - self.flush(); + if contiguous_strides(descriptor.shape) == descriptor.strides { + let size = descriptor.shape.iter().product::() * descriptor.elem_size; + let aligned_len = resource.size.div_ceil(align) * align; + let (staging, sbinding) = self.mem_manage.reserve_staging(aligned_len).unwrap(); + self.tasks_count += 1; + self.encoder.copy_buffer_to_buffer( + &resource.buffer, + resource.offset, + &staging.buffer, + 0, + aligned_len, + ); + entries.push(Stored::Managed { + staging, + binding: sbinding, + size, + }); + continue; + } - for (staging, _binding, _size) in staging_info.iter() { - let (sender, receiver) = async_channel::bounded(1); - staging - .buffer - .slice(..) - .map_async(wgpu::MapMode::Read, move |v| { - // This might fail if the channel is closed (eg. the future is dropped). - // This is fine, just means results aren't needed anymore. - let _ = sender.try_send(v); + if let Some(pitch_elems) = row_pitch_elems(descriptor.shape, descriptor.strides) { + let last = descriptor.shape.len() - 1; + let rows = descriptor.shape[..last].iter().product::() as u64; + let cols = descriptor.shape[last] as u64; + let row_bytes = cols * elem; + let row_pitch = pitch_elems as u64 * elem; + let total = rows * row_pitch; + let aligned_total = total.div_ceil(align) * align; + + let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: aligned_total, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + self.tasks_count += 1; + self.encoder.copy_buffer_to_buffer( + &resource.buffer, + resource.offset, + &staging_buffer, + 0, + aligned_total, + ); + let out_size = (rows * row_bytes) as usize; + entries.push(Stored::Pitched { + buffer: staging_buffer, + size: out_size, + row_bytes: row_bytes as usize, + row_pitch: row_pitch as usize, }); + continue; + } + + return Box::pin(async { Err(IoError::UnsupportedStrides) }); + } + // Flush copies to queue + self.flush(); + for entry in &entries { + let (sender, receiver) = async_channel::bounded(1); + match entry { + Stored::Managed { staging, .. } => { + staging + .buffer + .slice(..) + .map_async(wgpu::MapMode::Read, move |v| { + let _ = sender.try_send(v); + }); + } + Stored::Pitched { buffer, .. } => { + buffer.slice(..).map_async(wgpu::MapMode::Read, move |v| { + let _ = sender.try_send(v); + }); + } + } callbacks.push(receiver); } @@ -242,19 +306,40 @@ impl WgpuStream { .expect("Failed to map buffer"); } - // Can stop polling now. core::mem::drop(poll); - let result = { - staging_info - .into_iter() - .map(|(staging, binding, size)| { + let result: Vec = entries + .into_iter() + .map(|entry| match entry { + Stored::Managed { + staging, + binding, + size, + } => { let (controller, alloc) = WgpuAllocController::init(binding, staging.buffer, size); unsafe { Bytes::from_raw_parts(alloc, size, Box::new(controller)) } - }) - .collect() - }; + } + Stored::Pitched { + buffer, + size, + row_bytes, + row_pitch, + } => { + let data = buffer.slice(..).get_mapped_range(); + let rows = if row_bytes == 0 { 0 } else { size / row_bytes }; + let mut out = vec![0u8; size]; + for r in 0..rows { + let src_off = r * row_pitch; + let dst_off = r * row_bytes; + out[dst_off..dst_off + row_bytes] + .copy_from_slice(&data[src_off..src_off + row_bytes]); + } + buffer.unmap(); + Bytes::from_bytes_vec(out) + } + }) + .collect(); Ok(result) }) @@ -384,6 +469,27 @@ impl WgpuStream { self.write_to_buffer(&resource, data); } + /// Writes a contiguous slice into a pitched 2D region row-by-row using queue writes. + /// Assumes `data` is laid out as `rows` contiguous rows of `row_bytes` each. + pub fn write_rows_pitched( + &mut self, + resource: &WgpuResource, + rows: u64, + row_bytes: u64, + row_pitch: u64, + data: &[u8], + ) { + // Ensure queued compute work is flushed before queue writes for ordering. + self.flush(); + for r in 0..rows { + let src_off = (r * row_bytes) as usize; + let dst_off = resource.offset + r * row_pitch; + let end = src_off + row_bytes as usize; + let slice = &data[src_off..end]; + self.queue.write_buffer(&resource.buffer, dst_off, slice); + } + } + // Nb: this function submits a command to the _queue_ not to the encoder, // so you have to be really careful about the ordering of operations here. // Any buffer which has outstanding (not yet flushed) compute work should diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 7c550b56de..e86cc76b90 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -1,8 +1,6 @@ -use crate::{ - AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer, - contiguous_strides, -}; +use crate::{AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer}; use cubecl_common::{future, profile::TimingMethod}; +use cubecl_runtime::stride::{is_contiguous, is_inner_contiguous_rows}; use cubecl_core::{CubeCount, CubeDim, Runtime, ir::TargetProperties}; pub use cubecl_runtime::memory_management::MemoryConfiguration; @@ -80,17 +78,7 @@ impl Runtime for WgpuRuntime { } fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool { - if shape.is_empty() { - return true; - } - - for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) { - if expected != stride { - return false; - } - } - - true + is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides) } fn target_properties() -> TargetProperties { @@ -268,6 +256,8 @@ pub(crate) fn create_client_on_setup( mem_props.clone(), hardware_props, time_measurement, + // Default to contiguous for rank>1 allocations on WGPU unless overridden. + cubecl_runtime::server::AllocationKind::Contiguous, ); #[cfg(not(all(target_os = "macos", feature = "msl")))] diff --git a/crates/cubecl-wgpu/tests/strided_io.rs b/crates/cubecl-wgpu/tests/strided_io.rs new file mode 100644 index 0000000000..f49fcdebd7 --- /dev/null +++ b/crates/cubecl-wgpu/tests/strided_io.rs @@ -0,0 +1,99 @@ +#![cfg(feature = "std")] + +use cubecl_common::{future::block_on, reader::read_sync}; +use cubecl_runtime::server::CopyDescriptor; + +#[test] +fn wgpu_strided_io_roundtrip_u8_rows_pitched() { + type R = cubecl_wgpu::WgpuRuntime; + let client = R::client(&R::Device::default()); + + let rows: usize = 4; + let cols: usize = 5; + let pitch_elems: usize = 8; // >= cols; introduces padding per row + let elem_size: usize = 1; // u8 + let total_bytes = rows * pitch_elems * elem_size; + + // Allocate a buffer large enough to hold the pitched rows. + let handle = client.empty(total_bytes); + + // Prepare contiguous row-major data (no padding in source). + let mut data = vec![0u8; rows * cols]; + for r in 0..rows { + for c in 0..cols { + data[r * cols + c] = (r as u8) * 32 + (c as u8); + } + } + + let binding = handle.clone().binding(); + let shape = [rows, cols]; + let strides = [pitch_elems, 1]; + + // Write using pitched descriptor; the runtime should place each row at row_pitch offsets. + let write_desc = CopyDescriptor::new(binding.clone(), &shape, &strides, elem_size); + block_on(client.write_async(vec![(write_desc, &data)])).expect("pitched write ok"); + + // Read back using the same pitched descriptor. The runtime reconstructs contiguous rows. + let read_desc = CopyDescriptor::new(binding.clone(), &shape, &strides, elem_size); + let out = read_sync(client.read_tensor_async(vec![read_desc])); + assert_eq!(out.len(), 1); + assert_eq!(out[0], data); +} + +#[test] +fn wgpu_strided_io_roundtrip_f32_rows_pitched() { + type R = cubecl_wgpu::WgpuRuntime; + let client = R::client(&R::Device::default()); + + let rows: usize = 3; + let cols: usize = 7; + let pitch_elems: usize = 10; // >= cols; introduces padding per row + let elem_size: usize = core::mem::size_of::(); + let total_bytes = rows * pitch_elems * elem_size; + + // Allocate a buffer large enough to hold the pitched rows. + let handle = client.empty(total_bytes); + + // Prepare contiguous row-major data (no padding in source). + let mut data = vec![0f32; rows * cols]; + for r in 0..rows { + for c in 0..cols { + data[r * cols + c] = (r as f32) * 100.0 + (c as f32); + } + } + let bytes: &[u8] = bytemuck::cast_slice(&data); + + let binding = handle.clone().binding(); + let shape = [rows, cols]; + let strides = [pitch_elems, 1]; + + // Write using pitched descriptor; the runtime should place each row at row_pitch offsets. + let write_desc = CopyDescriptor::new(binding.clone(), &shape, &strides, elem_size); + block_on(client.write_async(vec![(write_desc, bytes)])).expect("pitched write ok"); + + // Read back using the same pitched descriptor. The runtime reconstructs contiguous rows. + let read_desc = CopyDescriptor::new(binding.clone(), &shape, &strides, elem_size); + let out = read_sync(client.read_tensor_async(vec![read_desc])); + assert_eq!(out.len(), 1); + assert_eq!(out[0], bytes); +} + +#[test] +#[should_panic] +fn wgpu_strided_io_read_unsupported_strides_panics_rank3() { + type R = cubecl_wgpu::WgpuRuntime; + let client = R::client(&R::Device::default()); + + // Rank 3 descriptor with non-trivial strides is currently unsupported on WGPU. + let shape = [2usize, 3usize, 4usize]; + let strides = [12usize, 4usize, 1usize]; + let elem_size = 1usize; + let total_bytes = shape.iter().product::() * elem_size; + + let handle = client.empty(total_bytes); + let binding = handle.binding(); + + // Attempting to read should surface UnsupportedStrides, which panics via client read wrapper. + let desc = CopyDescriptor::new(binding, &shape, &strides, elem_size); + let _ = client.read_tensor(vec![desc]); +} diff --git a/cubecl-book/src/getting-started/simple_reduction.md b/cubecl-book/src/getting-started/simple_reduction.md index 97ba4eefdd..3c9a36378b 100644 --- a/cubecl-book/src/getting-started/simple_reduction.md +++ b/cubecl-book/src/getting-started/simple_reduction.md @@ -27,7 +27,7 @@ The following code creates a 3x3 matrix, initializes the input tensor, and calls This example demonstrates how to perform a simple reduction operation on a multidimensional array (tensor) using CubeCL. It is a simple implementation that will be used as a starting point to show how to use CubeCL in the next chapters. ## GpuTensor struct -The `GpuTensor` struct is a representation of a tensor that resides on the GPU. It contains the data handle, shape, strides, and marker types for the runtime and floating-point type. The `GpuTensor` struct provides methods to create tensors, read data from the GPU, and convert them into tensor arguments for kernel execution. Please note that it is generic over the runtime and floating-point type, allowing it to work with different CubeCL runtimes and floating-point types (e.g., `f16`, `f32`). Also, the strides can be computed using the `compact_strides` function from the `cubecl::std::tensor` module, which will compute the strides for a given shape with a compact representation. +The `GpuTensor` struct is a representation of a tensor that resides on the GPU. It contains the data handle, shape, strides, and marker types for the runtime and floating-point type. The `GpuTensor` struct provides methods to create tensors, read data from the GPU, and convert them into tensor arguments for kernel execution. Please note that it is generic over the runtime and floating-point type, allowing it to work with different CubeCL runtimes and floating-point types (e.g., `f16`, `f32`). Also, the strides can be computed using the `contiguous_strides` helper from the `cubecl_runtime::stride` module, which computes canonical row‑major strides for a given shape. Another important concept is the `ComputeClient` trait, which define what a runtime should implement to be able to run kernels. Each runtime has their own implementation of the `ComputeClient` trait, which provides methods to create tensors and read data from the GPU. The `ComputeClient` can send compute task to a `Server` that will run the kernel on the GPU and schedule the tasks. diff --git a/cubecl-book/src/getting-started/src/cpu_tensor.rs b/cubecl-book/src/getting-started/src/cpu_tensor.rs index f2da0db6fc..3591b1e2ea 100644 --- a/cubecl-book/src/getting-started/src/cpu_tensor.rs +++ b/cubecl-book/src/getting-started/src/cpu_tensor.rs @@ -1,3 +1,5 @@ +use cubecl_runtime::stride::contiguous_strides; + /// Example of a naive multidimensional tensor in pure Rust #[derive(Debug, Clone)] pub struct CpuTensor { @@ -9,22 +11,13 @@ pub struct CpuTensor { pub shape: Vec, } -/// Function to compute strides in a compact layout -fn compact_strides(shape: &[usize]) -> Vec { - let rank = shape.len(); - let mut strides = vec![1; rank]; - for i in (0..rank - 1).rev() { - strides[i] = strides[i + 1] * shape[i + 1]; - } - strides -} impl CpuTensor { /// Create a CpuTensor with a shape filled by number in order pub fn arange(shape: Vec) -> Self { let size = shape.iter().product(); let data = (0..size).map(|i| i as f32).collect(); - let strides = compact_strides(&shape); + let strides = contiguous_strides(&shape); Self { data, strides, @@ -36,7 +29,7 @@ impl CpuTensor { pub fn empty(shape: Vec) -> Self { let size = shape.iter().product(); let data = vec![0.0; size]; - let strides = compact_strides(&shape); + let strides = contiguous_strides(&shape); Self { data, strides, diff --git a/cubecl-book/src/getting-started/src/gpu_tensor.rs b/cubecl-book/src/getting-started/src/gpu_tensor.rs index 24f36be793..715cb691d0 100644 --- a/cubecl-book/src/getting-started/src/gpu_tensor.rs +++ b/cubecl-book/src/getting-started/src/gpu_tensor.rs @@ -1,6 +1,7 @@ use std::marker::PhantomData; -use cubecl::{prelude::*, server::Handle, std::tensor::compact_strides}; +use cubecl::{prelude::*, server::Handle}; +use cubecl_runtime::stride::contiguous_strides; /// Simple GpuTensor #[derive(Debug)] @@ -31,7 +32,7 @@ impl GpuTensor { let data: Vec = (0..size).map(|i| F::from_int(i as i64)).collect(); let data = client.create(F::as_bytes(&data)); - let strides = compact_strides(&shape); + let strides = contiguous_strides(&shape); Self { data, shape, @@ -46,7 +47,7 @@ impl GpuTensor { let size = shape.iter().product::() * core::mem::size_of::(); let data = client.empty(size); - let strides = compact_strides(&shape); + let strides = contiguous_strides(&shape); Self { data, shape,