Skip to content

Reduce scatter yields inaccurate results when using bf16 inputs and a fp32 accumulator #37884

@smeydanshahiTT

Description

@smeydanshahiTT

Component / Area

kernels

Issue Type

Inaccurate output. Issue description: #36749

Observed

Given bf16 inputs, the fp32 accumulator provides outputs with too low PCC due to the iterative rounding done.

Expected

reduce_scatter op yields valid output when using bf16 input.

The following test case fails for the fp32 = true instance, as it is taking in bf16 inputs.

import torch
import pytest
from loguru import logger
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc


def create_global_semaphores(mesh_device, cores, initial_value, count=3):
    return [ttnn.create_global_semaphore(mesh_device, cores, initial_value) for _ in range(count)]


@pytest.mark.parametrize("num_devices", [8])
@pytest.mark.parametrize("rs_input_shape", [[1, 1, 32, 1024]])
@pytest.mark.parametrize("dim", [3])
@pytest.mark.parametrize("num_links", [1])
@pytest.mark.parametrize("fp32_dest_acc_en", [True, False])
@pytest.mark.parametrize(
    "device_params",
    [{"fabric_config": ttnn.FabricConfig.FABRIC_1D}],
    indirect=True,
)
@pytest.mark.parametrize("mesh_device", [(1, 8)], indirect=True)
def test_reduce_scatter_bf16_precision(
    mesh_device,
    num_devices,
    rs_input_shape,
    dim,
    num_links,
    fp32_dest_acc_en,
):
    if mesh_device.get_num_devices() < num_devices:
        pytest.skip(f"Requires at least {num_devices} devices")

    torch.manual_seed(42)
    mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM)

    compute_grid_size = mesh_device.compute_with_storage_grid_size()
    ccl_sub_device_crs = ttnn.CoreRangeSet(
        {ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(compute_grid_size.x - 1, compute_grid_size.y - 1))}
    )
    worker_sub_device = ttnn.SubDevice([ccl_sub_device_crs])
    worker_sub_device_id = ttnn.SubDeviceId(0)
    sub_device_stall_group = [worker_sub_device_id]

    sub_device_manager = mesh_device.create_sub_device_manager([worker_sub_device], 0)
    mesh_device.load_sub_device_manager(sub_device_manager)
    mesh_device.set_sub_device_stall_group(sub_device_stall_group)

    rs_global_input_shape = rs_input_shape[:]
    rs_global_input_shape[dim] *= num_devices
    # Values near 8000 so sum of 8 is ~64,000
    torch_input = torch.rand(rs_global_input_shape).float() * 1000 + 7500
    torch_input_bf16 = torch_input.bfloat16()

    input_for_golden = torch_input_bf16.float()
    input_chunks = torch.chunk(input_for_golden, num_devices, dim)
    golden_reduce = torch.sum(torch.stack(input_chunks), dim=0)

    compute_kernel_config = ttnn.WormholeComputeKernelConfig(
        fp32_dest_acc_en=fp32_dest_acc_en,
        math_fidelity=ttnn.MathFidelity.HiFi4,
    )

    # --- Method 1: reduce_scatter + all_gather (Hardware Path) ---
    rs_semaphore_handles = create_global_semaphores(mesh_device, ccl_sub_device_crs, 0, count=3)

    intermediate_shape = rs_input_shape[:]
    intermediate_shape.insert(0, 2)
    persistent_intermediate = ttnn.from_torch(
        torch.zeros(intermediate_shape),
        device=mesh_device,
        layout=ttnn.TILE_LAYOUT,
        dtype=ttnn.bfloat16,
        memory_config=mem_config,
        mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
    )

    rs_output_shape = rs_input_shape[:]
    rs_output_shape[dim] //= num_devices
    persistent_output = ttnn.from_torch(
        torch.zeros(rs_output_shape),
        device=mesh_device,
        layout=ttnn.TILE_LAYOUT,
        dtype=ttnn.bfloat16,
        memory_config=mem_config,
        mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
    )

    input_rs = ttnn.from_torch(
        torch_input_bf16,
        device=mesh_device,
        layout=ttnn.TILE_LAYOUT,
        dtype=ttnn.bfloat16,
        memory_config=mem_config,
        mesh_mapper=ttnn.create_mesh_mapper(
            mesh_device,
            ttnn.MeshMapperConfig(
                [ttnn.PlacementReplicate(), ttnn.PlacementShard(dim)], ttnn.MeshShape(1, num_devices)
            ),
        ),
    )

    # Step 1: reduce_scatter using the new compute_kernel_config
    rs_output = ttnn.experimental.reduce_scatter_minimal_async(
        input_rs,
        persistent_output_buffers=[persistent_intermediate, persistent_output],
        dim=dim,
        multi_device_global_semaphore=rs_semaphore_handles,
        num_links=num_links,
        memory_config=mem_config,
        topology=ttnn.Topology.Linear,
        subdevice_id=worker_sub_device_id,
        compute_kernel_config=compute_kernel_config,
    )
    ttnn.synchronize_device(mesh_device, sub_device_ids=sub_device_stall_group)

    # Step 2: all_gather to get full reduction result back for comparison
    ag_semaphore_handles = create_global_semaphores(mesh_device, ccl_sub_device_crs, 0, count=2)
    all_reduce_output = ttnn.experimental.all_gather_async(
        rs_output,
        dim=dim,
        multi_device_global_semaphore=ag_semaphore_handles,
        num_links=num_links,
        memory_config=mem_config,
        topology=ttnn.Topology.Linear,
        subdevice_id=worker_sub_device_id,
    )
    ttnn.synchronize_device(mesh_device, sub_device_ids=sub_device_stall_group)

    all_reduce_torch = ttnn.to_torch(
        ttnn.from_device(all_reduce_output), mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)
    )
    rs_ag_result = all_reduce_torch[0:1]

    # --- FINAL VALIDATION ---
    _, rs_ag_pcc = comp_pcc(rs_ag_result.float(), golden_reduce, pcc=0.99)
    rs_ag_max_diff = (rs_ag_result.float() - golden_reduce).abs().max().item()

    mesh_device.reset_sub_device_stall_group()
    mesh_device.clear_loaded_sub_device_manager()

    # Split the PCC string to get a float for the assertion
    pcc_val = float(rs_ag_pcc.split("PCC: ")[1].split(",")[0]) if "PCC: " in rs_ag_pcc else 1.0

    logger.info(f"Configuration: fp32_dest_acc_en={fp32_dest_acc_en}")
    logger.info(f"Result: PCC={pcc_val}, Max Diff={rs_ag_max_diff:.2f}")

    if fp32_dest_acc_en:
        # High precision path should have near-perfect PCC and low Max Diff
        assert pcc_val > 0.999, f"Precision regression with FP32 accumulation enabled: PCC={pcc_val}"
        assert rs_ag_max_diff < 1.0, f"Max difference too high with FP32 accumulation: {rs_ag_max_diff}"
    else:
        # Default path is expected to show the ~768.0 delta / lower PCC
        logger.info("Observed expected precision loss on default (FP16) accumulation path.")
        assert pcc_val < 0.99, "Default path unexpectedly high precision; check if flag is being ignored."

The result shows that the PCC is still incorrect, and the same value as when we use bf16 inputs with a bf16 accumulator (the first test passing represents the incorrect value is correctly inaccurate):

PASSED tests/ttnn/unit_tests/operations/ccl/a.py::test_reduce_scatter_bf16_precision[silicon_arch_name=wormhole_b0-mesh_device=(1, 8)-device_params={'fabric_config': FabricConfig.FABRIC_1D}-fp32_dest_acc_en=False-num_links=1-dim=3-rs_input_shape=[1, 1, 32, 1024]-num_devices=8]
FAILED tests/ttnn/unit_tests/operations/ccl/a.py::test_reduce_scatter_bf16_precision[silicon_arch_name=wormhole_b0-mesh_device=(1, 8)-device_params={'fabric_config': FabricConfig.FABRIC_1D}-fp32_dest_acc_en=True-num_links=1-dim=3-rs_input_shape=[1, 1, 32, 1024]-num_devices=8] - AssertionError: Precision regression with FP32 accumulation enabled: PCC=0.9856133555412304
assert 0.9856133555412304 > 0.999

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    Projects

    Status

    🆕 New

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions