-
Notifications
You must be signed in to change notification settings - Fork 346
Description
📝 Background
Currently, ttnn.experimental.conv3d in tt-metal does not support sharded memory layout for input, weight, or bias tensors. The implementation only accepts interleaved memory layouts, enforced by explicit assertions in the device operation code. However, the Conv3d kernels already use TensorAccessor, which provides unified access to both interleaved and sharded tensors. According to the TensorAccessor Guide, this accessor handles the mapping from logical tensor indices to physical memory locations for all tensor distributions, including ND sharding.
Supporting sharded layouts in Conv3d is essential for achieving optimal performance in 3D volumetric applications such as UNet-3D, video analysis, and medical imaging models.
🎯 What Success Looks Like
- Remove the restrictions that force input, weight, and bias tensors to be interleaved.
- Leverage the existing
TensorAccessorinfrastructure—no new kernels should be added. - Support all sharding layouts (HEIGHT_SHARDED, WIDTH_SHARDED, BLOCK_SHARDED) via the ND sharding path in TensorAccessor.
- API accepts an optional
memory_configparameter for specifying sharded output memory configuration. - Properly configure
TensorAccessorArgsfor sharded buffers in the program factory. - Comprehensive tests added to
tests/ttnn/nightly/unit_tests/operations/conv/test_conv3d.pyvalidating:- Numerical correctness against PyTorch Conv3d with sharded inputs/outputs
- Various sharding configurations (height, width, block sharded)
- Edge cases with different shard sizes
💡 Problem to Solve
The Conv3d kernels already use TensorAccessor for memory access:
// reader_vol2col.cpp
constexpr auto in_args = TensorAccessorArgs<28>();
const auto in_reader = TensorAccessor(in_args, in_addr, in_row_size_bytes);
// writer.cpp
constexpr auto out_args = TensorAccessorArgs<22>();
const auto out_writer = TensorAccessor(out_args, out_addr, out_row_size_bytes);TensorAccessor already supports ND sharding transparently—it works for both sharded and interleaved tensors. The only blockers are the explicit assertions in conv3d_device_operation.cpp:
// input and weight must both be interleaved, bfloat16
TT_FATAL(!input_tensor_a.memory_config().is_sharded(), "Activation tensor must be interleaved.");
TT_FATAL(!weight_tensor.memory_config().is_sharded(), "Weight tensor must be interleaved.");
TT_FATAL(!bias_tensor.memory_config().is_sharded(), "Bias tensor must be interleaved.");This bounty requires removing these restrictions and ensuring the TensorAccessorArgs are properly configured for sharded buffers.
🧭 Guidance & Starting Points
- The assertions blocking sharded layouts are in
ttnn/cpp/ttnn/operations/experimental/conv3d/device/conv3d_device_operation.cpp(lines 62, 66, 72). TensorAccessorArgsis already used inconv3d_program_factory.cpp:
tt::tt_metal::TensorAccessorArgs(*input_tensor.buffer()).append_to(reader_compile_time_args);
tt::tt_metal::TensorAccessorArgs(*output_tensor.buffer()).append_to(writer_compile_time_args);
These should work with sharded buffers—verify and test.- The
TensorAccessortech report and iterator guide explain how the accessor handles sharded tensors. - For sharded tensors, the Pages Iterator provides optimized iteration:
tensor_accessor.pages()automatically handles shard-local access patterns. - Verify that output tensor creation in
create_output_tensors()properly supports shardedoutput_mem_config. - Test coverage in
tests/ttnn/nightly/unit_tests/operations/conv/test_conv3d.py.
🔎 Possible Approaches
-
Remove assertions: Delete the
TT_FATALchecks requiring interleaved layout for input, weight, and bias tensors. -
Verify TensorAccessorArgs configuration: Ensure
TensorAccessorArgscorrectly handles sharded buffer metadata. The constructorTensorAccessorArgs(*buffer)should already extract shard specs from the buffer. -
Update output tensor creation: Ensure
create_output_tensors()respectsoutput_mem_configwhen it specifies a sharded layout, creating the output buffer with the appropriate shard spec. -
Verify kernel compatibility: The kernels use
TensorAccessorwith page-based access. Ensure the page iteration logic works correctly with sharded tensors (the iterator handles shard boundaries automatically). -
Add validation: Add appropriate validation for sharding constraints (e.g., shard sizes must be compatible with the convolution access pattern).
-
Test thoroughly: Add test cases covering:
- HEIGHT_SHARDED, WIDTH_SHARDED, BLOCK_SHARDED inputs
- Sharded weights and bias
- Various shard grid configurations
- Comparison against PyTorch for numerical correctness
📚 Resources
- TensorAccessor documentation:
- Tensor Accessor Guide
- Tensor Accessor Iterator Guide
tt_metal/hw/inc/api/tensor/tensor_accessor.h
- Relevant tt-metal Conv3d files:
ttnn/cpp/ttnn/operations/experimental/conv3d/device/conv3d_device_operation.cpp— assertions to removettnn/cpp/ttnn/operations/experimental/conv3d/device/conv3d_program_factory.cpp— TensorAccessorArgs setupttnn/cpp/ttnn/operations/experimental/conv3d/device/kernels/reader_vol2col.cpp— input accessor usagettnn/cpp/ttnn/operations/experimental/conv3d/device/kernels/writer.cpp— output/weight accessor usagetests/ttnn/nightly/unit_tests/operations/conv/test_conv3d.py
- PyTorch Conv3d Documentation
- GitHub Issue #34375 — Original feature request
Metadata
Metadata
Assignees
Labels
Type
Projects
Status