Skip to content

Add TTLInsertInterLoopCBSync pass for interloop cb synchronization#272

Draft
brnorris03 wants to merge 9 commits intobnorris/refactor-synchronizationfrom
bnorris/insert-interloop-cbsync-pass
Draft

Add TTLInsertInterLoopCBSync pass for interloop cb synchronization#272
brnorris03 wants to merge 9 commits intobnorris/refactor-synchronizationfrom
bnorris/insert-interloop-cbsync-pass

Conversation

@brnorris03
Copy link
Contributor

@brnorris03 brnorris03 commented Jan 25, 2026

Problem

When multiple compute operations share the same CB (first writes, second reads), the consumer must wait for data availability. Without explicit synchronization, it reads stale/uninitialized data.

Solution

New pass ttl-insert-inter-loop-cb-sync runs after ttl-lower-to-loops. It uses loop marker attributes (ttl.tile_loop.input_cbs, ttl.tile_loop.output_cbs) to detect CB dependencies and inserts cb_wait before consumer loops.

  • Checks ALL dominating producers using MLIR's DominanceInfo, not just adjacent loops
    • Handles non-consecutive dependencies (loop0 writes CB2, loop1 writes CB3, loop2 reads CB2)
    • Handles cross-block dependencies (producer at function level, consumer inside user loop)
  • Supports multiple input/output CBs per compute
  • Deduplicates shared CBs to avoid redundant waits

Generated code change:

// Before: consumer reads stale data
cb_push_back(CB2);       // producer done
init_sfpu(CB2, CB3);     // consumer starts immediately

// After: explicit sync
cb_push_back(CB2);       // producer done
cb_wait_front(CB2);      // wait for data
init_sfpu(CB2, CB3);     // consumer reads valid data

Testing

Lit tests cover: consecutive deps, different CBs (no sync), user loops, non-consecutive deps, cross-block deps.

E2E tests: TestTwoComputesChained computes (a + b)^2, TestThreeComputePipeline, TestMultiInputCompute.

…nsecutive compute loops when the output CB of one loop feeds into the input CB of the next. Also adds intermediate CB support to the Python runner infrastructure.
@brnorris03 brnorris03 force-pushed the bnorris/insert-interloop-cbsync-pass branch from 8a1a407 to c2c73d8 Compare January 25, 2026 01:30
@brnorris03 brnorris03 changed the title Add TTLInsertInterLoopCBSync pass that inserts cb_wait between co… Add TTLInsertInterLoopCBSync pass for interloop cb synchronization Jan 25, 2026
@brnorris03 brnorris03 marked this pull request as ready for review January 25, 2026 23:01
@brnorris03 brnorris03 requested a review from a team as a code owner January 25, 2026 23:01
consecutive compute loops when the output CB of one matches the input CB of
the next.

2. TestTwoComputesSecondResult: compute1(a + b) -> CB2, compute2(a * b) -> CB2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't think we should keep both?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this seems like a pointless scenario that would be DCE-ed before you reach C++ in regular code.

return CircularBuffer(tensor, shape, buffer_factor)


def make_intermediate_cb(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could maybe put this in an me2e module or make it clear that it's not supposed to be used by the frontend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying we should disallow L1 intermediate cbs in the DSL?

Comment on lines +65 to +67
*,
dtype: Any = None,
cb_index: int = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be getting to anxious here, but I'm concerned if we have multiple ways to construct and set up CBs it's going to be error prone.

Before this change we have a single, very clearly defined state for the CB on initialization, and a single clear processes for binding the cb, etc. I'm slightly worried that in the future we might try to either use this initializer when we shouldn't or use a property that's not initialized. At the very least maybe we can communicate that this is only for me2e? Or maybe make this initializer a factor function?

Copy link
Contributor Author

@brnorris03 brnorris03 Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you propose to use instead for L1-only CBs not backed by function arguments? This is not a test-only scenario. They are required for a number of operations, including broadcast (any time a chain of compute ops has to be produced).

self._cb_index = _next_cb_index()
self._explicit_dtype = dtype
# Allow explicit cb_index for intermediate CBs, otherwise auto-assign.
self._cb_index = cb_index if cb_index is not None else _next_cb_index()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify these don't overlap?


// Insert init_sfpu before outermost loop if not present.
// Use stored CB indices from loop attributes to find the CBs.
// Use first input/output CB for init_sfpu (hardware only needs one pair).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about this. Where does "Hardware only needs one pair" come from? Can you explain the logic here a bit? Seems wrong to init with (potentially) different CBs than are used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's how the init_sfpu function is defined, the arguments are used for the metadata. You can confirm in the Metalium API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add that as a comment please? Maybe

  // init_sfpu only uses CB metadata (format, num_faces, face_dim) to configure                                                                                                             
  // hardware. It does not access CB data. All input CBs must have matching                                                                                                                 
  // metadata for correct operation.     

I think maybe we should also have a validator to check all inputs/outputs match (but I guess inputs and outputs don't need to match... what if we have two outputs of different kinds and two inputs of different kinds..?).

It seems like the metadata used is 1) data format/type 2) num faces 3) row dimension

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concretely, this currently produces incorrect codegen:

  func.func @use_second_input_only(%a: tensor<2x2x!ttcore.tile<32x32, f32>>,                                                                                                                
                                    %b: tensor<2x2x!ttcore.tile<32x32, bf16>>)                                                                                                              
      -> tensor<2x2x!ttcore.tile<32x32, bf16>> {                                                                                                                                            
    %cb0 = ttl.bind_cb {cb_index = 0} : !ttl.cb<[2, 2], !ttcore.tile<32x32, f32>, 1>                                                                                                        
    %cb1 = ttl.bind_cb {cb_index = 1} : !ttl.cb<[2, 2], !ttcore.tile<32x32, bf16>, 1>                                                                                                       
    %cb2 = ttl.bind_cb {cb_index = 2} : !ttl.cb<[2, 2], !ttcore.tile<32x32, bf16>, 1>                                                                                                       
                                                                                                                                                                                            
    %a_ready = ttl.cb_wait %cb0 : ...                                                                                                                                                       
    %b_ready = ttl.cb_wait %cb1 : ...                                                                                                                                                       
                                                                                                                                                                                            
    %result = ttl.compute                                                                                                                                                                   
        ins(%a_ready, %b_ready : tensor<...f32>, tensor<...bf16>)                                                                                                                           
        outs(%init_cb : tensor<...bf16>) {                                                                                                                                                  
      ^bb0(%a_tile: !ttcore.tile<32x32, f32>,                                                                                                                                               
           %b_tile: !ttcore.tile<32x32, bf16>, ...):                                                                                                                                        
        ttl.yield %b_tile : !ttcore.tile<32x32, bf16>  // Only use bf16 input                                                                                                               
    }                                                                                                                                                                                       
  } 

After pass:

  %1 = ttl.bind_cb{cb_index = 0} : <[2, 2], !ttcore.tile<32x32, f32>, 1>                                                                                                                    
  %2 = ttl.bind_cb{cb_index = 1} : <[2, 2], !ttcore.tile<32x32, bf16>, 1>                                                                                                                   
  %3 = ttl.bind_cb{cb_index = 2} : <[2, 2], !ttcore.tile<32x32, bf16>, 1>                                                                                                                   
  %5 = ttl.cb_wait %2 : ... -> tensor<2x2x!ttcore.tile<32x32, bf16>>                                                                                                                        
                                                                                                                                                                                            
  ttl.init_sfpu(%1, %3) : <..., f32, 1>, <..., bf16, 1>                                                                                                                                     
  //            ^^ CB0 (f32) used for input config                                                                                                                                          
                                                                                                                                                                                            
  %7 = scf.for ... {                                                                                                                                                                        
    %8 = scf.for ... {                                                                                                                                                                      
      %extracted = tensor.extract %5[...] : tensor<2x2x!ttcore.tile<32x32, bf16>>                                                                                                           
      //                          ^^ CB1 (bf16) actually unpacked                                                                                                                           
      %dst_token, %dst_tile = ttl.copy_tile %extracted, ... : !ttcore.tile<32x32, bf16>                                                                                                     
    }                                                                                                                                                                                       
  }  

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good catch -- will fix; this was leftover from the single compute per function. Thanks

), "Should have 1 mul operation"
# Inter-loop syncs: cb_wait_front for CB2 (before compute2) and CB3 (before compute3)
assert (
source.count("cb_wait_front") >= 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe make this ==2?

break;
}
// Stop if we hit a loop without tile_loop marker - it's a user loop.
if (!parentFor->hasAttr(kTileLoopAttrName)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will break if we have a loop in between, eg if we had a 3D shape in the future.

I know we don't allow 3D inputs right now, but we should at least raise an error in the caller with some info, otherwise, we'll return the wrong loop and put the wait/init in the wrong spot, maybe:

  if (loopNest.loops.size() > 2) {                                                                                                                                                          
    return rewriter.notifyMatchFailure(                                                                                                                                                     
        op, "3D+ iteration spaces not yet supported");                                                                                                                                      
  }                                                                                                                                                                                                    

I created a local reproducer if you'd like to add it as a lit test, but no need if we don't want to support 3D shapes yet.

Copy link
Contributor Author

@brnorris03 brnorris03 Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the very near-future PRs will flatten all compute loops, so I don't think this will be an issue for long (unless I misunderstand). Required for unrolling among other things.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's easy would you mind adding the error and then removing it in the follow up PR that has tests for this case?

brnorris03 added a commit that referenced this pull request Jan 27, 2026
  - Added marking of outermost loop with kTileLoopOuterAttrName
  - Uses kTileLoopInputCBsAttrName and kTileLoopOutputCBsAttrName (plural, array-based)

  2. TTL.h (attribute constants):
  - Added kTileLoopOuterAttrName = "ttl.tile_loop.outer"
  - Changed to plural names: kTileLoopInputCBsAttrName, kTileLoopOutputCBsAttrName
  - Added comments documenting ArrayAttr of I64IntegerAttr format

  3. TTLOpsUtils.h (utility functions):
  - Added findBindCBByIndex - find bind_cb by index
  - Added getCBIndicesFromLoopAttr - get indices from array attribute
  - Added getCBValuesFromLoopAttr - get CB values by looking up bind_cb ops
  - Added findOutermostLoop - find outermost scf.for
  - Added findOutermostComputeLoop - find outermost compute loop respecting markers, with > 2D error check

  4. TTLInsertTileRegsSync.cpp (complete rewrite):
  - Uses findOutermostComputeLoop instead of simple outermost loop
  - Uses findExtractedInputCB to find the correct CB for init_sfpu (prefers CB matching output shape)
  - Error if no input CB is extracted but input_cbs attribute is non-empty
  - Uses array-based getCBValuesFromLoopAttr for multiple outputs
  - Maps tensor.insert to correct output CB via iter_arg indices
@brnorris03 brnorris03 marked this pull request as draft January 29, 2026 23:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants