Skip to content

Commit bc52756

Browse files
Add some matrix valued tests
1 parent 03e4b20 commit bc52756

File tree

3 files changed

+96
-6
lines changed

3 files changed

+96
-6
lines changed

src/OptimizationDIExt.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ function instantiate_function(
384384
end
385385

386386
if f.cons_vjp === nothing && cons_vjp == true && cons !== nothing
387-
extras_pullback = prepare_pullback(cons, adtype, x)
387+
extras_pullback = prepare_pullback(cons, adtype, x, ones(eltype(x), num_cons))
388388
function cons_vjp!(θ, v)
389389
return pullback(cons, adtype, θ, v, extras_pullback)
390390
end
@@ -395,7 +395,8 @@ function instantiate_function(
395395
end
396396

397397
if f.cons_jvp === nothing && cons_jvp == true && cons !== nothing
398-
extras_pushforward = prepare_pushforward(cons, adtype, x)
398+
extras_pushforward = prepare_pushforward(
399+
cons, adtype, x, ones(eltype(x), length(x)))
399400
function cons_jvp!(θ, v)
400401
return pushforward(cons, adtype, θ, v, extras_pushforward)
401402
end

src/OptimizationDISparseExt.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ function instantiate_function(
236236
end
237237

238238
if f.cons_vjp === nothing && cons_vjp == true
239-
extras_pullback = prepare_pullback(cons_oop, adtype, x)
239+
extras_pullback = prepare_pullback(cons_oop, adtype, x, ones(eltype(x), num_cons))
240240
function cons_vjp!(J, θ, v)
241241
pullback!(cons_oop, J, adtype.dense_ad, θ, v, extras_pullback)
242242
end
@@ -247,7 +247,8 @@ function instantiate_function(
247247
end
248248

249249
if f.cons_jvp === nothing && cons_jvp == true
250-
extras_pushforward = prepare_pushforward(cons_oop, adtype, x)
250+
extras_pushforward = prepare_pushforward(
251+
cons_oop, adtype, x, ones(eltype(x), length(x)))
251252
function cons_jvp!(J, θ, v)
252253
pushforward!(cons_oop, J, adtype.dense_ad, θ, v, extras_pushforward)
253254
end
@@ -469,7 +470,7 @@ function instantiate_function(
469470
end
470471

471472
if f.cons_vjp === nothing && cons_vjp == true
472-
extras_pullback = prepare_pullback(cons, adtype, x)
473+
extras_pullback = prepare_pullback(cons, adtype, x, ones(eltype(x), num_cons))
473474
function cons_vjp!(θ, v)
474475
pullback(cons, adtype, θ, v, extras_pullback)
475476
end
@@ -480,7 +481,8 @@ function instantiate_function(
480481
end
481482

482483
if f.cons_jvp === nothing && cons_jvp == true
483-
extras_pushforward = prepare_pushforward(cons, adtype, x)
484+
extras_pushforward = prepare_pushforward(
485+
cons, adtype, x, ones(eltype(x), length(x)))
484486
function cons_jvp!(θ, v)
485487
pushforward(cons, adtype, θ, v, extras_pushforward)
486488
end

test/matrixvalued.jl

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
using OptimizationBase, LinearAlgebra, ForwardDiff, Zygote, FiniteDiff,
2+
DifferentiationInterface, SparseConnectivityTracer
3+
using Test, ReverseDiff
4+
5+
@testset "Matrix Valued" begin
6+
for adtype in [AutoForwardDiff(), AutoZygote(), AutoFiniteDiff(),
7+
AutoSparse(AutoForwardDiff(), sparsity_detector = TracerLocalSparsityDetector()),
8+
AutoSparse(AutoZygote(), sparsity_detector = TracerLocalSparsityDetector()),
9+
AutoSparse(AutoFiniteDiff(), sparsity_detector = TracerLocalSparsityDetector())]
10+
# 1. Matrix Factorization
11+
function matrix_factorization_objective(X, A)
12+
U, V = @view(X[1:size(A, 1), 1:Int(size(A, 2) / 2)]),
13+
@view(X[1:size(A, 1), (Int(size(A, 2) / 2) + 1):size(A, 2)])
14+
return norm(A - U * V')
15+
end
16+
function non_negative_constraint(X, A)
17+
U, V = X
18+
return [all(U .>= 0) && all(V .>= 0)]
19+
end
20+
A_mf = rand(4, 4) # Original matrix
21+
U_mf = rand(4, 2) # Factor matrix U
22+
V_mf = rand(4, 2) # Factor matrix V
23+
24+
optf = OptimizationFunction{false}(
25+
matrix_factorization_objective, adtype, cons = non_negative_constraint)
26+
optf = OptimizationBase.instantiate_function(
27+
optf, hcat(U_mf, V_mf), adtype, A_mf, g = true, h = true,
28+
cons_j = true, cons_h = true)
29+
optf.grad(hcat(U_mf, V_mf))
30+
optf.hess(hcat(U_mf, V_mf))
31+
if adtype != AutoSparse(
32+
AutoForwardDiff(), sparsity_detector = TracerLocalSparsityDetector()) &&
33+
adtype !=
34+
AutoSparse(AutoZygote(), sparsity_detector = TracerLocalSparsityDetector()) &&
35+
adtype !=
36+
AutoSparse(AutoFiniteDiff(), sparsity_detector = TracerLocalSparsityDetector())
37+
optf.cons_j(hcat(U_mf, V_mf))
38+
optf.cons_h(hcat(U_mf, V_mf))
39+
end
40+
41+
# 2. Principal Component Analysis (PCA)
42+
function pca_objective(X, A)
43+
return -tr(X' * A * X) # Minimize the negative of the trace for maximization
44+
end
45+
function orthogonality_constraint(X, A)
46+
return [norm(X' * X - I) < 1e-6]
47+
end
48+
A_pca = rand(4, 4) # Covariance matrix (can be symmetric positive definite)
49+
X_pca = rand(4, 2) # Matrix to hold principal components
50+
51+
optf = OptimizationFunction{false}(
52+
pca_objective, adtype, cons = orthogonality_constraint)
53+
optf = OptimizationBase.instantiate_function(
54+
optf, X_pca, adtype, A_pca, g = true, h = true,
55+
cons_j = true, cons_h = true)
56+
optf.grad(X_pca)
57+
optf.hess(X_pca)
58+
if adtype != AutoSparse(
59+
AutoForwardDiff(), sparsity_detector = TracerLocalSparsityDetector()) &&
60+
adtype !=
61+
AutoSparse(AutoZygote(), sparsity_detector = TracerLocalSparsityDetector()) &&
62+
adtype !=
63+
AutoSparse(AutoFiniteDiff(), sparsity_detector = TracerLocalSparsityDetector())
64+
optf.cons_j(X_pca)
65+
optf.cons_h(X_pca)
66+
end
67+
68+
# 3. Matrix Completion
69+
function matrix_completion_objective(X, P)
70+
A, Omega = P
71+
return norm(Omega .* (A - X))
72+
end
73+
# r = 2 # Rank of the matrix to be completed
74+
# function rank_constraint(X, P)
75+
# return [rank(X) <= r]
76+
# end
77+
A_mc = rand(4, 4) # Original matrix with missing entries
78+
Omega_mc = rand(4, 4) .> 0.5 # Mask for observed entries (boolean matrix)
79+
X_mc = rand(4, 4) # Matrix to be completed
80+
optf = OptimizationFunction{false}(
81+
matrix_completion_objective, adtype, cons = rank_constraint)
82+
optf = OptimizationBase.instantiate_function(
83+
optf, X_mc, adtype, (A_mc, Omega_mc), g = true, h = true)
84+
optf.grad(X_mc)
85+
optf.hess(X_mc)
86+
end
87+
end

0 commit comments

Comments
 (0)