Skip to content

How to write in-place custom ops compatible with torch.compile using pallas #8385

Open
@soodoshll

Description

❓ Questions and Help

I'm trying to implement an in-place operator using pallas, and wrap it as a torch custom op. However, I found it difficult to make it work with torch.compile. More specifically, I’m unclear about how to set donation, input-output aliases, and the op schema. It seems having an output aliased with the input will leads to functionalization problems in torch compiler.

Thanks!

My script is like this:

from typing import List, Callable

import os

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import torch
import torch_xla
from torch_xla.experimental import custom_kernel
from functools import partial
import torch_xla.debug.profiler as xp

server = xp.start_server(9012)
profile_logdir = "./profile" 
xp.trace_detached('localhost:9012', profile_logdir)

os.environ["XLA_SAVE_TENSORS_FILE"] = "./graph.txt"
os.environ["XLA_FLAGS"] = "--xla_dump_to=./graph_hlo/"
os.environ["XLA_DUMP_HLO_GRAPH"]="1"

M = 4096
N = 1024

def plus_one_kernel(x_ref, o_ref):
    o_ref[...] = x_ref[...] + 1

def plus_one_pallas(x: jax.Array):
    return pl.pallas_call(
        plus_one_kernel,
        grid=[2, 2],
        in_specs=[pl.BlockSpec([M, N], lambda i, j: (i, j))],
        out_specs=pl.BlockSpec([M, N], lambda i, j: (i, j)),
        out_shape=jax.ShapeDtypeStruct(x.shape, dtype=jnp.int32),
        input_output_aliases={0:0}
    )(x)

@torch.library.custom_op("xla::plus_one_", mutates_args={})
def plus_one_(x: torch.Tensor) -> torch.Tensor:
    plus_one_pt = torch_xla.experimental.custom_kernel.make_kernel_from_pallas(
        plus_one_pallas, output_shape_dtype_fn = lambda x: [(x.shape, x.dtype)]
    )
    return plus_one_pt(x)

@plus_one_.register_fake
def plus_one_fake(x: torch.Tensor) -> torch.Tensor:
    return x

def fn(x):
    torch.ops.xla.dynamo_set_buffer_donor_(x, True)
    ret = plus_one_(x)
    return ret

fn = torch.compile(fn, backend="openxla")
x = torch.ones([M * 2, N * 2], dtype=torch.int32, device='xla')

ret = fn(x)
print(ret)

And it seems it does not change the value of x.

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions