- 
                Notifications
    
You must be signed in to change notification settings  - Fork 61
 
Open
Labels
Description
Raised by @v0i0 - this would be helpful for consolidating common logic for attention variants, e.g. in #764.
Example:
from __future__ import annotations
import torch
import helion
import helion.language as hl
from helion.language import Tile
# TODO: we could add some decorator here to specifically say that "this is a Helion device loop"
# e.g. `@helion.device_loop()`
def inner_device_loop(tile: Tile, x_chunk: torch.Tensor, y_chunk: torch.Tensor) -> torch.Tensor:
    """Device helper that performs its own hl.tile iteration."""
    tmp = torch.empty_like(x_chunk)
    # Second-level device loop: iterate over the elements owned by ``tile``
    for local_tile in hl.tile(tile.block_size, block_size=32):
        tmp[local_tile] = x_chunk[local_tile] + y_chunk[local_tile]
    return tmp
@helion.kernel()
def nested_device_loops(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Outer kernel that delegates a chunk of work to ``inner_device_loop``."""
    assert x.shape == y.shape
    out = torch.empty_like(x)
    # First-level device loop tiles the full iteration space.
    for tile in hl.tile(x.numel(), block_size=128):
        x_chunk = x[tile]
        y_chunk = y[tile]
        # Call into a helper that contains another device loop.
        out[tile] = inner_device_loop(tile, x_chunk, y_chunk)
    return out
def main() -> None:
    if not torch.cuda.is_available():
        raise RuntimeError("This example expects a CUDA-capable device.")
    size = 1 << 12
    x = torch.randn(size, device="cuda", dtype=torch.float32)
    y = torch.randn(size, device="cuda", dtype=torch.float32)
    out = nested_device_loops(x, y)
    torch.testing.assert_close(out, x + y)
if __name__ == "__main__":
    main()