Open
Description
❓ Questions and Help
I'm trying to implement an in-place operator using pallas, and wrap it as a torch custom op. However, I found it difficult to make it work with torch.compile
. More specifically, I’m unclear about how to set donation, input-output aliases, and the op schema. It seems having an output aliased with the input will leads to functionalization problems in torch compiler.
Thanks!
My script is like this:
from typing import List, Callable
import os
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import torch
import torch_xla
from torch_xla.experimental import custom_kernel
from functools import partial
import torch_xla.debug.profiler as xp
server = xp.start_server(9012)
profile_logdir = "./profile"
xp.trace_detached('localhost:9012', profile_logdir)
os.environ["XLA_SAVE_TENSORS_FILE"] = "./graph.txt"
os.environ["XLA_FLAGS"] = "--xla_dump_to=./graph_hlo/"
os.environ["XLA_DUMP_HLO_GRAPH"]="1"
M = 4096
N = 1024
def plus_one_kernel(x_ref, o_ref):
o_ref[...] = x_ref[...] + 1
def plus_one_pallas(x: jax.Array):
return pl.pallas_call(
plus_one_kernel,
grid=[2, 2],
in_specs=[pl.BlockSpec([M, N], lambda i, j: (i, j))],
out_specs=pl.BlockSpec([M, N], lambda i, j: (i, j)),
out_shape=jax.ShapeDtypeStruct(x.shape, dtype=jnp.int32),
input_output_aliases={0:0}
)(x)
@torch.library.custom_op("xla::plus_one_", mutates_args={})
def plus_one_(x: torch.Tensor) -> torch.Tensor:
plus_one_pt = torch_xla.experimental.custom_kernel.make_kernel_from_pallas(
plus_one_pallas, output_shape_dtype_fn = lambda x: [(x.shape, x.dtype)]
)
return plus_one_pt(x)
@plus_one_.register_fake
def plus_one_fake(x: torch.Tensor) -> torch.Tensor:
return x
def fn(x):
torch.ops.xla.dynamo_set_buffer_donor_(x, True)
ret = plus_one_(x)
return ret
fn = torch.compile(fn, backend="openxla")
x = torch.ones([M * 2, N * 2], dtype=torch.int32, device='xla')
ret = fn(x)
print(ret)
And it seems it does not change the value of x
.