Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
477dac6
balanced traffic for AG minimal on 4-device ring, opus 4.5, attempt 1.
llongTT Jan 28, 2026
9fbcc3d
opus 4.5 attempt 2.
llongTT Jan 29, 2026
93ecdb5
opus 4.5 attempt 3
llongTT Jan 29, 2026
c885302
Local claude fixed the hang issue. Great.
llongTT Jan 29, 2026
55be2ab
Merge branch 'main' into llong/dit_ag_min
llongTT Feb 3, 2026
4f60d5e
introduce max_payload_size to boost fabric bandwidth.
Feb 3, 2026
ca45c88
put zone scopes in reader/writer kernels for tracyi gui.
Feb 4, 2026
dc032c1
remove the zone scope from kernels.
Feb 4, 2026
276d47f
Merge branch 'main' into llong/dit_ag_min
Feb 9, 2026
8eea24d
Add back some statements removed by Agent. It's safer to keep it.
Feb 9, 2026
465de17
split local writes half/half between forward/backward worker.
Feb 9, 2026
cc17165
extend the split forward feature to ring size >2 and ring size even. …
Feb 10, 2026
283ced2
address the copilot suggestion to pass pipeline.
Feb 10, 2026
6b0163b
Merge branch 'main' into llong/dit_ag_min
llongTT Feb 11, 2026
c4d3ed5
Merge branch 'main' into llong/dit_ag_min
llongTT Feb 11, 2026
862e68b
some update from pipeline feedback/sheran.
Feb 12, 2026
5998e2c
Merge branch 'main' into llong/dit_ag_min
llongTT Feb 12, 2026
340ccf5
update the fabric max payload size to precisely 8K
Feb 12, 2026
a1fb5fe
revert the local write split feature.
Feb 12, 2026
971b544
revert the whisper model perf change.
Feb 12, 2026
62d1d85
skip the new unit test on wormhole due to memory limit.
llongTT Feb 13, 2026
9e3a7f3
Merge branch 'main' into llong/dit_ag_min
llongTT Feb 13, 2026
5c5ade9
Merge branch 'main' into llong/dit_ag_min
llongTT Feb 14, 2026
b3f7d27
some treatment on edge case when number of tile to read = 1.
Feb 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions tests/nightly/tg/ccl/test_minimal_all_gather_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,18 @@
import ttnn

from tests.nightly.t3000.ccl.test_minimal_all_gather_async import run_all_gather_impl
from tests.ttnn.multidevice_perf_tests.sweep_all_gather_hyperparameters_t3000 import get_max_chunks_per_sync
from models.common.utility_functions import skip_for_blackhole, skip_for_wormhole_b0
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from models.perf.benchmarking_utils import BenchmarkProfiler
from tracy import signpost


def create_fabric_router_config(max_payload_size):
"""Helper to create FabricRouterConfig with custom max payload size."""
config = ttnn._ttnn.fabric.FabricRouterConfig()
config.max_packet_payload_size_bytes = max_payload_size
return config


@skip_for_blackhole("This test is for wormhole")
Expand Down Expand Up @@ -554,3 +564,156 @@ def test_all_gather_async_wan_galaxy_4x32(
# torch_reference = torch_input.repeat([devices, 1, 1, 1])
# eq, output = comp_equal(torch_output, torch_reference)
# assert eq, f"Output mismatch between torch and ttnn all-gather: {output}"


@skip_for_wormhole_b0("This test is for blackhole")
@pytest.mark.parametrize(
"ag_output_shape, dim, cluster_axis, ag_input_dtype, layout, mem_config_input, mem_config_ag",
[
([1, 1, 9472, 5120], 3, 0, ttnn.bfloat16, ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG),
([1, 1, 9472, 256], 3, 0, ttnn.bfloat16, ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG),
([1, 1, 9472, 128], 3, 0, ttnn.bfloat16, ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG),
([1, 1, 118, 128], 3, 0, ttnn.bfloat16, ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG),
],
ids=[
"spatial_activation",
"layernorm_stats",
"rmsnorm_stats_spatial",
"rmsnorm_stats_prompt",
],
)
@pytest.mark.parametrize(
"device_params, all_gather_topology",
[
(
{
"fabric_config": ttnn.FabricConfig.FABRIC_1D_RING,
"fabric_router_config": create_fabric_router_config(8192),
"trace_region_size": 1000000,
},
ttnn.Topology.Ring,
),
],
indirect=["device_params"],
ids=["fabric_ring"],
)
@pytest.mark.parametrize("num_links", [2], ids=lambda v: f"{v}links")
@pytest.mark.parametrize("chunks_per_sync", [320], ids=lambda v: f"{v}chunks")
@pytest.mark.parametrize("num_workers_per_link", [3], ids=lambda v: f"{v}workers")
@pytest.mark.parametrize("num_buffers_per_channel", [4], ids=lambda v: f"{v}buffers")
@pytest.mark.parametrize("num_iters, warmup_iters", [(75, 10)])
@pytest.mark.parametrize("mesh_device", [(4, 8)], indirect=True)
def test_all_gather_wan(
mesh_device,
ag_output_shape,
dim,
cluster_axis,
ag_input_dtype,
layout,
mem_config_input,
mem_config_ag,
num_links,
chunks_per_sync,
num_workers_per_link,
num_buffers_per_channel,
all_gather_topology,
num_iters,
warmup_iters,
):
from loguru import logger

# Create input tensor
mesh_shape = tuple(mesh_device.shape)
input_shape = ag_output_shape
num_devices = mesh_shape[cluster_axis]

torch.manual_seed(2005)
torch_input = torch.rand(input_shape, dtype=torch.bfloat16)

shard_dims = (None, dim) if cluster_axis == 1 else (dim, None)
tt_input = ttnn.from_torch(
torch_input,
layout=layout,
dtype=ag_input_dtype,
memory_config=mem_config_input,
mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=shard_dims, mesh_shape=mesh_shape),
device=mesh_device,
)

# AllGather config
if chunks_per_sync == "MAX":
chunks_per_sync_val = get_max_chunks_per_sync(num_devices, ag_output_shape, num_links)
else:
chunks_per_sync_val = chunks_per_sync

# Compile Run
logger.info("Compiling op")
tt_output = ttnn.all_gather(
tt_input,
dim=dim,
cluster_axis=cluster_axis,
topology=all_gather_topology,
num_links=num_links,
memory_config=mem_config_ag,
chunks_per_sync=chunks_per_sync_val,
num_workers_per_link=num_workers_per_link,
num_buffers_per_channel=num_buffers_per_channel,
)
ttnn.synchronize_device(mesh_device)

# Check output
errors = []
for dev, tt_out in enumerate(ttnn.get_device_tensors(tt_output)):
eq, mess = comp_pcc(torch_input, ttnn.to_torch(tt_out))
if not eq:
errors.append(f"Device {dev}: {mess}")
assert not errors, f"PCC check failed on {len(errors)} device(s):\n" + "\n".join(errors)

################## TRACE RUN #######################

# Capture trace
logger.info("Capturing trace")

def capture_trace(n_iters):
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(n_iters):
_ = ttnn.all_gather(
tt_input,
dim=dim,
cluster_axis=cluster_axis,
topology=all_gather_topology,
num_links=num_links,
memory_config=mem_config_ag,
chunks_per_sync=chunks_per_sync_val,
num_workers_per_link=num_workers_per_link,
num_buffers_per_channel=num_buffers_per_channel,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
ttnn.synchronize_device(mesh_device)
return trace_id

if warmup_iters > 0:
trace_id_warmup = capture_trace(warmup_iters)
trace_id = capture_trace(num_iters)

# Run the op
logger.info("Starting Trace perf test...")
profiler = BenchmarkProfiler()
profiler.start("all-gather-async-trace-warmup")
if warmup_iters > 0:
ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False)
ttnn.release_trace(mesh_device, trace_id_warmup)
ttnn.synchronize_device(mesh_device)
profiler.end("all-gather-async-trace-warmup")

profiler.start("all-gather-async-trace")
signpost("start")
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)
ttnn.synchronize_device(mesh_device)
signpost("stop")
profiler.end("all-gather-async-trace")
time_taken = profiler.get_duration("all-gather-async-trace")
logger.info(f"Time taken e2e: {time_taken} s")
logger.info(f"Time per iter e2e: {time_taken / num_iters} s")
logger.info(f"Time per iter e2e: {time_taken / num_iters * 1e6} us")
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,12 @@ AllGatherProgramArtifacts build_all_gather_async_minimal_default_program_artifac
// L1 Scratch CB Creation
const size_t packet_size_bytes = tt::tt_fabric::get_tt_fabric_channel_buffer_size_bytes();
uint32_t l1_scratch_cb_page_size_bytes = page_size;
TT_FATAL(
packet_size_bytes >= l1_scratch_cb_page_size_bytes,
"Fabric packet size ({} bytes) must be >= tensor page size ({} bytes). "
"Increase max_packet_payload_size_bytes in FabricRouterConfig.",
packet_size_bytes,
l1_scratch_cb_page_size_bytes);

// scatter-write currently supports 4 distinct noc addresses
uint32_t max_target_noc_addresses_per_packet = 4;
Expand Down
Loading
Loading