TorchWave fused PyTorch nativert executor (#16878)#16878
Open
oerling wants to merge 1 commit intofacebookincubator:mainfrom
Open
TorchWave fused PyTorch nativert executor (#16878)#16878oerling wants to merge 1 commit intofacebookincubator:mainfrom
oerling wants to merge 1 commit intofacebookincubator:mainfrom
Conversation
✅ Deploy Preview for meta-velox canceled.
|
Summary: TorchWave is a GPU kernel fusion and execution framework that compiles nativert FX graphs into fused CUDA kernels. It analyzes the dataflow graph, groups operations into composite kernels, generates CUDA code, and executes them with efficient GPU resource management. **Core files:** - **Registry.h/.cpp, Builtins.cpp** — Operation registry mapping nativert op names to metadata (elementwise traits, cost, code generation functions). Builtins registers standard aten ops (add, sub, mul, div, etc.). - **ParallelExpr.h/.cpp** — Analyzes the nativert graph to identify independent subgraphs (ProjectNodes) that can execute in parallel, partitioning the graph into sequential stages with internal parallelism. - **Compile.h/.cpp** — The compilation pipeline. Extracts subgraphs from ProjectNodes, matches isomorphic subgraphs to reuse compiled kernels, generates fused CUDA code for elementwise expression trees, and assembles CompositeKernels from multiple KernelOperations. - **CompiledOp.h** — Data structures for the compiled representation: KernelOperation (a single fused op with CUDA code), OpInvocation (runtime binding of a KernelOperation to actual values), CompositeKernel (groups KernelOperations into one compiled CUDA kernel), CompositeInvocation (runtime invocation with grid building and param filling), CompiledNode (parallel/sequential kernel groups), and WaveGraph (top-level compiled graph). - **Executor.h/.cpp** — The executor that ties compilation to execution. WaveGraphExecutor extends nativert's GraphExecutorBase. Contains process-wide GPU resource initialization (arenas, stream/event pools), ExecutionState management, grid construction (makeGrid), and the full execution path: output allocation, BlockInfo grid building, pinned buffer param filling, H2D transfer, and kernel launch. - **KernelParams.h** — Shared host/device structs: Tensor (storage pointer + dims/strides for up to 3D), BlockInfo (per-thread-block dispatch info with op code, element count, param pointer), TorchWaveParams (kernel entry point parameter). - **Core.cuh, Elementwise.cuh** — CUDA device code. Core.cuh provides the kernel entry macro (ENTRY) and BlockInfo loading. Elementwise.cuh implements the fused elementwise kernel body with fast-path detection for contiguous tensors and broadcast support. - **Utils.h/.cpp** — Thread-safe Pool<T> template for reusing GPU Streams and Events. - **Pt2Load.h/.cpp** — Loading .pt2 archives: deserializes the nativert graph, tensor metadata, and weight paths from PyTorchStreamReader. - **GraphView.h/.cpp** — Diagnostic printing of the nativert graph and compiled WaveGraph structure. - **NativertSerialization.cpp** — Deserialization of nativert graph IR from JSON format within .pt2 archives. - **Execute.h/.cpp** — Standalone entry point for loading and running a .pt2 model through the WaveGraph executor. - **Project.h/.cpp** — ProjectNode representation for the parallelism analysis stage. - **tests/ExecutorTest.cpp** — End-to-end test: loads a .pt2 model, runs it through both nativert SerialGraphExecutor and WaveGraphExecutor, verifies outputs match eager-mode expectations. - **tests/GraphTool.cpp** — CLI tool for inspecting .pt2 graph structure and compiled WaveGraph. - **tests/element_test.py, element_test_run_pt2.py** — Python test model (elementwise arithmetic on int64 tensors) and script to export it as a .pt2 archive for the C++ tests. Differential Revision: D95696931
d5a5e14 to
d20d472
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
TorchWave is a GPU kernel fusion and execution framework that compiles nativert FX graphs into fused CUDA kernels. It analyzes the dataflow graph, groups operations into composite kernels, generates CUDA code, and executes them with efficient GPU resource management.
Core files:
Registry.h/.cpp, Builtins.cpp — Operation registry mapping nativert op names to metadata (elementwise traits, cost, code generation functions). Builtins registers standard aten ops (add, sub, mul, div, etc.).
ParallelExpr.h/.cpp — Analyzes the nativert graph to identify independent subgraphs (ProjectNodes) that can execute in parallel, partitioning the graph into sequential stages with internal parallelism.
Compile.h/.cpp — The compilation pipeline. Extracts subgraphs from ProjectNodes, matches isomorphic subgraphs to reuse compiled kernels, generates fused CUDA code for elementwise expression trees, and assembles CompositeKernels from multiple KernelOperations.
CompiledOp.h — Data structures for the compiled representation: KernelOperation (a single fused op with CUDA code), OpInvocation (runtime binding of a KernelOperation to actual values), CompositeKernel (groups KernelOperations into one compiled CUDA kernel), CompositeInvocation (runtime invocation with grid building and param filling), CompiledNode (parallel/sequential kernel groups), and WaveGraph (top-level compiled graph).
Executor.h/.cpp — The executor that ties compilation to execution. WaveGraphExecutor extends nativert's GraphExecutorBase. Contains process-wide GPU resource initialization (arenas, stream/event pools), ExecutionState management, grid construction (makeGrid), and the full execution path: output allocation, BlockInfo grid building, pinned buffer param filling, H2D transfer, and kernel launch.
KernelParams.h — Shared host/device structs: Tensor (storage pointer + dims/strides for up to 3D), BlockInfo (per-thread-block dispatch info with op code, element count, param pointer), TorchWaveParams (kernel entry point parameter).
Core.cuh, Elementwise.cuh — CUDA device code. Core.cuh provides the kernel entry macro (ENTRY) and BlockInfo loading. Elementwise.cuh implements the fused elementwise kernel body with fast-path detection for contiguous tensors and broadcast support.
Utils.h/.cpp — Thread-safe Pool template for reusing GPU Streams and Events.
Pt2Load.h/.cpp — Loading .pt2 archives: deserializes the nativert graph, tensor metadata, and weight paths from PyTorchStreamReader.
GraphView.h/.cpp — Diagnostic printing of the nativert graph and compiled WaveGraph structure.
NativertSerialization.cpp — Deserialization of nativert graph IR from JSON format within .pt2 archives.
Execute.h/.cpp — Standalone entry point for loading and running a .pt2 model through the WaveGraph executor.
Project.h/.cpp — ProjectNode representation for the parallelism analysis stage.
tests/ExecutorTest.cpp — End-to-end test: loads a .pt2 model, runs it through both nativert SerialGraphExecutor and WaveGraphExecutor, verifies outputs match eager-mode expectations.
tests/GraphTool.cpp — CLI tool for inspecting .pt2 graph structure and compiled WaveGraph.
tests/element_test.py, element_test_run_pt2.py — Python test model (elementwise arithmetic on int64 tensors) and script to export it as a .pt2 archive for the C++ tests.
Differential Revision: D95696931