Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 0 additions & 162 deletions tests/nightly/tg/ccl/test_minimal_all_gather_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,8 @@
import ttnn

from tests.nightly.t3000.ccl.test_minimal_all_gather_async import run_all_gather_impl
from tests.ttnn.multidevice_perf_tests.sweep_all_gather_hyperparameters_t3000 import get_max_chunks_per_sync
from models.common.utility_functions import skip_for_blackhole, skip_for_wormhole_b0
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
Copy link

Copilot AI Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comp_equal and comp_pcc are imported but only referenced in commented-out code, so they are currently unused. If the repo runs ruff/flake8 in CI, this will fail with an unused-import error; please remove the unused imports or re-enable the assertions that use them.

Suggested change
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc

Copilot uses AI. Check for mistakes.
from models.perf.benchmarking_utils import BenchmarkProfiler
from tracy import signpost


def create_fabric_router_config(max_payload_size):
"""Helper to create FabricRouterConfig with custom max payload size."""
config = ttnn._ttnn.fabric.FabricRouterConfig()
config.max_packet_payload_size_bytes = max_payload_size
return config


@skip_for_blackhole("This test is for wormhole")
Expand Down Expand Up @@ -564,155 +554,3 @@ def test_all_gather_async_wan_galaxy_4x32(
# torch_reference = torch_input.repeat([devices, 1, 1, 1])
# eq, output = comp_equal(torch_output, torch_reference)
# assert eq, f"Output mismatch between torch and ttnn all-gather: {output}"


@pytest.mark.parametrize(
"ag_output_shape, dim, cluster_axis, ag_input_dtype, layout, mem_config_input, mem_config_ag",
[
([1, 1, 9472, 5120], 3, 0, ttnn.bfloat16, ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG),
([1, 1, 9472, 256], 3, 0, ttnn.bfloat16, ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG),
([1, 1, 9472, 128], 3, 0, ttnn.bfloat16, ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG),
([1, 1, 118, 128], 3, 0, ttnn.bfloat16, ttnn.TILE_LAYOUT, ttnn.DRAM_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG),
],
ids=[
"spatial_activation",
"layernorm_stats",
"rmsnorm_stats_spatial",
"rmsnorm_stats_prompt",
],
)
@pytest.mark.parametrize(
"device_params, all_gather_topology",
[
(
{
"fabric_config": ttnn.FabricConfig.FABRIC_1D_RING,
"fabric_router_config": create_fabric_router_config(8192),
"trace_region_size": 1000000,
},
ttnn.Topology.Ring,
),
],
indirect=["device_params"],
ids=["fabric_ring"],
)
@pytest.mark.parametrize("num_links", [2], ids=lambda v: f"{v}links")
@pytest.mark.parametrize("chunks_per_sync", [320], ids=lambda v: f"{v}chunks")
@pytest.mark.parametrize("num_workers_per_link", [3], ids=lambda v: f"{v}workers")
@pytest.mark.parametrize("num_buffers_per_channel", [4], ids=lambda v: f"{v}buffers")
@pytest.mark.parametrize("num_iters, warmup_iters", [(75, 10)])
@pytest.mark.parametrize("mesh_device", [(4, 8)], indirect=True)
def test_all_gather_wan(
mesh_device,
ag_output_shape,
dim,
cluster_axis,
ag_input_dtype,
layout,
mem_config_input,
mem_config_ag,
num_links,
chunks_per_sync,
num_workers_per_link,
num_buffers_per_channel,
all_gather_topology,
num_iters,
warmup_iters,
):
from loguru import logger

# Create input tensor
mesh_shape = tuple(mesh_device.shape)
input_shape = ag_output_shape
num_devices = mesh_shape[cluster_axis]

torch.manual_seed(2005)
torch_input = torch.rand(input_shape, dtype=torch.bfloat16)

shard_dims = (None, dim) if cluster_axis == 1 else (dim, None)
tt_input = ttnn.from_torch(
torch_input,
layout=layout,
dtype=ag_input_dtype,
memory_config=mem_config_input,
mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=shard_dims, mesh_shape=mesh_shape),
device=mesh_device,
)

# AllGather config
if chunks_per_sync == "MAX":
chunks_per_sync_val = get_max_chunks_per_sync(num_devices, ag_output_shape, num_links)
else:
chunks_per_sync_val = chunks_per_sync

# Compile Run
logger.info("Compiling op")
tt_output = ttnn.all_gather(
tt_input,
dim=dim,
cluster_axis=cluster_axis,
topology=all_gather_topology,
num_links=num_links,
memory_config=mem_config_ag,
chunks_per_sync=chunks_per_sync_val,
num_workers_per_link=num_workers_per_link,
num_buffers_per_channel=num_buffers_per_channel,
)
ttnn.synchronize_device(mesh_device)

# Check output
errors = []
for dev, tt_out in enumerate(ttnn.get_device_tensors(tt_output)):
eq, mess = comp_pcc(torch_input, ttnn.to_torch(tt_out))
if not eq:
errors.append(f"Device {dev}: {mess}")
assert not errors, f"PCC check failed on {len(errors)} device(s):\n" + "\n".join(errors)

################## TRACE RUN #######################

# Capture trace
logger.info("Capturing trace")

def capture_trace(n_iters):
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(n_iters):
_ = ttnn.all_gather(
tt_input,
dim=dim,
cluster_axis=cluster_axis,
topology=all_gather_topology,
num_links=num_links,
memory_config=mem_config_ag,
chunks_per_sync=chunks_per_sync_val,
num_workers_per_link=num_workers_per_link,
num_buffers_per_channel=num_buffers_per_channel,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
ttnn.synchronize_device(mesh_device)
return trace_id

if warmup_iters > 0:
trace_id_warmup = capture_trace(warmup_iters)
trace_id = capture_trace(num_iters)

# Run the op
logger.info("Starting Trace perf test...")
profiler = BenchmarkProfiler()
profiler.start("all-gather-async-trace-warmup")
if warmup_iters > 0:
ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False)
ttnn.release_trace(mesh_device, trace_id_warmup)
ttnn.synchronize_device(mesh_device)
profiler.end("all-gather-async-trace-warmup")

profiler.start("all-gather-async-trace")
signpost("start")
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)
ttnn.synchronize_device(mesh_device)
signpost("stop")
profiler.end("all-gather-async-trace")
time_taken = profiler.get_duration("all-gather-async-trace")
logger.info(f"Time taken e2e: {time_taken} s")
logger.info(f"Time per iter e2e: {time_taken / num_iters} s")
logger.info(f"Time per iter e2e: {time_taken / num_iters * 1e6} us")
Original file line number Diff line number Diff line change
Expand Up @@ -383,12 +383,6 @@ AllGatherProgramArtifacts build_all_gather_async_minimal_default_program_artifac
// L1 Scratch CB Creation
const size_t packet_size_bytes = tt::tt_fabric::get_tt_fabric_channel_buffer_size_bytes();
uint32_t l1_scratch_cb_page_size_bytes = page_size;
TT_FATAL(
packet_size_bytes >= l1_scratch_cb_page_size_bytes,
"Fabric packet size ({} bytes) must be >= tensor page size ({} bytes). "
"Increase max_packet_payload_size_bytes in FabricRouterConfig.",
packet_size_bytes,
l1_scratch_cb_page_size_bytes);

// scatter-write currently supports 4 distinct noc addresses
uint32_t max_target_noc_addresses_per_packet = 4;
Expand Down
Loading
Loading