From 5cd95bd08415b39ef9047e10234c44fb97844e44 Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Thu, 27 Feb 2025 11:34:25 -0300 Subject: [PATCH 1/7] feat: add lu matrix decomposition --- nx/lib/nx/lin_alg/lu.ex | 98 +++++++++++++++++++++++++++++++++++++ nx/test/nx/lin_alg_test.exs | 58 ++++++++++++++++++++++ 2 files changed, 156 insertions(+) create mode 100644 nx/lib/nx/lin_alg/lu.ex diff --git a/nx/lib/nx/lin_alg/lu.ex b/nx/lib/nx/lin_alg/lu.ex new file mode 100644 index 0000000000..e1dd4501b8 --- /dev/null +++ b/nx/lib/nx/lin_alg/lu.ex @@ -0,0 +1,98 @@ +defmodule Nx.LinAlg.LU do + import Nx.Defn + + defn lu(input_data, opts \\ []) do + opts = keyword!(opts, eps: 1.0e-10) + eps = opts[:eps] + + {p, a_prime} = lu_validate_and_pivot(input_data) + + {n, _} = Nx.shape(input_data) + + l = u = Nx.fill(a_prime, 0.0) + + {l, u, _} = + while {l, u, {a_prime, eps, n}}, j <- 0..(n - 1) do + l = Nx.put_slice(l, [j, j], Nx.tensor([[1.0]])) + + {u, _} = + while {u, {l, a_prime, eps, j, i = 0}}, Nx.less_equal(i, j) do + sum = vector_dot_slice(u[[.., j]], l[i], i) + a_ij = a_prime[i][j] + + value = a_ij - sum + + if Nx.less(Nx.abs(value), eps) do + {Nx.put_slice(u, [i, j], Nx.tensor([[0.0]])), {l, a_prime, eps, j, i + 1}} + else + {Nx.put_slice(u, [i, j], Nx.reshape(value, {1, 1})), {l, a_prime, eps, j, i + 1}} + end + end + + {l, _} = + while {l, {u, a_prime, eps, j, n, i = j + 1}}, Nx.less_equal(i, n - 1) do + sum = vector_dot_slice(u[[.., j]], l[i], i) + + a_ij = a_prime[i][j] + u_jj = u[j][j] + + value = + cond do + u_jj != 0 -> + (a_ij - sum) / u_jj + + a_ij >= sum -> + Nx.Constants.infinity() + + true -> + Nx.Constants.neg_infinity() + end + + if Nx.abs(value) < eps do + {Nx.put_slice(l, [i, j], Nx.tensor([[0]])), {u, a_prime, eps, j, n, i + 1}} + else + {Nx.put_slice(l, [i, j], Nx.reshape(value, {1, 1})), {u, a_prime, eps, j, n, i + 1}} + end + end + + {l, u, {a_prime, eps, n}} + end + + {p, l, u} + end + + defnp vector_dot_slice(u, v, last_idx) do + {n} = Nx.shape(u) + u = Nx.select(Nx.iota({n}) < last_idx, u, 0) + {n} = Nx.shape(v) + v = Nx.select(Nx.iota({n}) < last_idx, v, 0) + Nx.dot(u, v) + end + + defnp lu_validate_and_pivot(t) do + {n, _} = Nx.shape(t) + p = Nx.iota({n}) + + {p, _} = + while {p, t}, i <- 0..(n - 2) do + max_idx = + Nx.select(Nx.iota({n}) < i, 0, Nx.abs(t[[.., i]])) + |> Nx.argmax(axis: 0) + + if max_idx == i do + {p, t} + else + indices = Nx.stack([i, max_idx]) |> Nx.reshape({2, 1}) + updates = Nx.stack([p[max_idx], p[i]]) + + p = Nx.indexed_put(p, indices, updates) + + {p, Nx.take(t, p)} + end + end + + permutation = Nx.new_axis(p, 1) == Nx.iota({1, n}) + + {permutation, t[p]} + end +end diff --git a/nx/test/nx/lin_alg_test.exs b/nx/test/nx/lin_alg_test.exs index d8c8fe2bb4..e8729426bb 100644 --- a/nx/test/nx/lin_alg_test.exs +++ b/nx/test/nx/lin_alg_test.exs @@ -945,6 +945,64 @@ defmodule Nx.LinAlgTest do key end end + + test "LU decomposition of a 2x2 tensor" do + a = + Nx.tensor( + [ + [4, 3], + [6, 3] + ], + type: :f32 + ) + + {p, l, u} = Nx.LinAlg.lu(a) + + assert Nx.dot(p, Nx.dot(l, u)) + |> Nx.multiply(10) + |> Nx.round() + |> Nx.divide(10) == a + end + + test "LU decomposition of a 3x3 tensor" do + a = + Nx.tensor( + [ + [2, 3, 1], + [4, 7, 3], + [6, 18, 5] + ], + type: :f32 + ) + + {p, l, u} = Nx.LinAlg.lu(a) + + assert Nx.dot(p, Nx.dot(l, u)) + |> Nx.multiply(10) + |> Nx.round() + |> Nx.divide(10) == a + end + + test "LU decomposition of a 3D tensor" do + a = + Nx.tensor( + [ + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0] + ], + [ + [7.0, 8.0, 9.0], + [10.0, 11.0, 12.0] + ] + ], + type: :f32 + ) + + assert_raise ArgumentError, fn -> + Nx.LinAlg.lu(a) + end + end end describe "cholesky" do From e6592826dda19f3c7d993c119b7406053298e4d2 Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Thu, 27 Feb 2025 16:40:57 -0300 Subject: [PATCH 2/7] feat: use Nx.LinAlg.LU.lu implementation on Nx.LinAlg.ex --- nx/lib/nx/lin_alg.ex | 32 ++++++++++---------- nx/lib/nx/lin_alg/lu.ex | 28 ++++++++++++++++-- nx/lib/nx/lin_alg/qr.ex | 4 +-- nx/test/nx/lin_alg_test.exs | 58 ------------------------------------- 4 files changed, 43 insertions(+), 79 deletions(-) diff --git a/nx/lib/nx/lin_alg.ex b/nx/lib/nx/lin_alg.ex index edced80246..588811455c 100644 --- a/nx/lib/nx/lin_alg.ex +++ b/nx/lib/nx/lin_alg.ex @@ -1709,22 +1709,22 @@ defmodule Nx.LinAlg do ** (ArgumentError) tensor must be a square matrix or a batch of square matrices, got shape: {3, 4} """ def lu(tensor, opts \\ []) do - apply_vectorized(tensor, fn tensor -> - opts = keyword!(opts, eps: 1.0e-10) - %T{type: type, shape: shape} = tensor - - output_type = Nx.Type.to_floating(type) - {p_shape, l_shape, u_shape} = Nx.Shape.lu(shape) - names = List.duplicate(nil, tuple_size(shape)) - - impl!(tensor).lu( - {%{tensor | type: type, shape: p_shape, names: names}, - %{tensor | type: output_type, shape: l_shape, names: names}, - %{tensor | type: output_type, shape: u_shape, names: names}}, - tensor, - opts - ) - end) + opts = keyword!(opts, eps: 1.0e-10) + %T{vectorized_axes: vectorized_axes} = tensor = Nx.to_tensor(tensor) + %T{type: type, shape: shape} = tensor = Nx.devectorize(tensor) + + output_type = Nx.Type.to_floating(type) + {p_shape, l_shape, u_shape} = Nx.Shape.lu(shape) + names = List.duplicate(nil, tuple_size(shape)) + + output = + {%{tensor | type: type, shape: p_shape, names: names}, + %{tensor | type: output_type, shape: l_shape, names: names}, + %{tensor | type: output_type, shape: u_shape, names: names}} + + :lu + |> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.LU.lu/2) + |> Nx.vectorize(vectorized_axes) end @doc """ diff --git a/nx/lib/nx/lin_alg/lu.ex b/nx/lib/nx/lin_alg/lu.ex index e1dd4501b8..0c2335c722 100644 --- a/nx/lib/nx/lin_alg/lu.ex +++ b/nx/lib/nx/lin_alg/lu.ex @@ -1,13 +1,25 @@ defmodule Nx.LinAlg.LU do import Nx.Defn - defn lu(input_data, opts \\ []) do + defn lu(a, opts \\ []) do opts = keyword!(opts, eps: 1.0e-10) + + vectorized_axes = a.vectorized_axes + + a + |> Nx.revectorize([collapsed_axes: :auto], + target_shape: {Nx.axis_size(a, -2), Nx.axis_size(a, -1)} + ) + |> lu_matrix(opts) + |> revectorize_result(a.shape, vectorized_axes) + end + + defnp lu_matrix(a, opts \\ []) do eps = opts[:eps] - {p, a_prime} = lu_validate_and_pivot(input_data) + {p, a_prime} = lu_validate_and_pivot(a) - {n, _} = Nx.shape(input_data) + {n, _} = Nx.shape(a) l = u = Nx.fill(a_prime, 0.0) @@ -61,6 +73,16 @@ defmodule Nx.LinAlg.LU do {p, l, u} end + deftransformp revectorize_result({p, l, u}, shape, vectorized_axes) do + {p_shape, l_shape, u_shape} = Nx.Shape.lu(shape) + + { + Nx.revectorize(p, vectorized_axes, target_shape: p_shape), + Nx.revectorize(l, vectorized_axes, target_shape: l_shape), + Nx.revectorize(u, vectorized_axes, target_shape: u_shape) + } + end + defnp vector_dot_slice(u, v, last_idx) do {n} = Nx.shape(u) u = Nx.select(Nx.iota({n}) < last_idx, u, 0) diff --git a/nx/lib/nx/lin_alg/qr.ex b/nx/lib/nx/lin_alg/qr.ex index 9b428ea3cc..a93602a95e 100644 --- a/nx/lib/nx/lin_alg/qr.ex +++ b/nx/lib/nx/lin_alg/qr.ex @@ -16,7 +16,7 @@ defmodule Nx.LinAlg.QR do |> revectorize_result(a.shape, vectorized_axes, opts) custom_grad(result, [a], fn g -> - qr_grad(result, a, g) + qr_grad(result, g) end) end @@ -145,7 +145,7 @@ defmodule Nx.LinAlg.QR do Nx.select(selector, eye - scale * Nx.outer(v, v), eye) end - defn qr_grad({q, r}, _input, {dq, dr}) do + defn qr_grad({q, r}, {dq, dr}) do # Definition taken from https://arxiv.org/pdf/2009.10071.pdf # Equation (3) r_inv = Nx.LinAlg.invert(r) diff --git a/nx/test/nx/lin_alg_test.exs b/nx/test/nx/lin_alg_test.exs index e8729426bb..d8c8fe2bb4 100644 --- a/nx/test/nx/lin_alg_test.exs +++ b/nx/test/nx/lin_alg_test.exs @@ -945,64 +945,6 @@ defmodule Nx.LinAlgTest do key end end - - test "LU decomposition of a 2x2 tensor" do - a = - Nx.tensor( - [ - [4, 3], - [6, 3] - ], - type: :f32 - ) - - {p, l, u} = Nx.LinAlg.lu(a) - - assert Nx.dot(p, Nx.dot(l, u)) - |> Nx.multiply(10) - |> Nx.round() - |> Nx.divide(10) == a - end - - test "LU decomposition of a 3x3 tensor" do - a = - Nx.tensor( - [ - [2, 3, 1], - [4, 7, 3], - [6, 18, 5] - ], - type: :f32 - ) - - {p, l, u} = Nx.LinAlg.lu(a) - - assert Nx.dot(p, Nx.dot(l, u)) - |> Nx.multiply(10) - |> Nx.round() - |> Nx.divide(10) == a - end - - test "LU decomposition of a 3D tensor" do - a = - Nx.tensor( - [ - [ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0] - ], - [ - [7.0, 8.0, 9.0], - [10.0, 11.0, 12.0] - ] - ], - type: :f32 - ) - - assert_raise ArgumentError, fn -> - Nx.LinAlg.lu(a) - end - end end describe "cholesky" do From 3fbce5665c38f3cf50bbc014dca28dda2c736160 Mon Sep 17 00:00:00 2001 From: TomasPegado Date: Mon, 3 Mar 2025 15:13:14 -0300 Subject: [PATCH 3/7] feat: optimize lu decomposition results with custom_grad() --- nx/lib/nx/lin_alg/lu.ex | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/nx/lib/nx/lin_alg/lu.ex b/nx/lib/nx/lin_alg/lu.ex index 0c2335c722..b39eef55ae 100644 --- a/nx/lib/nx/lin_alg/lu.ex +++ b/nx/lib/nx/lin_alg/lu.ex @@ -6,12 +6,17 @@ defmodule Nx.LinAlg.LU do vectorized_axes = a.vectorized_axes - a - |> Nx.revectorize([collapsed_axes: :auto], - target_shape: {Nx.axis_size(a, -2), Nx.axis_size(a, -1)} - ) - |> lu_matrix(opts) - |> revectorize_result(a.shape, vectorized_axes) + result = + a + |> Nx.revectorize([collapsed_axes: :auto], + target_shape: {Nx.axis_size(a, -2), Nx.axis_size(a, -1)} + ) + |> lu_matrix(opts) + |> revectorize_result(a.shape, vectorized_axes) + + custom_grad(result, [a], fn g -> + lu_grad(result, g) + end) end defnp lu_matrix(a, opts \\ []) do @@ -117,4 +122,19 @@ defmodule Nx.LinAlg.LU do {permutation, t[p]} end + + defn lu_grad({l, u}, {dl, du}) do + # Definition taken from https://arxiv.org/pdf/2009.10071.pdf + # Equation (3) + r_inv = Nx.LinAlg.invert(u) + + m = Nx.dot(u, Nx.LinAlg.adjoint(du)) |> Nx.subtract(Nx.dot(Nx.LinAlg.adjoint(dl), l)) + + # copyltu + m_ltu = Nx.tril(m) |> Nx.add(m |> Nx.tril(k: -1) |> Nx.LinAlg.adjoint()) + + da = dl |> Nx.add(Nx.dot(l, m_ltu)) |> Nx.dot(Nx.LinAlg.adjoint(r_inv)) + + [da] + end end From faaf8e3cd4af7abdf597d042c69c6453fc9e9a51 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 4 Mar 2025 15:51:03 -0300 Subject: [PATCH 4/7] fix: use new implementation in tests --- nx/lib/nx/backend.ex | 1 + nx/lib/nx/binary_backend.ex | 31 ------------- nx/lib/nx/defn/expr.ex | 8 ---- nx/lib/nx/defn/grad.ex | 23 ---------- nx/lib/nx/lin_alg.ex | 22 ++++----- nx/lib/nx/lin_alg/lu.ex | 91 ++++++++++++++++++++++++------------- nx/test/nx/lin_alg_test.exs | 5 +- 7 files changed, 75 insertions(+), 106 deletions(-) diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index 1193a0b8c4..ac1a09f6db 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -175,6 +175,7 @@ defmodule Nx.Backend do top_k: 3, fft2: 3, ifft2: 3, + lu: 3, qr: 3, cholesky: 2, eigh: 3, diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 10c3b58e33..74c3247585 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -1258,26 +1258,6 @@ defmodule Nx.BinaryBackend do output_batch_groups |> Enum.with_index() |> Enum.map(fn {x, i} -> {x, rem(i, groups)} end) end - @impl true - def lu( - {%{type: p_type} = p_holder, %{type: l_type} = l_holder, %{type: u_type} = u_holder}, - %{type: input_type, shape: input_shape} = tensor, - opts - ) do - bin = to_binary(tensor) - rank = tuple_size(input_shape) - n = elem(input_shape, rank - 1) - - {p, l, u} = - bin_batch_reduce(bin, n * n, input_type, {<<>>, <<>>, <<>>}, fn matrix, - {p_acc, l_acc, u_acc} -> - {p, l, u} = B.Matrix.lu(matrix, input_type, {n, n}, p_type, l_type, u_type, opts) - {p_acc <> p, l_acc <> l, u_acc <> u} - end) - - {from_binary(p_holder, p), from_binary(l_holder, l), from_binary(u_holder, u)} - end - @impl true def triangular_solve( %{type: output_type} = out, @@ -2414,17 +2394,6 @@ defmodule Nx.BinaryBackend do bin_zip_reduce_axis(rest1, rest2, s1, s2, bin, acc, fun) end - defp bin_batch_reduce(bin, batch_size, {_, size}, acc, fun) do - batch_bit_size = batch_size * size - batches = bit_size(bin) |> div(batch_bit_size) - - for i <- 0..(batches - 1), reduce: acc do - acc -> - batch = bitstring_part(bin, i * batch_bit_size, batch_bit_size) - fun.(batch, acc) - end - end - ## Conversion helpers defp bitstring_part(bitstring, skip, size) do diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index a51d72e677..3940be47d9 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1188,14 +1188,6 @@ defmodule Nx.Defn.Expr do expr(out, context, :triangular_solve, [a, b, opts]) end - @impl true - def lu({p, l, u}, tensor, opts) do - tensor = to_expr(tensor) - context = tensor.data.context - out = %T{names: [], shape: {}, type: {:tuple, 3}} - tuple(expr(out, context, :lu, [{p, l, u}, tensor, opts]), [p, l, u]) - end - @impl true def sort(out, tensor, opts) do %{data: %{context: context}} = tensor = to_expr(tensor) diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 25b4a1178b..c18dbf0970 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -716,29 +716,6 @@ defmodule Nx.Defn.Grad do pairs end - defp grad(:lu, [{p, l, u}, input, _opts], ans, [_dp, dl, du]) do - # Definition taken from: https://sethaxen.com/blog/2021/02/differentiating-the-lu-decomposition/ - # Where dF = tril_strict(L^* . dL) + triu(dU . U^*) - # dA = P^t . (L^*)^-1 . dF . (U^*)^-1 - - {p, l, u} = Nx.Defn.Expr.tuple(ans, [p, l, u]) - - u_h = Nx.LinAlg.adjoint(u) - l_h = Nx.LinAlg.adjoint(l) - p_t = Nx.LinAlg.adjoint(p) - - lh_dl = Nx.dot(l_h, dl) - du_uh = Nx.dot(du, u_h) - - lt_inv = Nx.LinAlg.invert(l_h) - ut_inv = Nx.LinAlg.invert(u_h) - - df = lh_dl |> Nx.tril(k: -1) |> Nx.add(Nx.triu(du_uh)) - da = p_t |> Nx.dot(lt_inv) |> Nx.dot(df) |> Nx.dot(ut_inv) - - [{input, da}] - end - defp grad(:sort, [t, opts], _ans, g) do idx = Nx.argsort(t, opts) take_along_opts = Keyword.take(opts, [:axis]) diff --git a/nx/lib/nx/lin_alg.ex b/nx/lib/nx/lin_alg.ex index 588811455c..abd0fda989 100644 --- a/nx/lib/nx/lin_alg.ex +++ b/nx/lib/nx/lin_alg.ex @@ -1523,7 +1523,7 @@ defmodule Nx.LinAlg do [ [1.0, 0.0, 0.0], [0.5714285969734192, 1.0, 0.0], - [0.1428571492433548, 2.0, 1.0] + [0.1428571492433548, 2.0000009536743164, 1.0] ] > iex> u @@ -1531,7 +1531,7 @@ defmodule Nx.LinAlg do f32[3][3] [ [7.0, 8.0, 9.0], - [0.0, 0.4285714328289032, 0.8571428656578064], + [0.0, 0.4285712242126465, 0.857142448425293], [0.0, 0.0, 0.0] ] > @@ -1607,7 +1607,7 @@ defmodule Nx.LinAlg do [ [1.0, 0.0, 0.0], [0.6666666865348816, 1.0, 0.0], - [0.3333333432674408, 2.0, 1.0] + [0.3333333432674408, 1.9999992847442627, 1.0] ], [ [1.0, 0.0, 0.0], @@ -1622,8 +1622,8 @@ defmodule Nx.LinAlg do [ [ [9.0, 8.0, 7.0], - [0.0, -0.3333333432674408, -0.6666666865348816], - [0.0, 0.0, 0.0] + [0.0, -0.33333349227905273, -0.6666669845581055], + [0.0, 0.0, 5.960464477539063e-8] ], [ [-1.0, 0.0, -1.0], @@ -1638,7 +1638,7 @@ defmodule Nx.LinAlg do [ [ [9.0, 8.0, 7.0], - [6.0, 5.0, 4.0], + [6.0, 5.0, 3.999999761581421], [3.0, 2.0, 1.0] ], [ @@ -1676,7 +1676,7 @@ defmodule Nx.LinAlg do [ [1.0, 0.0, 0.0], [0.6666666865348816, 1.0, 0.0], - [0.3333333432674408, 2.0, 1.0] + [0.3333333432674408, 1.9999992847442627, 1.0] ], [ [1.0, 0.0, 0.0], @@ -1692,8 +1692,8 @@ defmodule Nx.LinAlg do [ [ [9.0, 8.0, 7.0], - [0.0, -0.3333333432674408, -0.6666666865348816], - [0.0, 0.0, 0.0] + [0.0, -0.33333349227905273, -0.6666669845581055], + [0.0, 0.0, 5.960464477539063e-8] ], [ [-1.0, 0.0, -1.0], @@ -1892,7 +1892,7 @@ defmodule Nx.LinAlg do ...> ])) #Nx.Tensor< f32 - 48.0 + 47.999996185302734 > iex> Nx.LinAlg.determinant(Nx.tensor([ @@ -1904,7 +1904,7 @@ defmodule Nx.LinAlg do ...> ])) #Nx.Tensor< f32 - 48.0 + 47.999996185302734 > iex> Nx.LinAlg.determinant(Nx.tensor([ diff --git a/nx/lib/nx/lin_alg/lu.ex b/nx/lib/nx/lin_alg/lu.ex index b39eef55ae..5f924a649a 100644 --- a/nx/lib/nx/lin_alg/lu.ex +++ b/nx/lib/nx/lin_alg/lu.ex @@ -21,55 +21,77 @@ defmodule Nx.LinAlg.LU do defnp lu_matrix(a, opts \\ []) do eps = opts[:eps] + type = Nx.Type.to_floating(a.type) + real_type = Nx.Type.to_real(type) {p, a_prime} = lu_validate_and_pivot(a) + # {p, a_prime} = {Nx.eye(a.shape, vectorized_axes: a.vectorized_axes, type: a.type), a} + a_prime = Nx.as_type(a_prime, type) {n, _} = Nx.shape(a) l = u = Nx.fill(a_prime, 0.0) + [eps, _] = Nx.broadcast_vectors([Nx.as_type(eps, real_type), l]) {l, u, _} = while {l, u, {a_prime, eps, n}}, j <- 0..(n - 1) do - l = Nx.put_slice(l, [j, j], Nx.tensor([[1.0]])) + l = Nx.put_slice(l, [j, j], Nx.tensor([[1.0]], type: type)) + [j, i, _] = Nx.broadcast_vectors([j, 0, l]) {u, _} = - while {u, {l, a_prime, eps, j, i = 0}}, Nx.less_equal(i, j) do + while {u, {l, a_prime, eps, j, i}}, i <= j do sum = vector_dot_slice(u[[.., j]], l[i], i) a_ij = a_prime[i][j] value = a_ij - sum - if Nx.less(Nx.abs(value), eps) do - {Nx.put_slice(u, [i, j], Nx.tensor([[0.0]])), {l, a_prime, eps, j, i + 1}} - else - {Nx.put_slice(u, [i, j], Nx.reshape(value, {1, 1})), {l, a_prime, eps, j, i + 1}} - end + updated_u = + if Nx.abs(value) < eps do + Nx.put_slice(u, [i, j], Nx.tensor([[0]], type: type)) + else + Nx.put_slice(u, [i, j], Nx.reshape(value, {1, 1})) + end + + {updated_u, {l, a_prime, eps, j, i + 1}} end {l, _} = - while {l, {u, a_prime, eps, j, n, i = j + 1}}, Nx.less_equal(i, n - 1) do + while {l, {u, a_prime, eps, j, n, i = j + 1}}, i <= n - 1 do sum = vector_dot_slice(u[[.., j]], l[i], i) a_ij = a_prime[i][j] u_jj = u[j][j] value = - cond do - u_jj != 0 -> - (a_ij - sum) / u_jj - - a_ij >= sum -> - Nx.Constants.infinity() - + case Nx.Type.complex?(type) do true -> - Nx.Constants.neg_infinity() + if u_jj != 0 do + (a_ij - sum) / u_jj + else + Nx.Constants.nan(real_type) + end + + false -> + cond do + u_jj != 0 -> + (a_ij - sum) / u_jj + + a_ij >= sum -> + Nx.Constants.infinity(real_type) + + true -> + Nx.Constants.neg_infinity(real_type) + end end - if Nx.abs(value) < eps do - {Nx.put_slice(l, [i, j], Nx.tensor([[0]])), {u, a_prime, eps, j, n, i + 1}} - else - {Nx.put_slice(l, [i, j], Nx.reshape(value, {1, 1})), {u, a_prime, eps, j, n, i + 1}} - end + updated_l = + if Nx.abs(value) < eps do + Nx.put_slice(l, [i, j], Nx.tensor([[0]], type: type)) + else + Nx.put_slice(l, [i, j], Nx.reshape(value, {1, 1})) + end + + {updated_l, {u, a_prime, eps, j, n, i + 1}} end {l, u, {a_prime, eps, n}} @@ -90,15 +112,15 @@ defmodule Nx.LinAlg.LU do defnp vector_dot_slice(u, v, last_idx) do {n} = Nx.shape(u) - u = Nx.select(Nx.iota({n}) < last_idx, u, 0) - {n} = Nx.shape(v) - v = Nx.select(Nx.iota({n}) < last_idx, v, 0) + selector = Nx.iota({n}) < last_idx + u = Nx.select(selector, u, 0) + v = Nx.select(selector, v, 0) Nx.dot(u, v) end defnp lu_validate_and_pivot(t) do {n, _} = Nx.shape(t) - p = Nx.iota({n}) + p = Nx.iota({n}, vectorized_axes: t.vectorized_axes) {p, _} = while {p, t}, i <- 0..(n - 2) do @@ -120,20 +142,25 @@ defmodule Nx.LinAlg.LU do permutation = Nx.new_axis(p, 1) == Nx.iota({1, n}) - {permutation, t[p]} + {Nx.as_type(permutation, t.type), t[p]} end - defn lu_grad({l, u}, {dl, du}) do + defn lu_grad({p, l, u}, {_dp, dl, du}) do # Definition taken from https://arxiv.org/pdf/2009.10071.pdf # Equation (3) - r_inv = Nx.LinAlg.invert(u) - m = Nx.dot(u, Nx.LinAlg.adjoint(du)) |> Nx.subtract(Nx.dot(Nx.LinAlg.adjoint(dl), l)) + u_h = Nx.LinAlg.adjoint(u) + l_h = Nx.LinAlg.adjoint(l) + p_t = Nx.LinAlg.adjoint(p) + + lh_dl = Nx.dot(l_h, dl) + du_uh = Nx.dot(du, u_h) - # copyltu - m_ltu = Nx.tril(m) |> Nx.add(m |> Nx.tril(k: -1) |> Nx.LinAlg.adjoint()) + lt_inv = Nx.LinAlg.invert(l_h) + ut_inv = Nx.LinAlg.invert(u_h) - da = dl |> Nx.add(Nx.dot(l, m_ltu)) |> Nx.dot(Nx.LinAlg.adjoint(r_inv)) + df = lh_dl |> Nx.tril(k: -1) |> Nx.add(Nx.triu(du_uh)) + da = p_t |> Nx.dot(lt_inv) |> Nx.dot(df) |> Nx.dot(ut_inv) [da] end diff --git a/nx/test/nx/lin_alg_test.exs b/nx/test/nx/lin_alg_test.exs index d8c8fe2bb4..404dc3c1d7 100644 --- a/nx/test/nx/lin_alg_test.exs +++ b/nx/test/nx/lin_alg_test.exs @@ -923,6 +923,7 @@ defmodule Nx.LinAlgTest do describe "lu" do test "property" do key = Nx.Random.key(System.unique_integer()) + key = Nx.Random.key(42) for _ <- 1..10, type <- [{:f, 32}, {:c, 64}], reduce: key do key -> @@ -941,7 +942,9 @@ defmodule Nx.LinAlgTest do a = Nx.dot(l_prime, [2], [0], u_prime, [1], [0]) assert {p, l, u} = Nx.LinAlg.lu(a) - assert_all_close(p |> Nx.dot([2], [0], l, [1], [0]) |> Nx.dot([2], [0], u, [1], [0]), a) + + actual = p |> Nx.dot([2], [0], l, [1], [0]) |> Nx.dot([2], [0], u, [1], [0]) + assert_all_close(actual, a) key end end From bc6aba46bcb980fac9dd2d14cb39564710be6da0 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 5 Mar 2025 12:21:05 -0300 Subject: [PATCH 5/7] fix: return transposed p --- nx/lib/nx/lin_alg/lu.ex | 4 +++- nx/test/nx/lin_alg_test.exs | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/nx/lib/nx/lin_alg/lu.ex b/nx/lib/nx/lin_alg/lu.ex index 5f924a649a..82d7d588be 100644 --- a/nx/lib/nx/lin_alg/lu.ex +++ b/nx/lib/nx/lin_alg/lu.ex @@ -140,7 +140,9 @@ defmodule Nx.LinAlg.LU do end end - permutation = Nx.new_axis(p, 1) == Nx.iota({1, n}) + # The comparison order here is deliberate, because if + # we use p == iota instead, we get the inverse/transposed permutation. + permutation = Nx.iota({n, 1}) == Nx.new_axis(p, 0) {Nx.as_type(permutation, t.type), t[p]} end diff --git a/nx/test/nx/lin_alg_test.exs b/nx/test/nx/lin_alg_test.exs index 404dc3c1d7..4e146cf53c 100644 --- a/nx/test/nx/lin_alg_test.exs +++ b/nx/test/nx/lin_alg_test.exs @@ -923,7 +923,6 @@ defmodule Nx.LinAlgTest do describe "lu" do test "property" do key = Nx.Random.key(System.unique_integer()) - key = Nx.Random.key(42) for _ <- 1..10, type <- [{:f, 32}, {:c, 64}], reduce: key do key -> From b0cdc4e41874c43db0075733258ed55046bd0b8a Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 5 Mar 2025 12:23:06 -0300 Subject: [PATCH 6/7] chore: remove unused functions from BinaryBackend.Matrix --- nx/lib/nx/binary_backend/matrix.ex | 135 ----------------------------- 1 file changed, 135 deletions(-) diff --git a/nx/lib/nx/binary_backend/matrix.ex b/nx/lib/nx/binary_backend/matrix.ex index afb55fb668..f989dd748c 100644 --- a/nx/lib/nx/binary_backend/matrix.ex +++ b/nx/lib/nx/binary_backend/matrix.ex @@ -1,8 +1,6 @@ defmodule Nx.BinaryBackend.Matrix do @moduledoc false use Complex.Kernel - import Kernel, except: [abs: 1] - import Complex, only: [abs: 1] import Nx.Shared @@ -116,107 +114,6 @@ defmodule Nx.BinaryBackend.Matrix do defp do_ts([], [], _idx, acc), do: acc - def lu(input_data, input_type, {n, n} = input_shape, p_type, l_type, u_type, opts) do - a = binary_to_matrix(input_data, input_type, input_shape) - eps = opts[:eps] - - {p, a_prime} = lu_validate_and_pivot(a, n) - - # We'll work with linear indices because of the way each matrix - # needs to be updated/accessed - zeros_matrix = List.duplicate(List.duplicate(0, n), n) - - {l, u} = - for j <- 0..(n - 1), reduce: {zeros_matrix, zeros_matrix} do - {l, u} -> - l = replace_matrix_element(l, j, j, 1.0) - - u = - for i <- 0..j, reduce: u do - u -> - u_slice = slice_matrix(u, [0, j], [i, 1]) - l_slice = slice_matrix(l, [i, 0], [1, i]) - sum = dot_matrix(u_slice, l_slice) - [a_ij] = get_matrix_elements(a_prime, [[i, j]]) - - value = a_ij - sum - - if abs(value) < eps do - replace_matrix_element(u, i, j, 0) - else - replace_matrix_element(u, i, j, value) - end - end - - l = - for i <- j..(n - 1), i != j, reduce: l do - l -> - u_slice = slice_matrix(u, [0, j], [i, 1]) - l_slice = slice_matrix(l, [i, 0], [1, i]) - sum = dot_matrix(u_slice, l_slice) - - [a_ij] = get_matrix_elements(a_prime, [[i, j]]) - [u_jj] = get_matrix_elements(u, [[j, j]]) - - value = - cond do - u_jj != 0 -> - (a_ij - sum) / u_jj - - a_ij >= sum -> - :infinity - - true -> - :neg_infinity - end - - if abs(value) < eps do - replace_matrix_element(l, i, j, 0) - else - replace_matrix_element(l, i, j, value) - end - end - - {l, u} - end - - # Transpose because since P is orthogonal, inv(P) = tranpose(P) - # and we want to return P such that A = P.L.U - {p |> transpose_matrix() |> matrix_to_binary(p_type), - l |> approximate_zeros(eps) |> matrix_to_binary(l_type), - u |> approximate_zeros(eps) |> matrix_to_binary(u_type)} - end - - defp lu_validate_and_pivot(a, n) do - # pivots a tensor so that the biggest elements of each column lie on the diagonal. - # if any of the diagonal elements ends up being 0, raises an ArgumentError - - identity = - Enum.map(0..(n - 1), fn i -> Enum.map(0..(n - 1), fn j -> if i == j, do: 1, else: 0 end) end) - - # For each row, find the max value by column. - # If its index (max_idx) is not in the diagonal (i.e. j != max_idx) - # we need to swap rows j and max_idx in both the permutation matrix - # and in the a matrix. - Enum.reduce(0..(n - 2), {identity, a}, fn j, {p, a} -> - [max_idx | _] = - Enum.sort_by(j..(n - 1), fn i -> a |> Enum.at(i) |> Enum.at(j) |> abs() end, &>=/2) - - if max_idx == j do - {p, a} - else - p_row = Enum.at(p, max_idx) - p_j = Enum.at(p, j) - p = p |> List.replace_at(max_idx, p_j) |> List.replace_at(j, p_row) - - a_row = Enum.at(a, max_idx) - a_j = Enum.at(a, j) - a = a |> List.replace_at(max_idx, a_j) |> List.replace_at(j, a_row) - {p, a} - end - end) - end - ## Matrix (2-D array) manipulation defp dot_matrix([], _), do: 0 @@ -279,41 +176,9 @@ defmodule Nx.BinaryBackend.Matrix do |> Enum.chunk_every(num_cols) end - defp slice_matrix(a, [row_start, col_start], [row_length, col_length]) do - a - |> Enum.slice(row_start, row_length) - |> Enum.flat_map(&Enum.slice(&1, col_start, col_length)) - end - defp get_matrix_column(m, col) do Enum.map(m, fn row -> Enum.at(row, col) end) end - - defp get_matrix_elements(m, row_col_pairs) do - Enum.map(row_col_pairs, fn [row, col] -> - m - |> Enum.at(row, []) - |> Enum.at(col) - |> case do - nil -> raise ArgumentError, "invalid index [#{row},#{col}] for matrix" - item -> item - end - end) - end - - defp replace_matrix_element(m, row, col, value) do - updated = m |> Enum.at(row) |> List.replace_at(col, value) - List.replace_at(m, row, updated) - end - - defp approximate_zeros(matrix, tol) do - do_round = fn x -> if Complex.abs(x) < tol, do: 0 * x, else: x end - - Enum.map(matrix, fn - row when is_list(row) -> Enum.map(row, do_round) - e -> do_round.(e) - end) - end end From 9e58ea19a1a05a29c25d12dabc1a353f10aeeaee Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 5 Mar 2025 12:28:19 -0300 Subject: [PATCH 7/7] fix: do not define lu in EXLA backend --- exla/lib/exla/backend.ex | 1 - 1 file changed, 1 deletion(-) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index f242780fc9..6eb23692cd 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -406,7 +406,6 @@ defmodule EXLA.Backend do [:tensor, :source, :init_value]}, {:indexed_add, [:tensor, :indices, :updates, :opts], [:tensor, :indices, :updates]}, {:indexed_put, [:tensor, :indices, :updates, :opts], [:tensor, :indices, :updates]}, - {:lu, [:tensor, :opts], [:tensor]}, {:triangular_solve, [:a, :b, :opts], [:a, :b]}, {:fft, [:tensor, :opts], [:tensor]}, {:ifft, [:tensor, :opts], [:tensor]}