Skip to content

Commit 9267b5d

Browse files
committed
Merge branch 'main' of https://github.com/jax-ml/jax
2 parents 4962afd + a212a29 commit 9267b5d

File tree

8 files changed

+46
-17
lines changed

8 files changed

+46
-17
lines changed

jax/_src/mesh_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,12 @@ def _transpose_trick(
705705
*_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims]
706706
)
707707

708+
def _validate_axis_shapes(axis_shapes: Sequence[int], arg_name: str,
709+
fun_name: str):
710+
if not all(isinstance(s, int) for s in axis_shapes):
711+
raise ValueError(
712+
f'{arg_name} passed to {fun_name} should be a sequence of ints. Got'
713+
f' {axis_shapes}')
708714

709715
def create_device_mesh(
710716
mesh_shape: Sequence[int],
@@ -740,7 +746,8 @@ def create_device_mesh(
740746
"""
741747
if devices is None:
742748
devices = xb.devices()
743-
if np.prod(mesh_shape) != len(devices):
749+
_validate_axis_shapes(mesh_shape, 'mesh_shape', 'create_device_mesh')
750+
if math.prod(mesh_shape) != len(devices):
744751
raise ValueError(
745752
f'Number of devices {len(devices)} must equal the product '
746753
f'of mesh_shape {mesh_shape}'

jax/_src/pallas/triton/lowering.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,10 +1469,22 @@ def _float_int_cast(
14691469
dst_element_type = ir.IntegerType(_element_type(dst_type))
14701470
if dst_element_type.width == 1:
14711471
return _not_equal(src, _full(src.type, 0), signed=signed)
1472-
elif signed:
1473-
return arith_dialect.fptosi(dst_type, src)
14741472
else:
1475-
return arith_dialect.fptoui(dst_type, src)
1473+
# We clamp the float value to the min/max integer destination value
1474+
# in order to match JAX/XLA casting behavior. Note that this differs
1475+
# from numpy casting behavior.
1476+
if signed:
1477+
maxint = 2**(dst_element_type.width-1) - 1
1478+
minint = -2**(dst_element_type.width-1)
1479+
else:
1480+
maxint = 2**dst_element_type.width - 1
1481+
minint = 0
1482+
src = arith_dialect.minimumf(src, _full(src.type, maxint))
1483+
src = arith_dialect.maximumf(src, _full(src.type, minint))
1484+
if signed:
1485+
return arith_dialect.fptosi(dst_type, src)
1486+
else:
1487+
return arith_dialect.fptoui(dst_type, src)
14761488

14771489

14781490
def _int_float_cast(
@@ -1499,10 +1511,12 @@ def _cast(
14991511
src,
15001512
_dtype_to_ir_type(dst_type),
15011513
signed=jnp.issubdtype(src_type, jnp.signedinteger),
1514+
dst_signed=jnp.issubdtype(dst_type, jnp.signedinteger),
15021515
)
15031516

15041517

1505-
def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
1518+
def _ir_cast(src: ir.Value, dst_type: ir.Type, *,
1519+
signed: bool, dst_signed: bool = False) -> ir.Value:
15061520
if ir.RankedTensorType.isinstance(
15071521
src.type
15081522
) and not ir.RankedTensorType.isinstance(dst_type):
@@ -1527,7 +1541,8 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
15271541
dst_element_type, ir.F32Type
15281542
):
15291543
return _ir_cast(
1530-
_ir_cast(src, ir.F32Type.get(), signed=False), dst_type, signed=False
1544+
_ir_cast(src, ir.F32Type.get(), signed=False),
1545+
dst_type, signed=False, dst_signed=dst_signed
15311546
)
15321547

15331548
if isinstance(src_element_type, ir.FloatType) and isinstance(
@@ -1543,7 +1558,7 @@ def _ir_cast(src: ir.Value, dst_type: ir.Type, *, signed: bool) -> ir.Value:
15431558
if isinstance(src_element_type, ir.FloatType) and isinstance(
15441559
dst_element_type, ir.IntegerType
15451560
):
1546-
return _float_int_cast(src, dst_type, signed=signed)
1561+
return _float_int_cast(src, dst_type, signed=dst_signed)
15471562
if isinstance(src_element_type, ir.IntegerType) and isinstance(
15481563
dst_element_type, ir.FloatType
15491564
):

jax/_src/pjit.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,10 +1662,8 @@ def _pjit_call_impl_python(
16621662
pgle_compile_options['fdo_profile'] = fdo_profile
16631663

16641664
compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items())
1665-
# TODO(patrios): Do not pass mutable profile session through cached lowering
1666-
# chain. Instead we need to move profilers dictionary to pxla module and use
1667-
# module as key. Right now we can't do that since there is no way to evict
1668-
# _pjit_lower_cached cache for in PGLE mode.
1665+
# Passing mutable PGLE profile here since it should be extracted by JAXPR to
1666+
# initialize the fdo_profile compile option.
16691667
compiled = _resolve_and_lower(
16701668
args, jaxpr=jaxpr, in_shardings=in_shardings,
16711669
out_shardings=out_shardings, in_layouts=in_layouts,

jax/_src/sharding_impls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,6 +1714,7 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
17141714
"""
17151715
if devices is None:
17161716
devices = xla_bridge.devices()
1717+
mesh_utils._validate_axis_shapes(axis_shapes, 'axis_shapes', 'make_mesh')
17171718
axis_size = math.prod(axis_shapes)
17181719
if axis_size > len(devices):
17191720
raise ValueError(

tests/mesh_utils_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,12 @@ def test_create_device_mesh_for_nd_torus(
353353
)
354354
self.assertArraysEqual(assignment, expected_assignment_matrix)
355355

356+
def test_create_device_mesh_non_int_error(self):
357+
with self.assertRaisesRegex(
358+
ValueError,
359+
"mesh_shape passed to create_device_mesh should be a sequence of ints"):
360+
mesh_utils.create_device_mesh(((4,), 4))
361+
356362
@parameterized.named_parameters(
357363
('2x2x1', mock_2x2x1_devices,),
358364
('2x2x4', mock_2x2x4_devices, ),

tests/pallas/ops_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,6 @@ def test_cast(self, from_dtype, to_dtype, data):
556556
self.skipTest("Not supported: bad canonicalization")
557557
if from_dtype == "bool" and to_dtype in {"int16", "int8"}:
558558
self.skipTest("Not supported: cannot extend to sub-32 bit types")
559-
if jtu.test_device_matches(["gpu"]):
560-
if (from_dtype in {"bfloat16", "float32"} and
561-
to_dtype in {"int8", "int16", "int32"}):
562-
self.skipTest("TODO: wrong result on GPU")
563559

564560
if from_dtype == "bfloat16":
565561
from_dtype = jnp.bfloat16

tests/pjit_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4458,6 +4458,12 @@ def g(x):
44584458
self.assertEqual(out2.sharding, s)
44594459
self.assertEqual(out2.dtype, np.float32)
44604460

4461+
def test_make_mesh_non_int_error(self):
4462+
with self.assertRaisesRegex(
4463+
ValueError,
4464+
"axis_shapes passed to make_mesh should be a sequence of ints"):
4465+
jax.make_mesh(((4,), 4), ('x', 'y'))
4466+
44614467
def test_jnp_array_reshard_error(self):
44624468
if jax.device_count() < 2:
44634469
self.skipTest('Requires >=2 devices')

third_party/xla/workspace.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
2121
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
2222
# and update XLA_SHA256 with the result.
2323

24-
XLA_COMMIT = "c7fdcbc588fa9ea021cf8766530604e8d0fef332"
25-
XLA_SHA256 = "c0e82c28e5e74065c8446199af657af71ae2f786ba33ddb23d6e1bbcd4463d50"
24+
XLA_COMMIT = "e2fe67323ea46076a61230952a3551df04ec559d"
25+
XLA_SHA256 = "0cdc3108f44f8ab37c90e165bae3bc72e16d049ad18c46d2aa8004f93df2d9f9"
2626

2727
def repo():
2828
tf_http_archive(

0 commit comments

Comments
 (0)