Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use cubecl_matmul::components::{
tile::TileMatmulFamily,
};

use cubecl_runtime::stride::{is_contiguous, is_inner_contiguous_rows};
use cubecl_std::{
CubeOption,
tensor::{TensorHandle, into_contiguous},
Expand Down Expand Up @@ -80,12 +81,8 @@ impl<TMM: TileMatmulFamily<LhsTile = Strided, RhsTile = Strided, AccTile = CubeO
}
}

fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: MatmulIdent) -> bool {
let rank = handle.shape.len();
let dim_c = rank - 1;
match ident {
MatmulIdent::Lhs => handle.strides[dim_c] == 1,
MatmulIdent::Rhs => handle.strides[dim_c] == 1,
MatmulIdent::Out => unreachable!(),
}
fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, _ident: MatmulIdent) -> bool {
// Accept fully contiguous or inner‑contiguous rows (rank>=2)
is_contiguous(handle.shape, handle.strides)
|| is_inner_contiguous_rows(handle.shape, handle.strides)
}
112 changes: 88 additions & 24 deletions crates/cubecl-cpu/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use cubecl_runtime::{
};

use crate::{CpuCompiler, compute::alloc_controller::CpuAllocController};
use cubecl_runtime::stride::{contiguous_strides, pitched_rows_layout, row_pitch_elems};

use super::scheduler::Scheduler;

Expand Down Expand Up @@ -66,10 +67,43 @@ impl CpuServer {
) -> Result<Vec<Bytes>, IoError> {
let mut result = Vec::with_capacity(descriptors.len());
for desc in descriptors {
let len = desc.binding.size() as usize;
let (controller, alloc) =
CpuAllocController::init(desc.binding, &mut ctx.memory_management)?;
result.push(unsafe { Bytes::from_raw_parts(alloc, len, Box::new(controller)) });
let binding = desc.binding;
let elem = desc.elem_size;
let size = desc.shape.iter().product::<usize>() * elem;

// Contiguous: return zero-copy Bytes over the binding with logical len
if contiguous_strides(desc.shape) == desc.strides {
let (controller, alloc) =
CpuAllocController::init(binding, &mut ctx.memory_management)?;
result
.push(unsafe { Bytes::from_raw_parts(alloc, size, Box::new(controller)) });
continue;
}

// Inner-contiguous rows: reconstruct rows into contiguous buffer
if let Some(row_pitch_elems) = row_pitch_elems(desc.shape, desc.strides) {
let resource = ctx
.memory_management
.get_resource(binding.memory, binding.offset_start, binding.offset_end)
.ok_or(IoError::InvalidHandle)?;
let last = desc.shape.len() - 1;
let rows = desc.shape[..last].iter().product::<usize>();
let cols = desc.shape[last];
let row_bytes = cols * elem;
let row_pitch = row_pitch_elems * elem;
let src = resource.read();
let mut out = vec![0u8; rows * row_bytes];
for r in 0..rows {
let src_off = r * row_pitch;
let dst_off = r * row_bytes;
out[dst_off..dst_off + row_bytes]
.copy_from_slice(&src[src_off..src_off + row_bytes]);
}
result.push(Bytes::from_bytes_vec(out));
continue;
}

return Err(IoError::UnsupportedStrides);
}
Ok(result)
}
Expand All @@ -90,14 +124,22 @@ impl ComputeServer for CpuServer {
descriptors: Vec<AllocationDescriptor<'_>>,
) -> Result<Vec<Allocation>, IoError> {
let align = 8;
let strides = descriptors
.iter()
.map(|desc| contiguous_strides(desc.shape))
.collect::<Vec<_>>();
let sizes = descriptors
.iter()
.map(|desc| desc.shape.iter().product::<usize>() * desc.elem_size)
.collect::<Vec<_>>();
let mut strides = Vec::with_capacity(descriptors.len());
let mut sizes = Vec::with_capacity(descriptors.len());

use cubecl_core::server::AllocationKind;

for desc in &descriptors {
let rank = desc.shape.len();
if matches!(desc.kind, AllocationKind::Optimized) && rank > 1 {
let (s, size) = pitched_rows_layout(desc.shape, desc.elem_size, align);
strides.push(s);
sizes.push(size);
} else {
strides.push(contiguous_strides(desc.shape));
sizes.push(desc.shape.iter().product::<usize>() * desc.elem_size);
}
}
let total_size = sizes
.iter()
.map(|it| it.next_multiple_of(align))
Expand All @@ -123,11 +165,42 @@ impl ComputeServer for CpuServer {

fn write(&mut self, descriptors: Vec<(CopyDescriptor<'_>, &[u8])>) -> Result<(), IoError> {
for (desc, data) in descriptors {
if desc.strides != contiguous_strides(desc.shape) {
return Err(IoError::UnsupportedStrides);
// Contiguous path
if contiguous_strides(desc.shape) == desc.strides {
self.copy_to_binding(desc.binding, data);
continue;
}

// Inner-contiguous rows: copy into pitched destination row-by-row
if let Some(row_pitch_elems) = row_pitch_elems(desc.shape, desc.strides) {
let last = desc.shape.len() - 1;
let rows = desc.shape[..last].iter().product::<usize>();
let cols = desc.shape[last];
let elem = desc.elem_size;
let row_bytes = cols * elem;
let row_pitch = row_pitch_elems * elem;

let resource = self
.ctx
.memory_management
.get_resource(
desc.binding.memory,
desc.binding.offset_start,
desc.binding.offset_end,
)
.ok_or(IoError::InvalidHandle)?;

let dst = resource.write();
for r in 0..rows {
let dst_off = r * row_pitch;
let src_off = r * row_bytes;
dst[dst_off..dst_off + row_bytes]
.copy_from_slice(&data[src_off..src_off + row_bytes]);
}
continue;
}

self.copy_to_binding(desc.binding, data);
return Err(IoError::UnsupportedStrides);
}
Ok(())
}
Expand Down Expand Up @@ -219,12 +292,3 @@ impl CpuServer {
resource.write().copy_from_slice(data);
}
}

pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
let rank = shape.len();
let mut strides = vec![1; rank];
for i in (0..rank - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
6 changes: 4 additions & 2 deletions crates/cubecl-cpu/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ use cubecl_core::{
client::ComputeClient,
ir::{StorageType, TargetProperties},
};
use cubecl_runtime::stride::{is_contiguous, is_inner_contiguous_rows};
use cubecl_runtime::{
ComputeRuntime, DeviceProperties,
memory_management::{HardwareProperties, MemoryDeviceProperties, MemoryManagement},
storage::BytesStorage,
};
use cubecl_std::tensor::is_contiguous;
use sysinfo::System;

use crate::{
Expand Down Expand Up @@ -72,6 +72,8 @@ fn create_client(options: RuntimeOptions) -> ComputeClient<Server, Channel> {
mem_properties,
topology,
TimingMethod::Device,
// Default to contiguous on CPU.
cubecl_runtime::server::AllocationKind::Contiguous,
);
register_supported_types(&mut device_props);

Expand Down Expand Up @@ -111,7 +113,7 @@ impl Runtime for CpuRuntime {
}

fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
is_contiguous(shape, strides)
is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides)
}

fn target_properties() -> TargetProperties {
Expand Down
7 changes: 4 additions & 3 deletions crates/cubecl-cuda/src/compute/io/base.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use super::controller::PinnedMemoryManagedAllocController;
use crate::compute::{CudaContext, MB, valid_strides};
use crate::compute::{CudaContext, MB};
use cubecl_common::bytes::Bytes;
use cubecl_core::server::{CopyDescriptor, IoError};
use cubecl_runtime::memory_management::MemoryHandle;
use cubecl_runtime::stride::{is_contiguous, is_inner_contiguous_rows, row_pitch_elems};
use cudarc::driver::sys::{CUDA_MEMCPY2D_st, CUmemorytype, cuMemcpy2DAsync_v2};
use std::{ffi::c_void, ops::DerefMut};

Expand Down Expand Up @@ -52,7 +53,7 @@ pub fn register_copy_to_bytes(
elem_size,
} = descriptor;

if !valid_strides(shape, strides) {
if !(is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides)) {
return Err(IoError::UnsupportedStrides);
}

Expand All @@ -77,7 +78,7 @@ pub fn register_copy_to_bytes(
let dim_x = shape[rank - 1];
let width_bytes = dim_x * elem_size;
let dim_y: usize = shape.iter().rev().skip(1).product();
let pitch = strides[rank - 2] * elem_size;
let pitch = row_pitch_elems(shape, strides).unwrap() * elem_size;
let slice = bytes.deref_mut();

let cpy = CUDA_MEMCPY2D_st {
Expand Down
78 changes: 25 additions & 53 deletions crates/cubecl-cuda/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ use cubecl_cpp::formatter::format_cpp;
use cubecl_cpp::{cuda::arch::CudaArchitecture, shared::CompilationOptions};

use super::storage::gpu::{GpuResource, GpuStorage};
use cubecl_runtime::data_service::DataTransferId;
use cubecl_runtime::logging::ServerLogger;
use cubecl_runtime::stride::{
contiguous_strides, is_contiguous, is_inner_contiguous_rows, pitched_rows_layout,
row_pitch_elems,
};
use cubecl_runtime::{memory_management::offset_handles, timestamp_profiler::TimestampProfiler};

use super::sync::{Fence, SyncStream};
use crate::compute::{
DataTransferItem, DataTransferRuntime, io::register_copies_to_bytes,
Expand All @@ -27,15 +35,14 @@ use cubecl_core::{
ir::FloatKind,
server::{Bindings, CopyDescriptor, TensorMapBinding},
};
use cubecl_runtime::data_service::DataTransferId;
use cubecl_runtime::logging::ServerLogger;
// deduped above: DataTransferId, ServerLogger
use cubecl_runtime::memory_management::MemoryUsage;
use cubecl_runtime::storage::BindingResource;
use cubecl_runtime::{
memory_management::MemoryManagement,
server::{self, ComputeServer},
};
use cubecl_runtime::{memory_management::offset_handles, timestamp_profiler::TimestampProfiler};
// deduped above: offset_handles, TimestampProfiler
use cudarc::driver::sys::{
CUDA_MEMCPY2D_st, CUctx_st, CUfunction_attribute, CUmemorytype, CUtensorMap,
CUtensorMapDataType, CUtensorMapFloatOOBfill, CUtensorMapL2promotion, CUtensorMapSwizzle,
Expand Down Expand Up @@ -145,31 +152,20 @@ impl ComputeServer for CudaServer {
let mut total_size = 0;

for descriptor in descriptors {
let pitch_align = match descriptor.kind {
AllocationKind::Contiguous => 1,
AllocationKind::Optimized => self.mem_alignment,
};

let rank = descriptor.shape.len();
let width = *descriptor.shape.last().unwrap_or(&1);
let height: usize = descriptor.shape.iter().rev().skip(1).product();
let height = height.max(1);
let width_bytes = width * descriptor.elem_size;
let pitch = width_bytes.next_multiple_of(pitch_align);
let size = height * pitch;
total_size += size.next_multiple_of(self.mem_alignment);
let mut stride = vec![1; rank];
if rank > 1 {
stride[rank - 2] = pitch / descriptor.elem_size;
}
if rank > 2 {
for i in (0..rank - 2).rev() {
stride[i] = stride[i + 1] * descriptor.shape[i + 1];
}
if matches!(descriptor.kind, AllocationKind::Optimized) && rank > 1 {
let (s, size) =
pitched_rows_layout(descriptor.shape, descriptor.elem_size, self.mem_alignment);
total_size += size.next_multiple_of(self.mem_alignment);
strides.push(s);
sizes.push(size);
} else {
let s = contiguous_strides(descriptor.shape);
let size = descriptor.shape.iter().product::<usize>() * descriptor.elem_size;
total_size += size.next_multiple_of(self.mem_alignment);
strides.push(s);
sizes.push(size);
}

strides.push(stride);
sizes.push(size);
}

let ctx = self.get_context();
Expand Down Expand Up @@ -197,7 +193,7 @@ impl ComputeServer for CudaServer {
} = descriptor;
let rank = shape.len();

if !valid_strides(shape, strides) {
if !(is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides)) {
return Err(IoError::UnsupportedStrides);
}

Expand All @@ -210,7 +206,8 @@ impl ComputeServer for CudaServer {
let dim_x = shape[rank - 1];
let width_bytes = dim_x * elem_size;
let dim_y: usize = shape.iter().rev().skip(1).product();
let pitch = strides[rank - 2] * elem_size;
let pitch =
row_pitch_elems(shape, strides).unwrap_or(strides[rank - 2]) * elem_size;

let cpy = CUDA_MEMCPY2D_st {
srcMemoryType: CUmemorytype::CU_MEMORYTYPE_HOST,
Expand Down Expand Up @@ -940,28 +937,3 @@ fn oob_to_cuda(fill: OobFill) -> CUtensorMapFloatOOBfill {
OobFill::NaN => CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA,
}
}

pub fn valid_strides(shape: &[usize], strides: &[usize]) -> bool {
let rank = shape.len();
if strides[rank - 1] != 1 {
return false;
}
if rank <= 1 {
return true;
}

let mut sorted = strides.to_vec();
sorted.sort();
sorted.reverse();

if sorted != strides {
return false;
}

for i in 0..rank - 2 {
if strides[i] != shape[i + 1] * strides[i + 1] {
return false;
}
}
true
}
6 changes: 4 additions & 2 deletions crates/cubecl-cuda/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use crate::{
cpu::{PINNED_MEMORY_ALIGNMENT, PinnedMemoryStorage},
gpu::GpuStorage,
},
valid_strides,
},
device::CudaDevice,
};
Expand All @@ -27,6 +26,7 @@ use cubecl_cpp::{
register_wmma_features,
},
};
use cubecl_runtime::stride::{is_contiguous, is_inner_contiguous_rows};
use cubecl_runtime::{
ComputeRuntime, DeviceProperties, Plane, Tma, TypeUsage,
channel::MutexComputeChannel,
Expand Down Expand Up @@ -184,6 +184,8 @@ fn create_client<M: DialectWmmaCompiler<CudaDialect<M>>>(
mem_properties,
hardware_props,
TimingMethod::System,
// Prefer pitched rows by default on CUDA (hardware efficient).
cubecl_runtime::server::AllocationKind::Optimized,
);
register_supported_types(&mut device_props);
device_props.register_type_usage(ElemType::Float(FloatKind::TF32), TypeUsage::Conversion);
Expand Down Expand Up @@ -311,7 +313,7 @@ impl Runtime for CudaRuntime {
}

fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
valid_strides(shape, strides)
is_contiguous(shape, strides) || is_inner_contiguous_rows(shape, strides)
}

fn target_properties() -> TargetProperties {
Expand Down
Loading