Skip to content

Commit 27b4026

Browse files
Lu matrix decomposition (#1587)
Co-authored-by: Paulo Valente <[email protected]>
1 parent 1a237e1 commit 27b4026

File tree

10 files changed

+202
-228
lines changed

10 files changed

+202
-228
lines changed

exla/lib/exla/backend.ex

-1
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,6 @@ defmodule EXLA.Backend do
381381
[:tensor, :source, :init_value]},
382382
{:indexed_add, [:tensor, :indices, :updates, :opts], [:tensor, :indices, :updates]},
383383
{:indexed_put, [:tensor, :indices, :updates, :opts], [:tensor, :indices, :updates]},
384-
{:lu, [:tensor, :opts], [:tensor]},
385384
{:triangular_solve, [:a, :b, :opts], [:a, :b]},
386385
{:fft, [:tensor, :opts], [:tensor]},
387386
{:ifft, [:tensor, :opts], [:tensor]}

nx/lib/nx/backend.ex

+1
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ defmodule Nx.Backend do
175175
top_k: 3,
176176
fft2: 3,
177177
ifft2: 3,
178+
lu: 3,
178179
qr: 3,
179180
cholesky: 2,
180181
eigh: 3,

nx/lib/nx/binary_backend.ex

-31
Original file line numberDiff line numberDiff line change
@@ -1258,26 +1258,6 @@ defmodule Nx.BinaryBackend do
12581258
output_batch_groups |> Enum.with_index() |> Enum.map(fn {x, i} -> {x, rem(i, groups)} end)
12591259
end
12601260

1261-
@impl true
1262-
def lu(
1263-
{%{type: p_type} = p_holder, %{type: l_type} = l_holder, %{type: u_type} = u_holder},
1264-
%{type: input_type, shape: input_shape} = tensor,
1265-
opts
1266-
) do
1267-
bin = to_binary(tensor)
1268-
rank = tuple_size(input_shape)
1269-
n = elem(input_shape, rank - 1)
1270-
1271-
{p, l, u} =
1272-
bin_batch_reduce(bin, n * n, input_type, {<<>>, <<>>, <<>>}, fn matrix,
1273-
{p_acc, l_acc, u_acc} ->
1274-
{p, l, u} = B.Matrix.lu(matrix, input_type, {n, n}, p_type, l_type, u_type, opts)
1275-
{p_acc <> p, l_acc <> l, u_acc <> u}
1276-
end)
1277-
1278-
{from_binary(p_holder, p), from_binary(l_holder, l), from_binary(u_holder, u)}
1279-
end
1280-
12811261
@impl true
12821262
def triangular_solve(
12831263
%{type: output_type} = out,
@@ -2414,17 +2394,6 @@ defmodule Nx.BinaryBackend do
24142394
bin_zip_reduce_axis(rest1, rest2, s1, s2, bin, acc, fun)
24152395
end
24162396

2417-
defp bin_batch_reduce(bin, batch_size, {_, size}, acc, fun) do
2418-
batch_bit_size = batch_size * size
2419-
batches = bit_size(bin) |> div(batch_bit_size)
2420-
2421-
for i <- 0..(batches - 1), reduce: acc do
2422-
acc ->
2423-
batch = bitstring_part(bin, i * batch_bit_size, batch_bit_size)
2424-
fun.(batch, acc)
2425-
end
2426-
end
2427-
24282397
## Conversion helpers
24292398

24302399
defp bitstring_part(bitstring, skip, size) do

nx/lib/nx/binary_backend/matrix.ex

-135
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
defmodule Nx.BinaryBackend.Matrix do
22
@moduledoc false
33
use Complex.Kernel
4-
import Kernel, except: [abs: 1]
5-
import Complex, only: [abs: 1]
64

75
import Nx.Shared
86

@@ -116,107 +114,6 @@ defmodule Nx.BinaryBackend.Matrix do
116114

117115
defp do_ts([], [], _idx, acc), do: acc
118116

119-
def lu(input_data, input_type, {n, n} = input_shape, p_type, l_type, u_type, opts) do
120-
a = binary_to_matrix(input_data, input_type, input_shape)
121-
eps = opts[:eps]
122-
123-
{p, a_prime} = lu_validate_and_pivot(a, n)
124-
125-
# We'll work with linear indices because of the way each matrix
126-
# needs to be updated/accessed
127-
zeros_matrix = List.duplicate(List.duplicate(0, n), n)
128-
129-
{l, u} =
130-
for j <- 0..(n - 1), reduce: {zeros_matrix, zeros_matrix} do
131-
{l, u} ->
132-
l = replace_matrix_element(l, j, j, 1.0)
133-
134-
u =
135-
for i <- 0..j, reduce: u do
136-
u ->
137-
u_slice = slice_matrix(u, [0, j], [i, 1])
138-
l_slice = slice_matrix(l, [i, 0], [1, i])
139-
sum = dot_matrix(u_slice, l_slice)
140-
[a_ij] = get_matrix_elements(a_prime, [[i, j]])
141-
142-
value = a_ij - sum
143-
144-
if abs(value) < eps do
145-
replace_matrix_element(u, i, j, 0)
146-
else
147-
replace_matrix_element(u, i, j, value)
148-
end
149-
end
150-
151-
l =
152-
for i <- j..(n - 1), i != j, reduce: l do
153-
l ->
154-
u_slice = slice_matrix(u, [0, j], [i, 1])
155-
l_slice = slice_matrix(l, [i, 0], [1, i])
156-
sum = dot_matrix(u_slice, l_slice)
157-
158-
[a_ij] = get_matrix_elements(a_prime, [[i, j]])
159-
[u_jj] = get_matrix_elements(u, [[j, j]])
160-
161-
value =
162-
cond do
163-
u_jj != 0 ->
164-
(a_ij - sum) / u_jj
165-
166-
a_ij >= sum ->
167-
:infinity
168-
169-
true ->
170-
:neg_infinity
171-
end
172-
173-
if abs(value) < eps do
174-
replace_matrix_element(l, i, j, 0)
175-
else
176-
replace_matrix_element(l, i, j, value)
177-
end
178-
end
179-
180-
{l, u}
181-
end
182-
183-
# Transpose because since P is orthogonal, inv(P) = tranpose(P)
184-
# and we want to return P such that A = P.L.U
185-
{p |> transpose_matrix() |> matrix_to_binary(p_type),
186-
l |> approximate_zeros(eps) |> matrix_to_binary(l_type),
187-
u |> approximate_zeros(eps) |> matrix_to_binary(u_type)}
188-
end
189-
190-
defp lu_validate_and_pivot(a, n) do
191-
# pivots a tensor so that the biggest elements of each column lie on the diagonal.
192-
# if any of the diagonal elements ends up being 0, raises an ArgumentError
193-
194-
identity =
195-
Enum.map(0..(n - 1), fn i -> Enum.map(0..(n - 1), fn j -> if i == j, do: 1, else: 0 end) end)
196-
197-
# For each row, find the max value by column.
198-
# If its index (max_idx) is not in the diagonal (i.e. j != max_idx)
199-
# we need to swap rows j and max_idx in both the permutation matrix
200-
# and in the a matrix.
201-
Enum.reduce(0..(n - 2), {identity, a}, fn j, {p, a} ->
202-
[max_idx | _] =
203-
Enum.sort_by(j..(n - 1), fn i -> a |> Enum.at(i) |> Enum.at(j) |> abs() end, &>=/2)
204-
205-
if max_idx == j do
206-
{p, a}
207-
else
208-
p_row = Enum.at(p, max_idx)
209-
p_j = Enum.at(p, j)
210-
p = p |> List.replace_at(max_idx, p_j) |> List.replace_at(j, p_row)
211-
212-
a_row = Enum.at(a, max_idx)
213-
a_j = Enum.at(a, j)
214-
a = a |> List.replace_at(max_idx, a_j) |> List.replace_at(j, a_row)
215-
{p, a}
216-
end
217-
end)
218-
end
219-
220117
## Matrix (2-D array) manipulation
221118

222119
defp dot_matrix([], _), do: 0
@@ -279,41 +176,9 @@ defmodule Nx.BinaryBackend.Matrix do
279176
|> Enum.chunk_every(num_cols)
280177
end
281178

282-
defp slice_matrix(a, [row_start, col_start], [row_length, col_length]) do
283-
a
284-
|> Enum.slice(row_start, row_length)
285-
|> Enum.flat_map(&Enum.slice(&1, col_start, col_length))
286-
end
287-
288179
defp get_matrix_column(m, col) do
289180
Enum.map(m, fn row ->
290181
Enum.at(row, col)
291182
end)
292183
end
293-
294-
defp get_matrix_elements(m, row_col_pairs) do
295-
Enum.map(row_col_pairs, fn [row, col] ->
296-
m
297-
|> Enum.at(row, [])
298-
|> Enum.at(col)
299-
|> case do
300-
nil -> raise ArgumentError, "invalid index [#{row},#{col}] for matrix"
301-
item -> item
302-
end
303-
end)
304-
end
305-
306-
defp replace_matrix_element(m, row, col, value) do
307-
updated = m |> Enum.at(row) |> List.replace_at(col, value)
308-
List.replace_at(m, row, updated)
309-
end
310-
311-
defp approximate_zeros(matrix, tol) do
312-
do_round = fn x -> if Complex.abs(x) < tol, do: 0 * x, else: x end
313-
314-
Enum.map(matrix, fn
315-
row when is_list(row) -> Enum.map(row, do_round)
316-
e -> do_round.(e)
317-
end)
318-
end
319184
end

nx/lib/nx/defn/expr.ex

-8
Original file line numberDiff line numberDiff line change
@@ -1188,14 +1188,6 @@ defmodule Nx.Defn.Expr do
11881188
expr(out, context, :triangular_solve, [a, b, opts])
11891189
end
11901190

1191-
@impl true
1192-
def lu({p, l, u}, tensor, opts) do
1193-
tensor = to_expr(tensor)
1194-
context = tensor.data.context
1195-
out = %T{names: [], shape: {}, type: {:tuple, 3}}
1196-
tuple(expr(out, context, :lu, [{p, l, u}, tensor, opts]), [p, l, u])
1197-
end
1198-
11991191
@impl true
12001192
def sort(out, tensor, opts) do
12011193
%{data: %{context: context}} = tensor = to_expr(tensor)

nx/lib/nx/defn/grad.ex

-23
Original file line numberDiff line numberDiff line change
@@ -716,29 +716,6 @@ defmodule Nx.Defn.Grad do
716716
pairs
717717
end
718718

719-
defp grad(:lu, [{p, l, u}, input, _opts], ans, [_dp, dl, du]) do
720-
# Definition taken from: https://sethaxen.com/blog/2021/02/differentiating-the-lu-decomposition/
721-
# Where dF = tril_strict(L^* . dL) + triu(dU . U^*)
722-
# dA = P^t . (L^*)^-1 . dF . (U^*)^-1
723-
724-
{p, l, u} = Nx.Defn.Expr.tuple(ans, [p, l, u])
725-
726-
u_h = Nx.LinAlg.adjoint(u)
727-
l_h = Nx.LinAlg.adjoint(l)
728-
p_t = Nx.LinAlg.adjoint(p)
729-
730-
lh_dl = Nx.dot(l_h, dl)
731-
du_uh = Nx.dot(du, u_h)
732-
733-
lt_inv = Nx.LinAlg.invert(l_h)
734-
ut_inv = Nx.LinAlg.invert(u_h)
735-
736-
df = lh_dl |> Nx.tril(k: -1) |> Nx.add(Nx.triu(du_uh))
737-
da = p_t |> Nx.dot(lt_inv) |> Nx.dot(df) |> Nx.dot(ut_inv)
738-
739-
[{input, da}]
740-
end
741-
742719
defp grad(:sort, [t, opts], _ans, g) do
743720
idx = Nx.argsort(t, opts)
744721
take_along_opts = Keyword.take(opts, [:axis])

nx/lib/nx/lin_alg.ex

+27-27
Original file line numberDiff line numberDiff line change
@@ -1523,15 +1523,15 @@ defmodule Nx.LinAlg do
15231523
[
15241524
[1.0, 0.0, 0.0],
15251525
[0.5714285969734192, 1.0, 0.0],
1526-
[0.1428571492433548, 2.0, 1.0]
1526+
[0.1428571492433548, 2.0000009536743164, 1.0]
15271527
]
15281528
>
15291529
iex> u
15301530
#Nx.Tensor<
15311531
f32[3][3]
15321532
[
15331533
[7.0, 8.0, 9.0],
1534-
[0.0, 0.4285714328289032, 0.8571428656578064],
1534+
[0.0, 0.4285712242126465, 0.857142448425293],
15351535
[0.0, 0.0, 0.0]
15361536
]
15371537
>
@@ -1607,7 +1607,7 @@ defmodule Nx.LinAlg do
16071607
[
16081608
[1.0, 0.0, 0.0],
16091609
[0.6666666865348816, 1.0, 0.0],
1610-
[0.3333333432674408, 2.0, 1.0]
1610+
[0.3333333432674408, 1.9999992847442627, 1.0]
16111611
],
16121612
[
16131613
[1.0, 0.0, 0.0],
@@ -1622,8 +1622,8 @@ defmodule Nx.LinAlg do
16221622
[
16231623
[
16241624
[9.0, 8.0, 7.0],
1625-
[0.0, -0.3333333432674408, -0.6666666865348816],
1626-
[0.0, 0.0, 0.0]
1625+
[0.0, -0.33333349227905273, -0.6666669845581055],
1626+
[0.0, 0.0, 5.960464477539063e-8]
16271627
],
16281628
[
16291629
[-1.0, 0.0, -1.0],
@@ -1638,7 +1638,7 @@ defmodule Nx.LinAlg do
16381638
[
16391639
[
16401640
[9.0, 8.0, 7.0],
1641-
[6.0, 5.0, 4.0],
1641+
[6.0, 5.0, 3.999999761581421],
16421642
[3.0, 2.0, 1.0]
16431643
],
16441644
[
@@ -1676,7 +1676,7 @@ defmodule Nx.LinAlg do
16761676
[
16771677
[1.0, 0.0, 0.0],
16781678
[0.6666666865348816, 1.0, 0.0],
1679-
[0.3333333432674408, 2.0, 1.0]
1679+
[0.3333333432674408, 1.9999992847442627, 1.0]
16801680
],
16811681
[
16821682
[1.0, 0.0, 0.0],
@@ -1692,8 +1692,8 @@ defmodule Nx.LinAlg do
16921692
[
16931693
[
16941694
[9.0, 8.0, 7.0],
1695-
[0.0, -0.3333333432674408, -0.6666666865348816],
1696-
[0.0, 0.0, 0.0]
1695+
[0.0, -0.33333349227905273, -0.6666669845581055],
1696+
[0.0, 0.0, 5.960464477539063e-8]
16971697
],
16981698
[
16991699
[-1.0, 0.0, -1.0],
@@ -1709,22 +1709,22 @@ defmodule Nx.LinAlg do
17091709
** (ArgumentError) tensor must be a square matrix or a batch of square matrices, got shape: {3, 4}
17101710
"""
17111711
def lu(tensor, opts \\ []) do
1712-
apply_vectorized(tensor, fn tensor ->
1713-
opts = keyword!(opts, eps: 1.0e-10)
1714-
%T{type: type, shape: shape} = tensor
1715-
1716-
output_type = Nx.Type.to_floating(type)
1717-
{p_shape, l_shape, u_shape} = Nx.Shape.lu(shape)
1718-
names = List.duplicate(nil, tuple_size(shape))
1719-
1720-
impl!(tensor).lu(
1721-
{%{tensor | type: type, shape: p_shape, names: names},
1722-
%{tensor | type: output_type, shape: l_shape, names: names},
1723-
%{tensor | type: output_type, shape: u_shape, names: names}},
1724-
tensor,
1725-
opts
1726-
)
1727-
end)
1712+
opts = keyword!(opts, eps: 1.0e-10)
1713+
%T{vectorized_axes: vectorized_axes} = tensor = Nx.to_tensor(tensor)
1714+
%T{type: type, shape: shape} = tensor = Nx.devectorize(tensor)
1715+
1716+
output_type = Nx.Type.to_floating(type)
1717+
{p_shape, l_shape, u_shape} = Nx.Shape.lu(shape)
1718+
names = List.duplicate(nil, tuple_size(shape))
1719+
1720+
output =
1721+
{%{tensor | type: type, shape: p_shape, names: names},
1722+
%{tensor | type: output_type, shape: l_shape, names: names},
1723+
%{tensor | type: output_type, shape: u_shape, names: names}}
1724+
1725+
:lu
1726+
|> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.LU.lu/2)
1727+
|> Nx.vectorize(vectorized_axes)
17281728
end
17291729

17301730
@doc """
@@ -1892,7 +1892,7 @@ defmodule Nx.LinAlg do
18921892
...> ]))
18931893
#Nx.Tensor<
18941894
f32
1895-
48.0
1895+
47.999996185302734
18961896
>
18971897
18981898
iex> Nx.LinAlg.determinant(Nx.tensor([
@@ -1904,7 +1904,7 @@ defmodule Nx.LinAlg do
19041904
...> ]))
19051905
#Nx.Tensor<
19061906
f32
1907-
48.0
1907+
47.999996185302734
19081908
>
19091909
19101910
iex> Nx.LinAlg.determinant(Nx.tensor([

0 commit comments

Comments
 (0)