Skip to content

Significant precision error in Sage3 causal mode #340

@Edenzzzz

Description

@Edenzzzz

When I compare the precision of Sage3 causal against torch SDPA, the max diff is around 2.7 even when only using one

Image Image

Test script


import torch
import torch.nn.functional as F
from math import sqrt
import sys
import os

# Add root directory to sys.path to allow importing attn_qat
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

try:
    from sageattn3 import sageattn3_blackwell
except ImportError:
    print("Could not import sageattn3_blackwell from sageattn3. Please ensure it is installed.")
    # For testing purposes without the library, we might want to mock it or exit.
    # Assuming the environment has it as per user request.
    sys.exit(1)

def test_error_comparison():
    # Parameters
    B = 2
    D = 128
    dtype = torch.bfloat16
    device = torch.device("cuda")
    
    # Expanded configurations to test
    L_values = [512, 1024, 2048, 4096, 8192]  # Various sequence lengths
    # H_values = [8, 16, 32, 64]  # Various number of heads
    H_values = [1]
    
    # Test both causal and non-causal attention
    causal_values = [True, False]
    
    num_iters = 10
    num_top_diffs = 20  # Number of top diffs to log when max diff > threshold
    
    print(f"Benchmarking with B={B}, D={D}, dtype={dtype}")
    print(f"Configurations (L x H): {[f'{l}x{h}' for l in L_values for h in H_values]}")
    print(f"Testing both causal and non-causal attention")
    print(f"Running for {num_iters} iterations per config...")

    for causal in causal_values:
        causal_str = "Causal" if causal else "Non-Causal"
        print(f"\n{'='*60}")
        print(f"Testing {causal_str} Attention")
        print(f"{'='*60}")
        
        for L in L_values:
            for H in H_values:
                print(f"\n==========================================")
                print(f"Testing Config: L={L}, H={H}, Causal={causal}")
                print(f"==========================================")

                total_max_diff_sage = 0.0
                total_mean_diff_sage = 0.0
                sage_valid_count = 0
                sage_top_diffs_per_iter = []  # Store top N diffs per iteration

                for i in range(num_iters):
                    # Initialize tensors in BHLD format (standard for Torch/Sage)
                    torch.manual_seed(42 + i)
                    q = torch.randn((B, H, L, D), dtype=dtype, device=device)
                    k = torch.randn((B, H, L, D), dtype=dtype, device=device)
                    v = torch.randn((B, H, L, D), dtype=dtype, device=device)
                    
                    sm_scale = 1.0 / sqrt(D)
                    
                    # 1. Torch BF16 Attention (Reference)
                    ref_out = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
                    
                    # 2. Sage3 Attention
                    try:
                        sage_out = sageattn3_blackwell(q, k, v, is_causal=causal)
                    except Exception:
                        sage_out = None
                    
                    # Compare Sage vs Torch
                    if sage_out is not None:
                        diff_sage = (sage_out - ref_out).abs()
                        max_diff_sage = diff_sage.max().item()
                        total_max_diff_sage += max_diff_sage
                        total_mean_diff_sage += diff_sage.mean().item()
                        sage_valid_count += 1
                        # Get top N diffs for this iteration
                        topN_this_iter = torch.topk(diff_sage.flatten(), k=min(num_top_diffs, diff_sage.numel())).values.cpu().tolist()
                        sage_top_diffs_per_iter.append(topN_this_iter)

                    
                    
                    print(f"Iteration {i+1}/{num_iters} done.", end='\r')

                print("\n--- Average Results ---")
                
                # Sage3 Results
                if sage_valid_count > 0:
                    avg_max_sage = total_max_diff_sage / sage_valid_count
                    avg_mean_sage = total_mean_diff_sage / sage_valid_count
                    print(f"Sage3 vs Torch BF16 (over {sage_valid_count} runs): Avg Max Diff = {avg_max_sage:.6f}, Avg Mean Diff = {avg_mean_sage:.6f}")
                    
                    # If avg max diff > 0.5, log the average top N diffs
                    if avg_max_sage > 5e-1:
                        # Average the top N diffs across all iterations
                        # Each element in sage_top_diffs_per_iter is a list of top N for that iteration
                        # We want to average the k-th largest across all iterations
                        num_top_values = min(len(diffs) for diffs in sage_top_diffs_per_iter) if sage_top_diffs_per_iter else 0
                        if num_top_values > 0:
                            avg_topN = [sum(diffs[i] for diffs in sage_top_diffs_per_iter) / len(sage_top_diffs_per_iter) 
                                        for i in range(num_top_values)]
                            print(f"  Avg Top {num_top_values} largest diffs: {[f'{d:.6f}' for d in avg_topN]}")
                else:
                    print("Sage3 vs Torch BF16: Skipped (all executions failed or not available)")

if __name__ == "__main__":
    test_error_comparison()

cc @jt-zhang

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions