Skip to content

Commit cab2ff6

Browse files
committed
Merge branch 'main' of https://github.com/google/jax
2 parents dcbc5b9 + a041ea1 commit cab2ff6

File tree

19 files changed

+791
-51
lines changed

19 files changed

+791
-51
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2525
for information on migrating to the new API.
2626
* The `initial` argument to {func}`jax.nn.softmax` and {func}`jax.nn.log_softmax`
2727
has been removed, after being deprecated in v0.4.27.
28+
* Calling `np.asarray` on typed PRNG keys (i.e. keys produced by :func:`jax.random.key`)
29+
now raises an error. Previously, this returned a scalar object array.
2830
* The following deprecated methods and functions in {mod}`jax.export` have
2931
been removed:
3032
* `jax.export.DisabledSafetyCheck.shape_assertions`: it had no effect

jax/_src/core.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -955,6 +955,7 @@ def __eq__(self, other):
955955
@dataclass(frozen=True)
956956
class AxisEnv:
957957
axis_sizes : dict[AxisName, int]
958+
spmd_axis_names : set[AxisName]
958959

959960
def axis_size(self, axis_name):
960961
if axis_name not in self.axis_sizes:
@@ -971,20 +972,24 @@ def axis_names(self):
971972
def pop_pure(self, axis_name):
972973
new_sizes = self.axis_sizes.copy()
973974
new_sizes.pop(axis_name)
974-
return AxisEnv(new_sizes)
975+
return AxisEnv(new_sizes, self.spmd_axis_names)
975976

976977
def extend_pure(self, name_size_pairs):
977978
new_sizes = self.axis_sizes.copy()
978979
new_sizes.update((name, size) for name, size in name_size_pairs
979980
if name is not no_axis_name)
980-
return AxisEnv(new_sizes)
981+
return AxisEnv(new_sizes, self.spmd_axis_names)
982+
983+
def add_spmd_axis_names(self, axis_names):
984+
new_spmd_axis_names = self.spmd_axis_names | set(axis_names)
985+
return AxisEnv(self.axis_sizes, new_spmd_axis_names)
981986

982987
def as_hashable_key(self):
983988
return tuple((name, size) for (name, size) in self.axis_sizes.items()
984989
if name is not no_axis_name)
985990

986991
eval_trace = EvalTrace()
987-
top_axis_env = AxisEnv({})
992+
top_axis_env = AxisEnv({}, set())
988993

989994
class TracingContext(threading.local):
990995
trace: Trace | None
@@ -1045,6 +1050,16 @@ def extend_axis_env_nd(name_size_pairs : Iterable[tuple[AxisName, int]]):
10451050
finally:
10461051
trace_ctx.set_axis_env(prev)
10471052

1053+
@contextmanager
1054+
def add_spmd_axis_names(axis_names: AxisName | None):
1055+
prev = trace_ctx.axis_env
1056+
try:
1057+
if axis_names is not None:
1058+
trace_ctx.set_axis_env(prev.add_spmd_axis_names(axis_names))
1059+
yield
1060+
finally:
1061+
trace_ctx.set_axis_env(prev)
1062+
10481063
def get_axis_env():
10491064
return trace_ctx.axis_env
10501065

jax/_src/extend/ffi.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import ctypes
1919
import functools
2020
import os
21-
from typing import Any
21+
from typing import Any, overload
2222

2323
import numpy as np
2424

@@ -240,6 +240,43 @@ def _convert_layouts_for_ffi_call(
240240
for aval, layout in zip(avals, layouts))
241241

242242

243+
# ffi_call() returns as many results as result_shape_dtypes.
244+
@overload
245+
def ffi_call(
246+
target_name: str,
247+
result_shape_dtypes: ResultMetadata,
248+
*deprecated_args: ArrayLike,
249+
has_side_effect: bool = ...,
250+
vmap_method: str | None = ...,
251+
input_layouts: Sequence[FfiLayoutOptions] | None = ...,
252+
output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = ...,
253+
input_output_aliases: dict[int, int] | None = ...,
254+
custom_call_api_version: int = ...,
255+
legacy_backend_config: str | None = ...,
256+
vectorized: bool | DeprecatedArg = ...,
257+
**deprecated_kwargs: Any,
258+
) -> Callable[..., Array] | Array:
259+
...
260+
261+
262+
@overload
263+
def ffi_call(
264+
target_name: str,
265+
result_shape_dtypes: Sequence[ResultMetadata],
266+
*deprecated_args: ArrayLike,
267+
has_side_effect: bool = ...,
268+
vmap_method: str | None = ...,
269+
input_layouts: Sequence[FfiLayoutOptions] | None = ...,
270+
output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = ...,
271+
input_output_aliases: dict[int, int] | None = ...,
272+
custom_call_api_version: int = ...,
273+
legacy_backend_config: str | None = ...,
274+
vectorized: bool | DeprecatedArg = ...,
275+
**deprecated_kwargs: Any,
276+
) -> Callable[..., Sequence[Array]] | Sequence[Array]:
277+
...
278+
279+
243280
def ffi_call(
244281
target_name: str,
245282
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],

jax/_src/interpreters/batching.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,10 @@ def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals):
596596
idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0,
597597
source_info_util.current()))
598598
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
599-
with core.set_current_trace(trace):
600-
with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]):
601-
outs = yield in_tracers, {}
599+
with (core.set_current_trace(trace),
600+
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
601+
core.add_spmd_axis_names(axis_data.spmd_name)):
602+
outs = yield in_tracers, {}
602603

603604
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
604605
out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)),
@@ -795,9 +796,10 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals):
795796
_, in_axes = resolve_ragged_axes(in_vals, in_axes)
796797
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
797798
for val, dim in zip(in_vals, in_axes)]
798-
with core.set_current_trace(trace):
799-
with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]):
800-
outs = yield in_tracers, {}
799+
with (core.set_current_trace(trace),
800+
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
801+
core.add_spmd_axis_names(axis_data.spmd_name)):
802+
outs = yield in_tracers, {}
801803
out_vals, out_axes = unzip2(map(trace.to_batch_info, outs))
802804
new_out_axes = indirectify_ragged_axes_against_inputs_outputs(
803805
out_axes, in_vals, out_vals)

jax/_src/prng.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,11 @@ def copy(self):
279279
__hash__ = None # type: ignore[assignment]
280280
__array_priority__ = 100
281281

282+
def __array__(self, dtype: np.dtype | None = None, copy: bool | None = None) -> np.ndarray:
283+
raise TypeError("JAX array with PRNGKey dtype cannot be converted to a NumPy array."
284+
" Use jax.random.key_data(arr) if you wish to extract the underlying"
285+
" integer array.")
286+
282287
# Overwritten immediately below
283288
@property
284289
def at(self) -> _IndexUpdateHelper: assert False # type: ignore[override]

jax/experimental/shard_map.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,7 @@ def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]:
15061506
# We use a filtered-down version of unmentioned to avoid defensive-psum over
15071507
# more chips than required in the transpose-no-check-rep case.
15081508
name_set = {n for ns in names.values() for n in ns}
1509-
return [n for n in mesh.axis_names if n not in name_set]
1509+
return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set]
15101510

15111511

15121512
def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
@@ -1652,10 +1652,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
16521652

16531653
# TODO(mattjj): remove this mechanism when we revise mesh scopes
16541654
def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]:
1655-
trace = core.unsafe_get_current_trace() if trace is None else trace
1656-
stack = core.unsafe_get_trace_stack(trace)
1657-
batch_traces = [t for t in stack if isinstance(t, batching.BatchTrace)]
1658-
spmd_names = {n for trace in batch_traces for n in trace.axis_data.spmd_name }
1655+
spmd_names = core.get_axis_env().spmd_axis_names
16591656
return tuple(name for name in mesh.axis_names if name not in spmd_names)
16601657

16611658
# DCE

jaxlib/mosaic/dialect/gpu/BUILD

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ td_library(
2929
srcs = ["mosaic_gpu.td"],
3030
includes = ["."],
3131
deps = [
32+
"@llvm-project//mlir:BasicPtxBuilderIntTdFiles",
3233
"@llvm-project//mlir:BuiltinDialectTdFiles",
34+
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
3335
"@llvm-project//mlir:OpBaseTdFiles",
3436
],
3537
)
@@ -109,16 +111,19 @@ cc_library(
109111
hdrs = ["mosaic_gpu.h"],
110112
deps = [
111113
":mosaic_gpu_inc_gen",
114+
"@com_google_absl//absl/algorithm:container",
112115
"@com_google_absl//absl/status",
113116
"@com_google_absl//absl/status:statusor",
114117
"@com_google_absl//absl/strings",
115118
"@llvm-project//llvm:Support",
116119
"@llvm-project//mlir:ArithDialect",
117120
"@llvm-project//mlir:FuncDialect",
118121
"@llvm-project//mlir:IR",
122+
"@llvm-project//mlir:InferTypeOpInterface",
119123
"@llvm-project//mlir:LLVMCommonConversion",
120124
"@llvm-project//mlir:LLVMDialect",
121125
"@llvm-project//mlir:MemRefDialect",
126+
"@llvm-project//mlir:MemRefUtils",
122127
"@llvm-project//mlir:SCFUtils",
123128
"@llvm-project//mlir:Support",
124129
"@tsl//tsl/platform:statusor",
@@ -152,12 +157,19 @@ cc_test(
152157
gentbl_filegroup(
153158
name = "mosaic_gpu_python_gen_raw",
154159
tbl_outs = [
160+
(
161+
[
162+
"-gen-python-enum-bindings",
163+
"-bind-dialect=mosaic_gpu",
164+
],
165+
"_mosaic_gpu_gen_enums_raw.py",
166+
),
155167
(
156168
[
157169
"-gen-python-op-bindings",
158170
"-bind-dialect=mosaic_gpu",
159171
],
160-
"_mosaic_gpu_gen_raw.py",
172+
"_mosaic_gpu_gen_ops_raw.py",
161173
),
162174
],
163175
tblgen = "@llvm-project//mlir:mlir-tblgen",
@@ -169,10 +181,19 @@ gentbl_filegroup(
169181
)
170182

171183
genrule(
172-
name = "mosaic_gpu_python_gen",
173-
srcs = ["_mosaic_gpu_gen_raw.py"],
174-
outs = ["_mosaic_gpu_gen.py"],
175-
cmd = "cat $(location _mosaic_gpu_gen_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@",
184+
name = "mosaic_gpu_python_gen_enums",
185+
srcs = ["_mosaic_gpu_gen_enums_raw.py"],
186+
outs = ["_mosaic_gpu_gen_enums.py"],
187+
cmd = """
188+
cat $(location _mosaic_gpu_gen_enums_raw.py) | \
189+
sed -e 's/^from \\.\\.ir/from jaxlib\\.mlir\\.ir/g; s/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@""",
190+
)
191+
192+
genrule(
193+
name = "mosaic_gpu_python_gen_ops",
194+
srcs = ["_mosaic_gpu_gen_ops_raw.py"],
195+
outs = ["_mosaic_gpu_gen_ops.py"],
196+
cmd = "cat $(location _mosaic_gpu_gen_ops_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@",
176197
)
177198

178199
DIALECT_CAPI_SOURCES = [

jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,17 @@ limitations under the License.
1818
#include <cstdint>
1919
#include <vector>
2020

21-
#include "absl/status/status.h"
22-
#include "absl/status/statusor.h"
23-
#include "absl/strings/str_cat.h"
24-
#include "absl/strings/string_view.h"
2521
#include "llvm/ADT/STLExtras.h"
2622
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
2723
#include "llvm/Support/Casting.h"
24+
#include "llvm/Support/FormatVariadic.h"
25+
#include "llvm/Support/LogicalResult.h"
2826
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
2927
#include "mlir/Dialect/Arith/IR/Arith.h"
3028
#include "mlir/Dialect/Func/IR/FuncOps.h"
3129
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3230
#include "mlir/Dialect/MemRef/IR/MemRef.h"
31+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
3332
#include "mlir/Dialect/SCF/Utils/Utils.h"
3433
#include "mlir/IR/Builders.h"
3534
#include "mlir/IR/BuiltinAttributes.h"
@@ -44,6 +43,12 @@ limitations under the License.
4443
#include "mlir/IR/Value.h"
4544
#include "mlir/IR/ValueRange.h"
4645
#include "mlir/Support/LLVM.h"
46+
#include "absl/algorithm/container.h"
47+
#include "absl/status/status.h"
48+
#include "absl/status/statusor.h"
49+
#include "absl/strings/str_cat.h"
50+
#include "absl/strings/string_view.h"
51+
#include "mlir/include/mlir/IR/Diagnostics.h"
4752
#include "tsl/platform/statusor.h"
4853

4954
// Generated definitions.
@@ -232,11 +237,89 @@ void DeclareRuntimeFunctions(mlir::OpBuilder& builder) {
232237
.setVisibility(mlir::func::FuncOp::Visibility::Private);
233238
}
234239

240+
bool IsContiguous(mlir::MemRefType type) {
241+
return type.getLayout().isIdentity() ||
242+
(type.hasStaticShape() && type.getNumElements() > 0 &&
243+
mlir::memref::isStaticShapeAndContiguousRowMajor(type));
244+
}
245+
246+
namespace {
247+
llvm::LogicalResult VerifyCommonLoadStoreOp(
248+
mlir::Location loc, mlir::MemRefType gmem_type, absl::string_view gmem_name,
249+
mlir::MemRefType smem_type, absl::string_view smem_name,
250+
mlir::ArrayRef<int64_t> slice_lengths, int num_indices) {
251+
auto error = [loc](auto... params) {
252+
return emitError(loc, llvm::formatv(params...));
253+
};
254+
255+
if (!IsContiguous(smem_type)) {
256+
return error("The `{0}` memref must be contiguous.", smem_name);
257+
}
258+
if (gmem_type.getElementType() != smem_type.getElementType()) {
259+
return error(
260+
"The `source` and `destination` memrefs must have the same element "
261+
"type.");
262+
}
263+
if (absl::c_any_of(slice_lengths, [](int64_t s) { return s < -1; })) {
264+
return error(
265+
"The `slice_lengths` attribute must not contain values less than -1.");
266+
}
267+
if (gmem_type.getRank() !=
268+
smem_type.getRank() + absl::c_count(slice_lengths, -1)) {
269+
return error(
270+
"The rank of the `{0}` must be equal to the rank of the "
271+
"`{1}` plus the number of collapsed dimensions as indicated "
272+
"by -1 values in `slice_lengths`.",
273+
gmem_name, smem_name);
274+
}
275+
if (num_indices != gmem_type.getRank()) {
276+
return error("The size of `indices` must be equal to the rank of `{0}`.",
277+
gmem_name);
278+
}
279+
if (slice_lengths.size() != gmem_type.getRank()) {
280+
return error(
281+
"The size of `slice_lengths` must be equal to the rank of `{0}`.",
282+
gmem_name);
283+
}
284+
return llvm::success();
285+
}
286+
} // namespace
287+
288+
llvm::LogicalResult AsyncLoadOp::verify() {
289+
auto r = VerifyCommonLoadStoreOp(getLoc(), getSource().getType(), "source",
290+
getDestination().getType(), "destination",
291+
getSliceLengths(), getIndices().size());
292+
if (failed(r)) {
293+
return r;
294+
}
295+
296+
for (int i = 0; i < getCollective().size(); ++i) {
297+
for (int k = i + 1; k < getCollective().size(); ++k)
298+
if (getCollective()[i] == getCollective()[k]) {
299+
return emitError(
300+
"The `collective` attribute must not contain duplicate "
301+
"dimensions.");
302+
}
303+
}
304+
305+
return llvm::success();
306+
}
307+
308+
llvm::LogicalResult AsyncStoreOp::verify() {
309+
return VerifyCommonLoadStoreOp(getLoc(), getDestination().getType(),
310+
"destination", getSource().getType(), "source",
311+
getSliceLengths(), getIndices().size());
312+
}
313+
235314
void MosaicGPUDialect::initialize() {
236315
addTypes<
237316
#define GET_TYPEDEF_LIST
238317
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_types.cc.inc"
239318
>();
319+
addAttributes<
320+
#define GET_ATTRDEF_LIST
321+
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_attrdefs.cc.inc"
322+
>();
240323
addOperations<
241324
#define GET_OP_LIST
242325
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_ops.cc.inc"

0 commit comments

Comments
 (0)