-
Notifications
You must be signed in to change notification settings - Fork 347
Open
Open
Copy link
Labels
Description
Component / Area
kernels
Issue Type
Inaccurate output. Issue description: #36749
Observed
Given bf16 inputs, the fp32 accumulator provides outputs with too low PCC due to the iterative rounding done.
Expected
reduce_scatter op yields valid output when using bf16 input.
The following test case fails for the fp32 = true instance, as it is taking in bf16 inputs.
import torch
import pytest
from loguru import logger
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc
def create_global_semaphores(mesh_device, cores, initial_value, count=3):
return [ttnn.create_global_semaphore(mesh_device, cores, initial_value) for _ in range(count)]
@pytest.mark.parametrize("num_devices", [8])
@pytest.mark.parametrize("rs_input_shape", [[1, 1, 32, 1024]])
@pytest.mark.parametrize("dim", [3])
@pytest.mark.parametrize("num_links", [1])
@pytest.mark.parametrize("fp32_dest_acc_en", [True, False])
@pytest.mark.parametrize(
"device_params",
[{"fabric_config": ttnn.FabricConfig.FABRIC_1D}],
indirect=True,
)
@pytest.mark.parametrize("mesh_device", [(1, 8)], indirect=True)
def test_reduce_scatter_bf16_precision(
mesh_device,
num_devices,
rs_input_shape,
dim,
num_links,
fp32_dest_acc_en,
):
if mesh_device.get_num_devices() < num_devices:
pytest.skip(f"Requires at least {num_devices} devices")
torch.manual_seed(42)
mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM)
compute_grid_size = mesh_device.compute_with_storage_grid_size()
ccl_sub_device_crs = ttnn.CoreRangeSet(
{ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(compute_grid_size.x - 1, compute_grid_size.y - 1))}
)
worker_sub_device = ttnn.SubDevice([ccl_sub_device_crs])
worker_sub_device_id = ttnn.SubDeviceId(0)
sub_device_stall_group = [worker_sub_device_id]
sub_device_manager = mesh_device.create_sub_device_manager([worker_sub_device], 0)
mesh_device.load_sub_device_manager(sub_device_manager)
mesh_device.set_sub_device_stall_group(sub_device_stall_group)
rs_global_input_shape = rs_input_shape[:]
rs_global_input_shape[dim] *= num_devices
# Values near 8000 so sum of 8 is ~64,000
torch_input = torch.rand(rs_global_input_shape).float() * 1000 + 7500
torch_input_bf16 = torch_input.bfloat16()
input_for_golden = torch_input_bf16.float()
input_chunks = torch.chunk(input_for_golden, num_devices, dim)
golden_reduce = torch.sum(torch.stack(input_chunks), dim=0)
compute_kernel_config = ttnn.WormholeComputeKernelConfig(
fp32_dest_acc_en=fp32_dest_acc_en,
math_fidelity=ttnn.MathFidelity.HiFi4,
)
# --- Method 1: reduce_scatter + all_gather (Hardware Path) ---
rs_semaphore_handles = create_global_semaphores(mesh_device, ccl_sub_device_crs, 0, count=3)
intermediate_shape = rs_input_shape[:]
intermediate_shape.insert(0, 2)
persistent_intermediate = ttnn.from_torch(
torch.zeros(intermediate_shape),
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat16,
memory_config=mem_config,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
)
rs_output_shape = rs_input_shape[:]
rs_output_shape[dim] //= num_devices
persistent_output = ttnn.from_torch(
torch.zeros(rs_output_shape),
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat16,
memory_config=mem_config,
mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device),
)
input_rs = ttnn.from_torch(
torch_input_bf16,
device=mesh_device,
layout=ttnn.TILE_LAYOUT,
dtype=ttnn.bfloat16,
memory_config=mem_config,
mesh_mapper=ttnn.create_mesh_mapper(
mesh_device,
ttnn.MeshMapperConfig(
[ttnn.PlacementReplicate(), ttnn.PlacementShard(dim)], ttnn.MeshShape(1, num_devices)
),
),
)
# Step 1: reduce_scatter using the new compute_kernel_config
rs_output = ttnn.experimental.reduce_scatter_minimal_async(
input_rs,
persistent_output_buffers=[persistent_intermediate, persistent_output],
dim=dim,
multi_device_global_semaphore=rs_semaphore_handles,
num_links=num_links,
memory_config=mem_config,
topology=ttnn.Topology.Linear,
subdevice_id=worker_sub_device_id,
compute_kernel_config=compute_kernel_config,
)
ttnn.synchronize_device(mesh_device, sub_device_ids=sub_device_stall_group)
# Step 2: all_gather to get full reduction result back for comparison
ag_semaphore_handles = create_global_semaphores(mesh_device, ccl_sub_device_crs, 0, count=2)
all_reduce_output = ttnn.experimental.all_gather_async(
rs_output,
dim=dim,
multi_device_global_semaphore=ag_semaphore_handles,
num_links=num_links,
memory_config=mem_config,
topology=ttnn.Topology.Linear,
subdevice_id=worker_sub_device_id,
)
ttnn.synchronize_device(mesh_device, sub_device_ids=sub_device_stall_group)
all_reduce_torch = ttnn.to_torch(
ttnn.from_device(all_reduce_output), mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0)
)
rs_ag_result = all_reduce_torch[0:1]
# --- FINAL VALIDATION ---
_, rs_ag_pcc = comp_pcc(rs_ag_result.float(), golden_reduce, pcc=0.99)
rs_ag_max_diff = (rs_ag_result.float() - golden_reduce).abs().max().item()
mesh_device.reset_sub_device_stall_group()
mesh_device.clear_loaded_sub_device_manager()
# Split the PCC string to get a float for the assertion
pcc_val = float(rs_ag_pcc.split("PCC: ")[1].split(",")[0]) if "PCC: " in rs_ag_pcc else 1.0
logger.info(f"Configuration: fp32_dest_acc_en={fp32_dest_acc_en}")
logger.info(f"Result: PCC={pcc_val}, Max Diff={rs_ag_max_diff:.2f}")
if fp32_dest_acc_en:
# High precision path should have near-perfect PCC and low Max Diff
assert pcc_val > 0.999, f"Precision regression with FP32 accumulation enabled: PCC={pcc_val}"
assert rs_ag_max_diff < 1.0, f"Max difference too high with FP32 accumulation: {rs_ag_max_diff}"
else:
# Default path is expected to show the ~768.0 delta / lower PCC
logger.info("Observed expected precision loss on default (FP16) accumulation path.")
assert pcc_val < 0.99, "Default path unexpectedly high precision; check if flag is being ignored."
The result shows that the PCC is still incorrect, and the same value as when we use bf16 inputs with a bf16 accumulator (the first test passing represents the incorrect value is correctly inaccurate):
PASSED tests/ttnn/unit_tests/operations/ccl/a.py::test_reduce_scatter_bf16_precision[silicon_arch_name=wormhole_b0-mesh_device=(1, 8)-device_params={'fabric_config': FabricConfig.FABRIC_1D}-fp32_dest_acc_en=False-num_links=1-dim=3-rs_input_shape=[1, 1, 32, 1024]-num_devices=8]
FAILED tests/ttnn/unit_tests/operations/ccl/a.py::test_reduce_scatter_bf16_precision[silicon_arch_name=wormhole_b0-mesh_device=(1, 8)-device_params={'fabric_config': FabricConfig.FABRIC_1D}-fp32_dest_acc_en=True-num_links=1-dim=3-rs_input_shape=[1, 1, 32, 1024]-num_devices=8] - AssertionError: Precision regression with FP32 accumulation enabled: PCC=0.9856133555412304
assert 0.9856133555412304 > 0.999
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
🆕 New