Skip to content

Commit 67f24df

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Activate FFI implementation of symmetric Eigendecomposition.
These kernels support shape polymorphism in all dimensions and no GPU is required during lowering. The kernels have been included in jaxlib for more than 3 weeks so we don't need to include any forward compatibility checks. PiperOrigin-RevId: 682415506
1 parent 18f48bd commit 67f24df

File tree

7 files changed

+445
-243
lines changed

7 files changed

+445
-243
lines changed

jax/_src/export/_export.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -947,11 +947,6 @@ def _check_lowering(lowering) -> None:
947947
"__gpu$xla.gpu.triton", # Pallas call on GPU
948948
# cholesky on CPU
949949
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
950-
# eigh on CPU
951-
"lapack_ssyevd", "lapack_dsyevd", "lapack_cheevd", "lapack_zheevd",
952-
# eigh on GPU
953-
"cusolver_syevj", "cusolver_syevd",
954-
"hipsolver_syevj", "hipsolver_syevd",
955950
# eigh on TPU
956951
"Eigh",
957952
# eig on CPU
@@ -969,9 +964,12 @@ def _check_lowering(lowering) -> None:
969964
# lu on GPU
970965
"cu_lu_pivots_to_permutation", "cusolver_getrf_ffi",
971966
"hip_lu_pivots_to_permutation", "hipsolver_getrf_ffi",
967+
"cu_lu_pivots_to_permutation", "cusolver_getrf_ffi",
972968
# qr on GPU
973969
"cusolver_geqrf_ffi", "cusolver_orgqr_ffi",
974970
"hipsolver_geqrf_ffi", "hipsolver_orgqr_ffi",
971+
# eigh on GPU
972+
"cusolver_syevd_ffi", "hipsolver_syevd_ffi",
975973
# svd on GPU
976974
# lu on TPU
977975
"LuDecomposition",

jax/_src/internal_test_util/export_back_compat_test_data/cuda_eigh_cusolver_syev.py

Lines changed: 340 additions & 1 deletion
Large diffs are not rendered by default.

jax/_src/lax/linalg.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,7 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index):
873873
if isinstance(operand, ShapedArray):
874874
if operand.ndim < 2 or operand.shape[-2] != operand.shape[-1]:
875875
raise ValueError(
876-
"Argument to symmetric eigendecomposition must have shape [..., n, n],"
876+
"Argument to symmetric eigendecomposition must have shape [..., n, n], "
877877
"got shape {}".format(operand.shape))
878878

879879
batch_dims = operand.shape[:-2]
@@ -894,33 +894,39 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index):
894894

895895

896896
def _eigh_cpu_gpu_lowering(
897-
syevd_impl, ctx, operand, *, lower, sort_eigenvalues, subset_by_index,
898-
platform=None
897+
ctx, operand, *, lower, sort_eigenvalues, subset_by_index,
898+
target_name_prefix: str
899899
):
900900
del sort_eigenvalues # The CPU/GPU implementations always sort.
901901
operand_aval, = ctx.avals_in
902902
v_aval, w_aval = ctx.avals_out
903903
n = operand_aval.shape[-1]
904-
batch_dims = operand_aval.shape[:-2]
905-
906-
# The eigh implementation on CPU and GPU uses lapack helper routines to
907-
# find the size of the workspace based on the non-batch dimensions.
908-
# Therefore, we cannot yet support dynamic non-batch dimensions.
909-
if not is_constant_shape(operand_aval.shape[-2:]):
910-
raise NotImplementedError(
911-
"Shape polymorphism for native lowering for eigh is implemented "
912-
f"only for the batch dimensions: {operand_aval.shape}")
913-
914904
if not (subset_by_index is None or subset_by_index == (0, n)):
915-
raise NotImplementedError("subset_by_index not implemented for CPU and GPU")
905+
raise NotImplementedError("subset_by_index not supported on CPU and GPU")
906+
batch_dims = operand_aval.shape[:-2]
907+
nb = len(batch_dims)
908+
layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1))
909+
result_layouts = [layout, tuple(range(nb, -1, -1)),
910+
tuple(range(nb - 1, -1, -1))]
911+
if target_name_prefix == "cpu":
912+
dtype = operand_aval.dtype
913+
prefix = "he" if dtypes.issubdtype(dtype, np.complexfloating) else "sy"
914+
target_name = lapack.prepare_lapack_call(f"{prefix}evd_ffi",
915+
operand_aval.dtype)
916+
kwargs = {
917+
"mode": np.uint8(ord("V")),
918+
"uplo": np.uint8(ord("L" if lower else "U")),
919+
}
920+
else:
921+
target_name = f"{target_name_prefix}solver_syevd_ffi"
922+
kwargs = {"lower": lower, "algorithm": np.uint8(0)}
916923

917-
op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape)
918-
cpu_args = []
919-
if platform == "cpu":
920-
ctx_args = (ctx,)
921-
cpu_args.extend(ctx_args)
922-
v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand,
923-
a_shape_vals=op_shape_vals, lower=lower)
924+
rule = ffi.ffi_lowering(target_name, operand_layouts=[layout],
925+
result_layouts=result_layouts,
926+
operand_output_aliases={0: 0})
927+
info_aval = ShapedArray(batch_dims, np.dtype(np.int32))
928+
sub_ctx = ctx.replace(avals_out=[v_aval, w_aval, info_aval])
929+
v, w, info = rule(sub_ctx, operand, **kwargs)
924930

925931
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
926932
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
@@ -1054,17 +1060,15 @@ def _eigh_batching_rule(
10541060
batching.primitive_batchers[eigh_p] = _eigh_batching_rule
10551061

10561062
mlir.register_lowering(
1057-
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo, platform='cpu'),
1063+
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cpu'),
10581064
platform='cpu')
10591065

10601066
if gpu_solver is not None:
10611067
mlir.register_lowering(
1062-
eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.cuda_syevd,
1063-
platform='cuda'),
1068+
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='cu'),
10641069
platform='cuda')
10651070
mlir.register_lowering(
1066-
eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd,
1067-
platform='rocm'),
1071+
eigh_p, partial(_eigh_cpu_gpu_lowering, target_name_prefix='hip'),
10681072
platform='rocm')
10691073

10701074
mlir.register_lowering(

jaxlib/gpu_solver.py

Lines changed: 1 addition & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from collections.abc import Sequence
1615
from functools import partial
1716
import importlib
1817
import math
@@ -24,9 +23,7 @@
2423

2524
from jaxlib import xla_client
2625

27-
from .hlo_helpers import (
28-
DimensionSize, ShapeTypePair, mk_result_types_and_shapes,
29-
custom_call, ensure_hlo_s32, hlo_s32, dense_int_array)
26+
from .hlo_helpers import custom_call, dense_int_array
3027

3128
try:
3229
from .cuda import _blas as _cublas # pytype: disable=import-error
@@ -122,80 +119,6 @@ def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
122119
cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver)
123120

124121

125-
def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, *,
126-
a_shape_vals: tuple[DimensionSize, ...], lower=False):
127-
"""Symmetric (Hermitian) eigendecomposition."""
128-
a_type = ir.RankedTensorType(a.type)
129-
assert len(a_shape_vals) >= 2
130-
m, n = a_shape_vals[-2:]
131-
assert type(m) is int and type(n) is int and m == n, a_shape_vals
132-
batch_dims_vals = a_shape_vals[:-2]
133-
134-
num_bd = len(batch_dims_vals)
135-
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
136-
137-
dynamic_batch_dims = any(type(d) != int for d in batch_dims_vals)
138-
if dynamic_batch_dims:
139-
batch_int = -1 # Signals to the kernel that the batch is an operand.
140-
else:
141-
batch_int = math.prod(batch_dims_vals)
142-
143-
if have_jacobi_solver and n <= 32 and not dynamic_batch_dims:
144-
# We cannot use syevj for dynamic shapes because the workspace size
145-
# depends on the batch size.
146-
kernel = f"{platform}solver_syevj"
147-
lwork, opaque = gpu_solver.build_syevj_descriptor(
148-
np.dtype(dtype), lower, batch_int, n)
149-
else:
150-
kernel = f"{platform}solver_syevd"
151-
lwork, opaque = gpu_solver.build_syevd_descriptor(
152-
np.dtype(dtype), lower, batch_int, n)
153-
# TODO(Ruturaj4): Currently, hipsolverSsyevd sets lwork to 0 if n==0.
154-
# Remove if this behavior changes in then new ROCm release.
155-
if n > 0 or platform != "hip":
156-
assert lwork > 0
157-
158-
if ir.ComplexType.isinstance(a_type.element_type):
159-
eigvals_type = ir.ComplexType(a_type.element_type).element_type
160-
else:
161-
eigvals_type = a_type.element_type
162-
163-
i32_type = ir.IntegerType.get_signless(32)
164-
operands = [a]
165-
operand_layouts = [layout]
166-
if dynamic_batch_dims:
167-
batch_size_val = hlo_s32(1)
168-
for b_v in batch_dims_vals:
169-
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
170-
operands.append(batch_size_val)
171-
operand_layouts.append(())
172-
173-
shape_type_pairs: Sequence[ShapeTypePair] = [
174-
(a_shape_vals, a_type.element_type),
175-
(batch_dims_vals + (n,), eigvals_type),
176-
(batch_dims_vals, i32_type),
177-
([lwork], a_type.element_type)]
178-
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
179-
out = custom_call(
180-
kernel,
181-
result_types=result_types,
182-
operands=operands,
183-
backend_config=opaque,
184-
operand_layouts=operand_layouts,
185-
result_layouts=[
186-
layout,
187-
tuple(range(num_bd, -1, -1)),
188-
tuple(range(num_bd - 1, -1, -1)),
189-
[0],
190-
],
191-
operand_output_aliases={0: 0},
192-
result_shapes=result_shapes).results
193-
return out[:3]
194-
195-
cuda_syevd = partial(_syevd_hlo, "cu", _cusolver, True)
196-
rocm_syevd = partial(_syevd_hlo, "hip", _hipsolver, True)
197-
198-
199122
def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
200123
full_matrices=True, compute_uv=True):
201124
"""Singular value decomposition."""

jaxlib/lapack.py

Lines changed: 0 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -340,120 +340,6 @@ def gesdd_hlo(ctx, dtype, a: ir.Value, *, full_matrices=True, compute_uv=True,
340340
).results[1:]
341341

342342

343-
# # syevd: Symmetric eigendecomposition
344-
345-
def syevd_hlo(ctx, dtype, a: ir.Value,
346-
a_shape_vals: tuple[DimensionSize, ...],
347-
lower=False):
348-
a_type = ir.RankedTensorType(a.type)
349-
assert len(a_shape_vals) >= 2
350-
m, n = a_shape_vals[-2:]
351-
# Non-batch dimensions must be static
352-
assert type(m) is int and type(n) is int and m == n, a_shape_vals
353-
354-
batch_dims_vals = a_shape_vals[:-2]
355-
num_bd = len(a_shape_vals) - 2
356-
mode = _enum_to_char_attr(eig.ComputationMode.kComputeEigenvectors)
357-
358-
i32_type = ir.IntegerType.get_signless(32)
359-
workspace: list[ShapeTypePair]
360-
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
361-
# Hermitian is for complex square matrices, symmetric otherwise.
362-
fn_base = "he" if dtype == np.complex64 or dtype == np.complex128 else "sy"
363-
fn_base = prepare_lapack_call(fn_base=fn_base + "evd", dtype=dtype)
364-
if ctx.is_forward_compat():
365-
fn = fn_base
366-
if dtype == np.float32:
367-
eigvals_type = ir.F32Type.get()
368-
workspace = [
369-
([_lapack.syevd_work_size(n)], a_type.element_type),
370-
([_lapack.syevd_iwork_size(n)], i32_type),
371-
]
372-
elif dtype == np.float64:
373-
eigvals_type = ir.F64Type.get()
374-
workspace = [
375-
([_lapack.syevd_work_size(n)], a_type.element_type),
376-
([_lapack.syevd_iwork_size(n)], i32_type),
377-
]
378-
elif dtype == np.complex64:
379-
eigvals_type = ir.F32Type.get()
380-
workspace = [
381-
([_lapack.heevd_work_size(n)], a_type.element_type),
382-
([_lapack.heevd_rwork_size(n)], eigvals_type),
383-
([_lapack.syevd_iwork_size(n)], i32_type),
384-
]
385-
elif dtype == np.complex128:
386-
eigvals_type = ir.F64Type.get()
387-
workspace = [
388-
([_lapack.heevd_work_size(n)], a_type.element_type),
389-
([_lapack.heevd_rwork_size(n)], eigvals_type),
390-
([_lapack.syevd_iwork_size(n)], i32_type),
391-
]
392-
else:
393-
raise NotImplementedError(f"Unsupported dtype {dtype}")
394-
395-
batch_size_val = hlo_s32(1)
396-
for b_v in batch_dims_vals:
397-
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
398-
399-
scalar_layout = []
400-
shape_layout = [0]
401-
workspace_layouts = [shape_layout] * len(workspace)
402-
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
403-
404-
result_types, result_shapes = mk_result_types_and_shapes(
405-
[(a_shape_vals, a_type.element_type),
406-
(batch_dims_vals + (n,), eigvals_type),
407-
(batch_dims_vals, i32_type)] + workspace
408-
)
409-
410-
return custom_call(
411-
fn,
412-
result_types=result_types,
413-
operands=[hlo_s32(1 if lower else 0), batch_size_val, ensure_hlo_s32(n), a],
414-
operand_layouts=[scalar_layout] * 3 + [layout],
415-
result_layouts=[
416-
layout,
417-
tuple(range(num_bd, -1, -1)),
418-
tuple(range(num_bd - 1, -1, -1)),
419-
] + workspace_layouts,
420-
operand_output_aliases={3: 0},
421-
result_shapes=result_shapes,
422-
).results[:3]
423-
fn = fn_base + "_ffi"
424-
if dtype == np.float32 or dtype == np.complex64:
425-
eigvals_type = ir.F32Type.get()
426-
elif dtype == np.float64 or dtype == np.complex128:
427-
eigvals_type = ir.F64Type.get()
428-
else:
429-
raise NotImplementedError(f"Unsupported dtype {dtype}")
430-
431-
result_types, result_shapes = mk_result_types_and_shapes([
432-
(a_shape_vals, a_type.element_type),
433-
(batch_dims_vals + (n,), eigvals_type),
434-
(batch_dims_vals, i32_type),
435-
])
436-
437-
return custom_call(
438-
fn,
439-
result_types=result_types,
440-
operands=[a],
441-
operand_layouts=[layout],
442-
result_layouts=[
443-
layout,
444-
tuple(range(num_bd, -1, -1)),
445-
tuple(range(num_bd - 1, -1, -1)),
446-
],
447-
operand_output_aliases={0: 0},
448-
result_shapes=result_shapes,
449-
backend_config={
450-
"uplo": _matrix_uplo_attr(lower=lower),
451-
"mode": mode,
452-
},
453-
api_version=4,
454-
).results
455-
456-
457343
# # geev: Nonsymmetric eigendecomposition (eig)
458344

459345
def geev_hlo(ctx, dtype, input, *,

0 commit comments

Comments
 (0)