Skip to content

Commit 4962afd

Browse files
committed
Merge branch 'main' of https://github.com/jax-ml/jax
2 parents 9017a46 + df6758f commit 4962afd

File tree

22 files changed

+284
-99
lines changed

22 files changed

+284
-99
lines changed

.github/workflows/ci-build.yaml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,19 @@ jobs:
144144
145145
documentation_render:
146146
name: Documentation - render documentation
147-
runs-on: ubuntu-latest
147+
runs-on: linux-x86-n2-16
148+
container:
149+
image: index.docker.io/library/ubuntu@sha256:6d8d9799fe6ab3221965efac00b4c34a2bcc102c086a58dff9e19a08b913c7ef # ratchet:ubuntu:20.04
148150
timeout-minutes: 10
149151
strategy:
150152
matrix:
151153
python-version: ['3.10']
152154
steps:
153155
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
156+
- name: Image Setup
157+
run: |
158+
apt update
159+
apt install -y libssl-dev libsqlite3-dev
154160
- name: Set up Python ${{ matrix.python-version }}
155161
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
156162
with:
@@ -170,8 +176,7 @@ jobs:
170176
pip install -r docs/requirements.txt
171177
- name: Render documentation
172178
run: |
173-
sphinx-build --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html
174-
179+
sphinx-build -j auto --color -W --keep-going -b html -D nb_execution_mode=off docs docs/build/html
175180
176181
jax2tf_test:
177182
name: "jax2tf_test (py ${{ matrix.python-version }} on ${{ matrix.os }}, x64=${{ matrix.enable-x64}})"

build/build.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,10 +598,12 @@ async def main():
598598
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
599599

600600
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run)
601+
# Exit with error if any wheel build fails.
601602
if result.return_code != 0:
602603
raise RuntimeError(f"Command failed with return code {result.return_code}")
603-
else:
604-
sys.exit(0)
604+
605+
# Exit with success if all wheels in the list were built successfully.
606+
sys.exit(0)
605607

606608

607609
if __name__ == "__main__":

build/rocm/dev_build_rocm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@ def build_jax_xla(xla_path, rocm_version, rocm_target, use_clang, clang_path):
7777
build_command = [
7878
"python3",
7979
"./build/build.py",
80-
"--enable_rocm",
81-
"--build_gpu_plugin",
82-
"--gpu_plugin_rocm_version=60",
80+
"build"
8381
f"--use_clang={str(use_clang).lower()}",
82+
"--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt"
83+
"--rocm_path=%/opt/rocm-{rocm_version}/",
84+
"--rocm_version=60",
8485
f"--rocm_amdgpu_targets={rocm_target}",
85-
f"--rocm_path=/opt/rocm-{rocm_version}/",
8686
bazel_options,
87+
"--verbose"
8788
]
8889

8990
if clang_option:

build/rocm/tools/build_wheels.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,12 @@ def build_jaxlib_wheel(
9393
cmd = [
9494
"python",
9595
"build/build.py",
96-
"--enable_rocm",
97-
"--build_gpu_plugin",
98-
"--gpu_plugin_rocm_version=60",
96+
"build"
97+
"--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt"
9998
"--rocm_path=%s" % rocm_path,
99+
"--rocm_version=60",
100100
"--use_clang=%s" % use_clang,
101+
"--verbose"
101102
]
102103

103104
# Add clang path if clang is used.

jax/_src/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def trace_context():
212212
return (axis_env_state.value, mesh_context_manager.value,
213213
xla_metadata_context_manager.value,
214214
abstract_mesh_context_manager.value,
215+
device_context.value,
215216
compute_on_context_manager.value, enable_x64.value,
216217
numpy_rank_promotion.value, default_matmul_precision.value,
217218
dynamic_shapes.value,
@@ -245,6 +246,7 @@ def trace_context():
245246
axis_env_state = ()
246247
mesh_context_manager = ()
247248
abstract_mesh_context_manager = ()
249+
device_context = ()
248250
xla_metadata_context_manager = ()
249251
compute_on_context_manager = ()
250252

@@ -255,12 +257,14 @@ def trace_context():
255257
mesh_context_manager = context.mesh_context_manager
256258
if context and context.abstract_mesh_context_manager:
257259
abstract_mesh_context_manager = context.abstract_mesh_context_manager
260+
if context and context.device_context:
261+
device_context = context.device_context
258262
if context and context.xla_metadata_context_manager:
259263
xla_metadata_context_manager = context.xla_metadata_context_manager
260264
if context and context.compute_on_context_manager:
261265
compute_on_context_manager = context.compute_on_context_manager
262266
return (axis_env_state, mesh_context_manager, abstract_mesh_context_manager,
263-
xla_metadata_context_manager,
267+
device_context, xla_metadata_context_manager,
264268
compute_on_context_manager, enable_x64.value,
265269
numpy_rank_promotion.value, default_matmul_precision.value,
266270
dynamic_shapes.value,
@@ -976,6 +980,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
976980
axis_env_state = config_ext.Config((), include_in_jit_key=True)
977981
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
978982
abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
983+
device_context = config_ext.Config((), include_in_jit_key=True)
979984
compute_on_context_manager = config_ext.Config((), include_in_jit_key=True)
980985
xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True)
981986
else:
@@ -1019,6 +1024,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
10191024
axis_env_state: Hashable = ()
10201025
mesh_context_manager: Hashable = ()
10211026
abstract_mesh_context_manager: Hashable = ()
1027+
device_context: Hashable = ()
10221028
compute_on_context_manager: Hashable = ()
10231029
xla_metadata_context_manager: Hashable = ()
10241030

@@ -1086,6 +1092,7 @@ def set_local(self, value):
10861092
axis_env_state = JitConfig('axis_env_state')
10871093
mesh_context_manager = JitConfig('mesh_context_manager')
10881094
abstract_mesh_context_manager = JitConfig('abstract_mesh_context_manager')
1095+
device_context = JitConfig('device_context')
10891096
compute_on_context_manager = JitConfig('compute_on_context_manager')
10901097
xla_metadata_context_manager = JitConfig('xla_metadata_context_manager')
10911098

jax/_src/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1605,7 +1605,7 @@ def get_sharding(sharding, ndim):
16051605
assert len(sharding.spec) == ndim
16061606
return sharding
16071607

1608-
context_mesh = mesh_lib.mesh_context.mesh
1608+
context_mesh = mesh_lib.abstract_mesh_context.mesh
16091609
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
16101610
# code.
16111611
if context_mesh is None:

jax/_src/interpreters/pxla.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2193,8 +2193,15 @@ def lower_sharding_computation(
21932193
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
21942194
len(out_shardings), len(out_layouts), len(global_out_avals))
21952195

2196-
devices_from_context = (None if context_mesh is None or context_mesh.empty
2197-
else context_mesh._flat_devices_tuple)
2196+
if config.sharding_in_types.value:
2197+
# TODO(yashkatariya): Thread it via jit path and remove the None check by
2198+
# making tests go via set_mesh API always.
2199+
devices_from_context = (
2200+
None if mesh_lib.device_context.concrete_mesh is None
2201+
else mesh_lib.device_context.concrete_mesh._flat_devices_tuple)
2202+
else:
2203+
devices_from_context = (None if context_mesh is None or context_mesh.empty
2204+
else context_mesh._flat_devices_tuple)
21982205
# Device assignment across all inputs, outputs and shardings inside jaxpr
21992206
# should be the same.
22002207
unique_intermediate_shardings = util.stable_unique(

jax/_src/mesh.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -455,10 +455,10 @@ def local_mesh(self):
455455
_raise_value_error("local_mesh")
456456

457457
def __enter__(self):
458-
return push_mesh_context(self)
458+
return push_abstract_mesh_context(self)
459459

460460
def __exit__(self, exc_type, exc_value, traceback):
461-
pop_mesh_context()
461+
pop_abstract_mesh_context()
462462
return False
463463

464464
@staticmethod
@@ -473,36 +473,70 @@ def _raise_value_error(name):
473473
raise ValueError(f"AbstractMesh does not implement {name}")
474474

475475

476-
class MeshContext(threading.local):
476+
class AbstractMeshContext(threading.local):
477477
def __init__(self):
478478
self.stack = [None]
479479
self.mesh = self.stack[-1]
480480

481-
mesh_context = MeshContext()
481+
abstract_mesh_context = AbstractMeshContext()
482482

483-
def push_mesh_context(val):
484-
mesh_context.stack.append(val)
485-
mesh_context.mesh = val
483+
def push_abstract_mesh_context(val):
484+
abstract_mesh_context.stack.append(val)
485+
abstract_mesh_context.mesh = val
486486
# TODO(yashkatariya): Allow setting empty tuples and tuples with None in them.
487487
# Right now that leads to weird numerical issues.
488-
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
488+
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
489+
if m is not None)
489490
if non_none_meshes:
490491
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
491492
return val
492493

493-
def pop_mesh_context():
494-
mesh_context.stack.pop()
495-
mesh_context.mesh = mesh_context.stack[-1]
496-
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
494+
def pop_abstract_mesh_context():
495+
abstract_mesh_context.stack.pop()
496+
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
497+
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
498+
if m is not None)
497499
if non_none_meshes:
498500
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
499501

500502

501503
class null_mesh_context:
502504

503505
def __enter__(self):
504-
return push_mesh_context(None)
506+
return push_abstract_mesh_context(None)
505507

506508
def __exit__(self, *excinfo):
507-
pop_mesh_context()
509+
pop_abstract_mesh_context()
508510
return False
511+
512+
513+
@contextlib.contextmanager
514+
def set_mesh(mesh: Mesh):
515+
with (mesh.abstract_mesh, jax_config.sharding_in_types(True),
516+
enter_device_context(mesh)):
517+
yield
518+
519+
520+
class DeviceContext(threading.local):
521+
def __init__(self):
522+
self.stack = [None]
523+
self.concrete_mesh = self.stack[-1]
524+
525+
device_context = DeviceContext()
526+
527+
528+
@contextlib.contextmanager
529+
def enter_device_context(mesh: Mesh):
530+
device_context.stack.append(mesh)
531+
device_context.concrete_mesh = mesh
532+
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
533+
if non_none_meshes:
534+
jax_config.device_context.set_local(non_none_meshes)
535+
try:
536+
yield
537+
finally:
538+
device_context.stack.pop()
539+
device_context.concrete_mesh = device_context.stack[-1]
540+
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
541+
if non_none_meshes:
542+
jax_config.device_context.set_local(non_none_meshes)

jax/_src/pallas/mosaic_gpu/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ pytype_strict_library(
9191
":lowering",
9292
"//jax",
9393
"//jax:core",
94-
"//jax:effects",
94+
"//jax:mlir",
9595
"//jax:mosaic_gpu",
9696
"//jax:tree_util",
9797
"//jax:util",

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@
2525
from jax._src import state
2626
from jax._src import tree_util
2727
from jax._src import util
28+
from jax._src.interpreters import mlir
2829
from jax._src.lib.mlir import ir
2930
from jax._src.lib.mlir.dialects import arith as arith_dialect
31+
from jax._src.lib.mlir.dialects import llvm as llvm_dialect
3032
from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
3133
from jax._src.pallas import core as pallas_core
3234
from jax._src.pallas.mosaic_gpu import core as gpu_core
@@ -692,3 +694,31 @@ def _commit_smem_lowering(ctx: lowering.LoweringRuleContext):
692694
def commit_smem():
693695
"""Commits all writes to SMEM, making them visible to loads, TMA and WGMMA."""
694696
commit_smem_p.bind()
697+
698+
699+
broadcasted_iota_p = jax_core.Primitive("broadcasted_iota")
700+
701+
@broadcasted_iota_p.def_abstract_eval
702+
def _broadcasted_iota_abstract_eval(dtype, shape, dimension, layout):
703+
del layout, dimension
704+
return jax_core.ShapedArray(shape, dtype)
705+
706+
@lowering.register_lowering_rule(broadcasted_iota_p)
707+
def _broadcasted_iota_lowering(ctx: lowering.LoweringRuleContext, dtype, shape, dimension, layout):
708+
del ctx
709+
undef = llvm_dialect.mlir_undef(mlir.dtype_to_ir_type(dtype))
710+
is_signed = (
711+
jnp.issubdtype(dtype, jnp.signedinteger)
712+
if jnp.issubdtype(dtype, jnp.integer)
713+
else None
714+
)
715+
mlir_dtype = mlir.dtype_to_ir_type(dtype)
716+
return mgpu.FragmentedArray.splat(
717+
undef, shape, layout.value, is_signed=is_signed
718+
).foreach(
719+
lambda _, idx: arith_dialect.index_cast(mlir_dtype, idx[dimension]), create_array=True, is_signed=is_signed
720+
)
721+
722+
723+
def broadcasted_iota(dtype, shape, dimension, *, layout: Layout | None = None):
724+
return broadcasted_iota_p.bind(dtype=jnp.dtype(dtype), shape=shape, dimension=dimension, layout=layout)

0 commit comments

Comments
 (0)