-
Notifications
You must be signed in to change notification settings - Fork 360
Open
Description
When I compare the precision of Sage3 causal against torch SDPA, the max diff is around 2.7 even when only using one
Test script
import torch
import torch.nn.functional as F
from math import sqrt
import sys
import os
# Add root directory to sys.path to allow importing attn_qat
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
from sageattn3 import sageattn3_blackwell
except ImportError:
print("Could not import sageattn3_blackwell from sageattn3. Please ensure it is installed.")
# For testing purposes without the library, we might want to mock it or exit.
# Assuming the environment has it as per user request.
sys.exit(1)
def test_error_comparison():
# Parameters
B = 2
D = 128
dtype = torch.bfloat16
device = torch.device("cuda")
# Expanded configurations to test
L_values = [512, 1024, 2048, 4096, 8192] # Various sequence lengths
# H_values = [8, 16, 32, 64] # Various number of heads
H_values = [1]
# Test both causal and non-causal attention
causal_values = [True, False]
num_iters = 10
num_top_diffs = 20 # Number of top diffs to log when max diff > threshold
print(f"Benchmarking with B={B}, D={D}, dtype={dtype}")
print(f"Configurations (L x H): {[f'{l}x{h}' for l in L_values for h in H_values]}")
print(f"Testing both causal and non-causal attention")
print(f"Running for {num_iters} iterations per config...")
for causal in causal_values:
causal_str = "Causal" if causal else "Non-Causal"
print(f"\n{'='*60}")
print(f"Testing {causal_str} Attention")
print(f"{'='*60}")
for L in L_values:
for H in H_values:
print(f"\n==========================================")
print(f"Testing Config: L={L}, H={H}, Causal={causal}")
print(f"==========================================")
total_max_diff_sage = 0.0
total_mean_diff_sage = 0.0
sage_valid_count = 0
sage_top_diffs_per_iter = [] # Store top N diffs per iteration
for i in range(num_iters):
# Initialize tensors in BHLD format (standard for Torch/Sage)
torch.manual_seed(42 + i)
q = torch.randn((B, H, L, D), dtype=dtype, device=device)
k = torch.randn((B, H, L, D), dtype=dtype, device=device)
v = torch.randn((B, H, L, D), dtype=dtype, device=device)
sm_scale = 1.0 / sqrt(D)
# 1. Torch BF16 Attention (Reference)
ref_out = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
# 2. Sage3 Attention
try:
sage_out = sageattn3_blackwell(q, k, v, is_causal=causal)
except Exception:
sage_out = None
# Compare Sage vs Torch
if sage_out is not None:
diff_sage = (sage_out - ref_out).abs()
max_diff_sage = diff_sage.max().item()
total_max_diff_sage += max_diff_sage
total_mean_diff_sage += diff_sage.mean().item()
sage_valid_count += 1
# Get top N diffs for this iteration
topN_this_iter = torch.topk(diff_sage.flatten(), k=min(num_top_diffs, diff_sage.numel())).values.cpu().tolist()
sage_top_diffs_per_iter.append(topN_this_iter)
print(f"Iteration {i+1}/{num_iters} done.", end='\r')
print("\n--- Average Results ---")
# Sage3 Results
if sage_valid_count > 0:
avg_max_sage = total_max_diff_sage / sage_valid_count
avg_mean_sage = total_mean_diff_sage / sage_valid_count
print(f"Sage3 vs Torch BF16 (over {sage_valid_count} runs): Avg Max Diff = {avg_max_sage:.6f}, Avg Mean Diff = {avg_mean_sage:.6f}")
# If avg max diff > 0.5, log the average top N diffs
if avg_max_sage > 5e-1:
# Average the top N diffs across all iterations
# Each element in sage_top_diffs_per_iter is a list of top N for that iteration
# We want to average the k-th largest across all iterations
num_top_values = min(len(diffs) for diffs in sage_top_diffs_per_iter) if sage_top_diffs_per_iter else 0
if num_top_values > 0:
avg_topN = [sum(diffs[i] for diffs in sage_top_diffs_per_iter) / len(sage_top_diffs_per_iter)
for i in range(num_top_values)]
print(f" Avg Top {num_top_values} largest diffs: {[f'{d:.6f}' for d in avg_topN]}")
else:
print("Sage3 vs Torch BF16: Skipped (all executions failed or not available)")
if __name__ == "__main__":
test_error_comparison()
cc @jt-zhang
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels