|
1 | 1 | defmodule Nx.BinaryBackend.Matrix do
|
2 | 2 | @moduledoc false
|
3 | 3 | use Complex.Kernel
|
4 |
| - import Kernel, except: [abs: 1] |
5 |
| - import Complex, only: [abs: 1] |
6 | 4 |
|
7 | 5 | import Nx.Shared
|
8 | 6 |
|
@@ -116,107 +114,6 @@ defmodule Nx.BinaryBackend.Matrix do
|
116 | 114 |
|
117 | 115 | defp do_ts([], [], _idx, acc), do: acc
|
118 | 116 |
|
119 |
| - def lu(input_data, input_type, {n, n} = input_shape, p_type, l_type, u_type, opts) do |
120 |
| - a = binary_to_matrix(input_data, input_type, input_shape) |
121 |
| - eps = opts[:eps] |
122 |
| - |
123 |
| - {p, a_prime} = lu_validate_and_pivot(a, n) |
124 |
| - |
125 |
| - # We'll work with linear indices because of the way each matrix |
126 |
| - # needs to be updated/accessed |
127 |
| - zeros_matrix = List.duplicate(List.duplicate(0, n), n) |
128 |
| - |
129 |
| - {l, u} = |
130 |
| - for j <- 0..(n - 1), reduce: {zeros_matrix, zeros_matrix} do |
131 |
| - {l, u} -> |
132 |
| - l = replace_matrix_element(l, j, j, 1.0) |
133 |
| - |
134 |
| - u = |
135 |
| - for i <- 0..j, reduce: u do |
136 |
| - u -> |
137 |
| - u_slice = slice_matrix(u, [0, j], [i, 1]) |
138 |
| - l_slice = slice_matrix(l, [i, 0], [1, i]) |
139 |
| - sum = dot_matrix(u_slice, l_slice) |
140 |
| - [a_ij] = get_matrix_elements(a_prime, [[i, j]]) |
141 |
| - |
142 |
| - value = a_ij - sum |
143 |
| - |
144 |
| - if abs(value) < eps do |
145 |
| - replace_matrix_element(u, i, j, 0) |
146 |
| - else |
147 |
| - replace_matrix_element(u, i, j, value) |
148 |
| - end |
149 |
| - end |
150 |
| - |
151 |
| - l = |
152 |
| - for i <- j..(n - 1), i != j, reduce: l do |
153 |
| - l -> |
154 |
| - u_slice = slice_matrix(u, [0, j], [i, 1]) |
155 |
| - l_slice = slice_matrix(l, [i, 0], [1, i]) |
156 |
| - sum = dot_matrix(u_slice, l_slice) |
157 |
| - |
158 |
| - [a_ij] = get_matrix_elements(a_prime, [[i, j]]) |
159 |
| - [u_jj] = get_matrix_elements(u, [[j, j]]) |
160 |
| - |
161 |
| - value = |
162 |
| - cond do |
163 |
| - u_jj != 0 -> |
164 |
| - (a_ij - sum) / u_jj |
165 |
| - |
166 |
| - a_ij >= sum -> |
167 |
| - :infinity |
168 |
| - |
169 |
| - true -> |
170 |
| - :neg_infinity |
171 |
| - end |
172 |
| - |
173 |
| - if abs(value) < eps do |
174 |
| - replace_matrix_element(l, i, j, 0) |
175 |
| - else |
176 |
| - replace_matrix_element(l, i, j, value) |
177 |
| - end |
178 |
| - end |
179 |
| - |
180 |
| - {l, u} |
181 |
| - end |
182 |
| - |
183 |
| - # Transpose because since P is orthogonal, inv(P) = tranpose(P) |
184 |
| - # and we want to return P such that A = P.L.U |
185 |
| - {p |> transpose_matrix() |> matrix_to_binary(p_type), |
186 |
| - l |> approximate_zeros(eps) |> matrix_to_binary(l_type), |
187 |
| - u |> approximate_zeros(eps) |> matrix_to_binary(u_type)} |
188 |
| - end |
189 |
| - |
190 |
| - defp lu_validate_and_pivot(a, n) do |
191 |
| - # pivots a tensor so that the biggest elements of each column lie on the diagonal. |
192 |
| - # if any of the diagonal elements ends up being 0, raises an ArgumentError |
193 |
| - |
194 |
| - identity = |
195 |
| - Enum.map(0..(n - 1), fn i -> Enum.map(0..(n - 1), fn j -> if i == j, do: 1, else: 0 end) end) |
196 |
| - |
197 |
| - # For each row, find the max value by column. |
198 |
| - # If its index (max_idx) is not in the diagonal (i.e. j != max_idx) |
199 |
| - # we need to swap rows j and max_idx in both the permutation matrix |
200 |
| - # and in the a matrix. |
201 |
| - Enum.reduce(0..(n - 2), {identity, a}, fn j, {p, a} -> |
202 |
| - [max_idx | _] = |
203 |
| - Enum.sort_by(j..(n - 1), fn i -> a |> Enum.at(i) |> Enum.at(j) |> abs() end, &>=/2) |
204 |
| - |
205 |
| - if max_idx == j do |
206 |
| - {p, a} |
207 |
| - else |
208 |
| - p_row = Enum.at(p, max_idx) |
209 |
| - p_j = Enum.at(p, j) |
210 |
| - p = p |> List.replace_at(max_idx, p_j) |> List.replace_at(j, p_row) |
211 |
| - |
212 |
| - a_row = Enum.at(a, max_idx) |
213 |
| - a_j = Enum.at(a, j) |
214 |
| - a = a |> List.replace_at(max_idx, a_j) |> List.replace_at(j, a_row) |
215 |
| - {p, a} |
216 |
| - end |
217 |
| - end) |
218 |
| - end |
219 |
| - |
220 | 117 | ## Matrix (2-D array) manipulation
|
221 | 118 |
|
222 | 119 | defp dot_matrix([], _), do: 0
|
@@ -279,41 +176,9 @@ defmodule Nx.BinaryBackend.Matrix do
|
279 | 176 | |> Enum.chunk_every(num_cols)
|
280 | 177 | end
|
281 | 178 |
|
282 |
| - defp slice_matrix(a, [row_start, col_start], [row_length, col_length]) do |
283 |
| - a |
284 |
| - |> Enum.slice(row_start, row_length) |
285 |
| - |> Enum.flat_map(&Enum.slice(&1, col_start, col_length)) |
286 |
| - end |
287 |
| - |
288 | 179 | defp get_matrix_column(m, col) do
|
289 | 180 | Enum.map(m, fn row ->
|
290 | 181 | Enum.at(row, col)
|
291 | 182 | end)
|
292 | 183 | end
|
293 |
| - |
294 |
| - defp get_matrix_elements(m, row_col_pairs) do |
295 |
| - Enum.map(row_col_pairs, fn [row, col] -> |
296 |
| - m |
297 |
| - |> Enum.at(row, []) |
298 |
| - |> Enum.at(col) |
299 |
| - |> case do |
300 |
| - nil -> raise ArgumentError, "invalid index [#{row},#{col}] for matrix" |
301 |
| - item -> item |
302 |
| - end |
303 |
| - end) |
304 |
| - end |
305 |
| - |
306 |
| - defp replace_matrix_element(m, row, col, value) do |
307 |
| - updated = m |> Enum.at(row) |> List.replace_at(col, value) |
308 |
| - List.replace_at(m, row, updated) |
309 |
| - end |
310 |
| - |
311 |
| - defp approximate_zeros(matrix, tol) do |
312 |
| - do_round = fn x -> if Complex.abs(x) < tol, do: 0 * x, else: x end |
313 |
| - |
314 |
| - Enum.map(matrix, fn |
315 |
| - row when is_list(row) -> Enum.map(row, do_round) |
316 |
| - e -> do_round.(e) |
317 |
| - end) |
318 |
| - end |
319 | 184 | end
|
0 commit comments