Skip to content

0.5.0 add support for most Triton's features; large rewrite/refactoring#383

Open
Arech8 wants to merge 1 commit intojax-ml:mainfrom
Arech8:pr2_bugfix_kwargs_namespace
Open

0.5.0 add support for most Triton's features; large rewrite/refactoring#383
Arech8 wants to merge 1 commit intojax-ml:mainfrom
Arech8:pr2_bugfix_kwargs_namespace

Conversation

@Arech8
Copy link
Copy Markdown
Contributor

@Arech8 Arech8 commented Apr 16, 2026

0.5.0

Breaking changes

  • float now follows upstream convention and is represented as fp32, instead of old
    fp64
  • zeroed_outputs= parameter of triton_call() no longer supports zeroing of aliased
    input-output arguments.

New features / bugfixing

  • all possible backend initialization options is now fully supported and is handled
    similarly to the upstream (via single kwargs dictionary).
  • support for @jt.kernel decorator and a concise Triton-native form of launching a
    kernel with kernel[grid](*args, **kwargs) syntax.
  • arrays and other run-time values can now also be passed as a key-value pair to the
    launcher when out_names= is set or if a new dictionary form of out_shape= is used.
  • handling of kernel arguments specialization and defaults values now is fully delegate
    to the upstream Triton code, which enables full support for default values, kernel
    parameter annotations, related @triton.jit() arguments such as do_not_specialize,
    and also using tuples (including deeply nested), callables or strings as kernel
    arguments.
  • out_shape, input_output_aliases and zeroed_outputs handling is fully reworked
    to support nested tuples and now is based on a kernel signature coordinate system,
    instead of flat array indices, leading to a much clearer launcher syntax.
  • dictionary form of input_output_aliases= is deprecated, but is still fully supported
  • CAN_USE_TRITON guard is dropped due to obsolescence
  • tests grew 187 to 438 test cases

Performance

Despite obviously huge additional code to support kernel's coordinate space for triton_call() parameters, performance of jitted launchers remain indistinguishable from the previous version. Non-jitted launchers are indeed roughly 25% slower for comparable features (I have a changeset to lower this number to ~10% at most, but it makes code less simple and straightforward, so I'll publish it after this one is merged). Typical current numbers for launching the following kernel

@triton.jit
def copy_scalar_triton(in_ptr, out_ptr):
    value = tl.load(in_ptr)
    tl.store(out_ptr, value)

on a scalar or 4G array are:

                               Benchmark comparison results (Brunner Munzel test, alpha=0.00100)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                                       ┃ mean (means),              ┃ median (means),            ┃ min (means),               ┃
┃ Benchmark                             ┃ [0%, 50%, 100%]            ┃ [0%, 50%, 100%]            ┃ [0%, 50%, 100%]            ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ startup(1) | new vs old               │ 116.8us > 85.06us {-27.2%} │ 113.0us > 83.66us {-26.0%} │ 101.7us > 75.22us {-26.1%} │
│                                       │ [96.40u,114.9u,339.4u] >   │ [93.90u,112.6u,133.9u] >   │ [78.47u,102.6u,124.0u] >   │
│                                       │ [66.87u,84.97u,107.0u]     │ [64.59u,83.37u,103.7u]     │ [58.01u,75.64u,89.43u]     │
│                                       │ {-30.6%,-26.0%,-68.5%}     │ {-31.2%,-25.9%,-22.6%}     │ {-26.1%,-26.3%,-27.9%}     │
│                                       │ p=0.00000+(200 vs 200)     │ p=0.00000+(200 vs 200)     │ p=0.00000+(200 vs 200)     │
│                                       │ pvs(101) >:100.0%(101)     │ pvs(101) >:100.0%(101)     │ pvs(101) >:100.0%(101)     │
│ startup(1) jcall | new vs old         │ 57.46us ~ 57.96us {+0.9%}  │ 55.84us ~ 56.01us {+0.3%}  │ 50.95us ~ 51.06us {+0.2%}  │
│                                       │ [41.76u,57.51u,74.07u] ~   │ [40.99u,56.38u,70.59u] ~   │ [38.51u,51.64u,69.87u] ~   │
│                                       │ [43.57u,57.51u,93.08u]     │ [42.42u,56.20u,73.00u]     │ [38.10u,52.11u,70.49u]     │
│                                       │ {+4.3%,+0.0%,+25.7%}       │ {+3.5%,-0.3%,+3.4%}        │ {-1.1%,+0.9%,+0.9%}        │
│                                       │           (200 vs 200)     │           (200 vs 200)     │           (200 vs 200)     │
│                                       │ pvs(101) ~:100.0%(101)     │ pvs(101) ~:100.0%(101)     │ pvs(101) ~:100.0%(101)     │
│ startup(4G) jcall | new vs old        │ 62.75us ~ 62.83us {+0.1%}  │ 62.28us ~ 61.65us {-1.0%}  │ 53.18us ~ 53.50us {+0.6%}  │
│                                       │ [47.59u,62.36u,77.50u] ~   │ [43.37u,60.14u,80.14u] ~   │ [41.05u,52.98u,72.50u] ~   │
│                                       │ [47.27u,62.58u,149.0u]     │ [44.12u,59.63u,85.54u]     │ [39.60u,53.10u,72.60u]     │
│                                       │ {-0.7%,+0.4%,+92.2%}       │ {+1.7%,-0.8%,+6.7%}        │ {-3.5%,+0.2%,+0.1%}        │
│                                       │           (200 vs 200)     │           (200 vs 200)     │           (200 vs 200)     │
│                                       │ pvs(101) ~:100.0%(101)     │ pvs(101) ~:100.0%(101)     │ pvs(101) ~:100.0%(101)     │
│ startup(4G) out jcall | new vs old    │ 7.352ms ~ 7.349ms {-0.0%}  │ 7.348ms ~ 7.347ms {-0.0%}  │ 7.305ms ~ 7.307ms {+0.0%}  │
│                                       │ [7.288m,7.349m,7.567m] ~   │ [7.302m,7.346m,7.421m] ~   │ [7.179m,7.307m,7.366m] ~   │
│                                       │ [7.304m,7.348m,7.452m]     │ [7.305m,7.346m,7.413m]     │ [7.183m,7.312m,7.364m]     │
│                                       │ {+0.2%,-0.0%,-1.5%}        │ {+0.0%,+0.0%,-0.1%}        │ {+0.0%,+0.1%,-0.0%}        │
│                                       │           (200 vs 200)     │           (200 vs 200)     │           (200 vs 200)     │
│                                       │ pvs(101) ~:100.0%(101)     │ pvs(101) ~:100.0%(101)     │ pvs(101) ~:100.0%(101)     │
│ startup(4G) donate jcall | new vs old │ 73.95us ~ 71.23us {-3.7%}  │ 70.54us ~ 71.07us {+0.8%}  │ 61.66us ~ 61.93us {+0.4%}  │
│                                       │ [47.98u,71.71u,347.2u] ~   │ [45.21u,69.52u,97.43u] ~   │ [40.62u,62.06u,88.55u] ~   │
│                                       │ [47.93u,71.87u,100.8u]     │ [47.03u,70.32u,100.3u]     │ [41.30u,62.64u,94.38u]     │
│                                       │ {-0.1%,+0.2%,-71.0%}       │ {+4.0%,+1.1%,+3.0%}        │ {+1.7%,+0.9%,+6.6%}        │
│                                       │           (200 vs 200)     │           (200 vs 200)     │           (200 vs 200)     │
│                                       │ pvs(101) ~:100.0%(101)     │ pvs(101) ~:100.0%(101)     │ pvs(101) ~:100.0%(101)     │
└───────────────────────────────────────┴────────────────────────────┴────────────────────────────┴────────────────────────────┘

To reproduce put the following preparation script in a sibling directory to the checkout and run it:

#!/usr/bin/bash

THIS_DIR="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)"
#echo "THIS_DIR: $THIS_DIR"

cd $THIS_DIR/../jax-triton
git worktree add ../jax-triton-old upstream/main

cd $THIS_DIR/../jax-triton-old
sed -i 's/name = "jax-triton"/name = "jax-triton-old"/' ./pyproject.toml
sed -i 's/packages = \["jax_triton"\]/packages = \["jax_triton_old"\]/' ./pyproject.toml
sed -i 's/version = {attr = "jax_triton.version.__version__"}/version = {attr = "jax_triton_old.version.__version__"}/' ./pyproject.toml
find ./jax_triton -name '*.py' -exec sed -i 's/\bjax_triton\b/jax_triton_old/g' {} +
mv ./jax_triton ./jax_triton_old

pip install .
pip install benchstats  # benchmarks runner + statistical testing package used

echo "Run as $ python ./comparative_bm.py --iters=5 --reps=200 --warmup=5"

The comparative_bm.py script is:

import argparse
import os
import time

from benchstats import qbench as qb
from benchstats.common import LoggingConsole
from benchstats.render import makeReadable
import jax
import jax.numpy as jnp
import jax_triton as jt
import jax_triton_old as jt_old
import numpy as np
import triton
import triton.language as tl
from triton.experimental import gluon
from triton.experimental.gluon import language as gl


@qb.registerBenchmark
def make_startup_benchmark():
  NGigs = 4

  @triton.jit
  def copy_scalar_triton(in_ptr, out_ptr):
    value = tl.load(in_ptr)
    tl.store(out_ptr, value)

  def startup(module, input: jnp.ndarray, kernel) -> jnp.ndarray:
    return module.triton_call(
      input,
      kernel=kernel,
      out_shape=jax.ShapeDtypeStruct(shape=input.shape, dtype=input.dtype),
      grid=1,
    )

  def startup_out(module, input: jnp.ndarray, output: jnp.ndarray, kernel) -> jnp.ndarray:
    return module.triton_call(
      input,
      output,
      kernel=kernel,
      out_shape=jax.ShapeDtypeStruct(shape=input.shape, dtype=input.dtype),
      grid=1,
      input_output_aliases={1: 0},
    )

  def init_scalar() -> list:
    return [jnp.array(42.314)]

  def init_vec() -> list:
    return [jnp.arange(NGigs * 1024 * 1024 * 1024)]

  def init_vec_out() -> list:
    i = jnp.arange(NGigs * 1024 * 1024 * 1024)
    return [i, jnp.empty_like(i)]

  # fmt: off
  return {
    "startup(1)|new": (lambda x: startup(jt, x, copy_scalar_triton), init_scalar),
    "startup(1)|old": (lambda x: startup(jt_old, x, copy_scalar_triton), init_scalar),
    "startup(1) jcall|new": (jax.jit(lambda x: startup(jt, x, copy_scalar_triton)), init_scalar),
    "startup(1) jcall|old": (jax.jit(lambda x: startup(jt_old, x, copy_scalar_triton)), init_scalar),
    f"startup({NGigs}G) jcall|new": (jax.jit(lambda x: startup(jt, x, copy_scalar_triton)), init_vec),
    f"startup({NGigs}G) jcall|old": (jax.jit(lambda x: startup(jt_old, x, copy_scalar_triton)), init_vec),
    f"startup({NGigs}G) out jcall|new": (jax.jit(lambda x, y: startup_out(jt, x, y, copy_scalar_triton)), init_vec_out),
    f"startup({NGigs}G) out jcall|old": (jax.jit(lambda x, y: startup_out(jt_old, x, y, copy_scalar_triton)), init_vec_out),
    f"startup({NGigs}G) donate jcall|new": (jax.jit(lambda x, y: startup_out(jt, x, y, copy_scalar_triton), donate_argnums=(1,)), init_vec_out),
    f"startup({NGigs}G) donate jcall|old": (jax.jit(lambda x, y: startup_out(jt_old, x, y, copy_scalar_triton), donate_argnums=(1,)), init_vec_out),
  }
  # fmt: on


def run_benchmarks(
  enabled: list[str] | None = None,
  *,
  iters: int = 100,
  reps: int = 10,
  warmup: int = 3,
  batch_functions: bool = False,
  pvalue_stats_bootstrap: int = 1000,
):
  start = time.perf_counter_ns()

  if not enabled:
    enabled = qb.getRegisteredBenchmarkSetNames()

  console = LoggingConsole()

  bms = qb.getRegisteredBenchmarks(enabled)

  bm_names = tuple(bms.keys())
  all_bms = '", "'.join(bm_names)
  print(f'Going to run {len(bm_names)} benchmarks ({jax.local_device_count()} device(s) available): "{all_bms}"')
  if jax.local_device_count() > 1:
    console.warning(
      "More than 1 device is visible. For potentially better results consistency, restrict the number of devices using ROCR_VISIBLE_DEVICES or similar environment variable."
    )

  # we're interested in a wall-clock time of invoking kernels, including all python
  # related overhead, so time.perf_counter_ns() used there is appropriate
  _, results = qb.benchmark(
    bms.values(),
    iters=iters,
    reps=reps,
    warmup=warmup,
    randomize_iterations=True,
    batch_functions=batch_functions,
    wait_complete=jax.block_until_ready,
    show_progress_each=1,
    bm_names=bm_names,
    alt_delimiter="|",
    metrics={"mean": np.mean, "median": np.median, "min": np.min},
    console=console,
    pvalue_stats_bootstrap=pvalue_stats_bootstrap,
  )

  end = time.perf_counter_ns()
  console.print(f"Done in {makeReadable((end - start) * 1e-9, 1)}s")
  return results


_g_ProgramName = "jax-triton/jax-triton-old benchmarks runner"


def main():
  parser = argparse.ArgumentParser(description=_g_ProgramName)
  parser = qb.makeArgumentParser(parser)

  args = parser.parse_args()
  run_benchmarks(
    enabled=args.benchmark_sets or None,
    iters=args.iters,
    reps=args.reps,
    warmup=args.warmup,
    batch_functions=args.batch_functions,
    pvalue_stats_bootstrap=args.pvalue_stats_bootstrap,
  )


if __name__ == "__main__":
  main()

@Arech8 Arech8 force-pushed the pr2_bugfix_kwargs_namespace branch 3 times, most recently from 4d14ff0 to 28a56c6 Compare April 17, 2026 09:30
@Arech8 Arech8 force-pushed the pr2_bugfix_kwargs_namespace branch from 28a56c6 to e921d3f Compare April 17, 2026 09:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant