@@ -105,11 +105,11 @@ def solve_balanced_FRLC(
105105 .. math::
106106 \textbf{P} = \mathop{\arg \min}_P \quad \langle \textbf{P}, \mathbf{M} \rangle_F
107107
108- \text{s.t.} \ \textbf{P} = \textbf{Q} \operatorname{diag}(1/g_Q)\textbf{T}\operatorname{diag}(1/g_R)\textbf{R}^T
108+ \text{s.t.} \textbf{P} = \textbf{Q} \operatorname{diag}(1/g_Q)\textbf{T}\operatorname{diag}(1/g_R)\textbf{R}^T
109109
110- \textbf{Q} & \in \Pi_{a,\cdot}, \quad \textbf{R} \in \Pi_{b,\cdot}, \quad \textbf{T} \in \Pi_{g_Q,g_R}
110+ \textbf{Q} \in \Pi_{a,\cdot}, \quad \textbf{R} \in \Pi_{b,\cdot}, \quad \textbf{T} \in \Pi_{g_Q,g_R}
111111
112- \textbf{Q}, \textbf {R}, \ textbf{T} &\geq 0
112+ \textbf{Q} \in \mathbb {R}^+_{n,r},\textbf{R} \in \mathbb{R}^+_{m,r},\ textbf{T} \in \mathbb{R}^+_{r,r}
113113
114114 where:
115115
@@ -215,3 +215,42 @@ def solve_balanced_FRLC(
215215 return nx .dot (nx .dot (Q_new , X_new ), R_new .T ) # Shape (n, m)
216216
217217 Q , R , T , X = Q_new , R_new , T_new , X_new
218+
219+
220+ if __name__ == "__main__" :
221+ import torch
222+
223+ grid_size = 4
224+ torch .manual_seed (42 )
225+ x_vals = torch .linspace (0 , 3 , grid_size )
226+ y_vals = torch .linspace (0 , 3 , grid_size )
227+ X , Y = torch .meshgrid (x_vals , y_vals , indexing = "ij" )
228+ source_points = torch .stack ([X .ravel (), Y .ravel ()], dim = - 1 ) # (16, 2)
229+ a = torch .ones (len (source_points )) / len (source_points ) # Uniform distribution
230+
231+ # Generate Target Distribution (Gaussian Samples)
232+ mean = torch .tensor ([2.0 , 2.0 ])
233+ cov = torch .tensor ([[1.0 , 0.5 ], [0.5 , 1.0 ]])
234+ target_points = torch .distributions .MultivariateNormal (
235+ mean , covariance_matrix = cov
236+ ).sample ((len (source_points ),)) # (16, 2)
237+ b = torch .ones (len (target_points )) / len (target_points ) # Uniform distribution
238+
239+ # Compute Cost Matrix (Squared Euclidean Distance)
240+ C = torch .cdist (source_points , target_points , p = 2 ) ** 2
241+
242+ # Solve OT problem (assuming you have PyTorch versions of these functions)
243+ print (type (a .numpy ()))
244+ P = solve_balanced_FRLC (
245+ a .to (torch .float64 ),
246+ b .to (torch .float64 ),
247+ C .to (torch .float64 ),
248+ 10 ,
249+ tau = 1e2 ,
250+ gamma = 1e2 ,
251+ stopThr = 1e-7 ,
252+ numItermax = 100 ,
253+ log = True ,
254+ )
255+ P = sinkhorn (a , b , C , reg = 1 )
256+ print (torch .sum (P * C ))
0 commit comments