Skip to content

Support prefetched inputs in Pallas pipelines#38165

Open
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_921750865
Open

Support prefetched inputs in Pallas pipelines#38165
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_921750865

Conversation

@copybara-service
Copy link
Copy Markdown

Support prefetched inputs in Pallas pipelines

This change introduces support for pre-populated VMEM buffers (prefetched inputs) in Pallas pipelines. The changes introduced include:

  • Introduces PrefetchedInput to wrap a reference and its pre-allocated VMEM buffer.
  • Updates emit_pipeline execution logic to skip initial copy-in steps for these prefetched inputs, avoiding redundant DMA operations.
  • Adds comprehensive unit tests in tpu_pallas_pipeline_test.py for standard, lookahead, and trivial windowing configurations.

By using pre-populated window buffer, we can avoid the head pipeline bubbles. The initial wait in can be moved outside the pipeline loop, potential hidden by other computations.

This change introduces support for pre-populated VMEM buffers (prefetched inputs) in Pallas pipelines. The changes introduced include:

 - Introduces `PrefetchedInput` to wrap a reference and its pre-allocated VMEM buffer.
 - Updates `emit_pipeline` execution logic to skip initial copy-in steps for these prefetched inputs, avoiding redundant DMA operations.
 - Adds comprehensive unit tests in `tpu_pallas_pipeline_test.py` for standard, lookahead, and trivial windowing configurations.

By using pre-populated window buffer, we can avoid the head pipeline bubbles. The initial wait in can be moved outside the pipeline loop, potential hidden by other computations.

PiperOrigin-RevId: 921750865
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant