@@ -215,42 +215,3 @@ def solve_balanced_FRLC(
215215 return nx .dot (nx .dot (Q_new , X_new ), R_new .T ) # Shape (n, m)
216216
217217 Q , R , T , X = Q_new , R_new , T_new , X_new
218-
219-
220- if __name__ == "__main__" :
221- import torch
222-
223- grid_size = 4
224- torch .manual_seed (42 )
225- x_vals = torch .linspace (0 , 3 , grid_size )
226- y_vals = torch .linspace (0 , 3 , grid_size )
227- X , Y = torch .meshgrid (x_vals , y_vals , indexing = "ij" )
228- source_points = torch .stack ([X .ravel (), Y .ravel ()], dim = - 1 ) # (16, 2)
229- a = torch .ones (len (source_points )) / len (source_points ) # Uniform distribution
230-
231- # Generate Target Distribution (Gaussian Samples)
232- mean = torch .tensor ([2.0 , 2.0 ])
233- cov = torch .tensor ([[1.0 , 0.5 ], [0.5 , 1.0 ]])
234- target_points = torch .distributions .MultivariateNormal (
235- mean , covariance_matrix = cov
236- ).sample ((len (source_points ),)) # (16, 2)
237- b = torch .ones (len (target_points )) / len (target_points ) # Uniform distribution
238-
239- # Compute Cost Matrix (Squared Euclidean Distance)
240- C = torch .cdist (source_points , target_points , p = 2 ) ** 2
241-
242- # Solve OT problem (assuming you have PyTorch versions of these functions)
243- print (type (a .numpy ()))
244- P = solve_balanced_FRLC (
245- a .to (torch .float64 ),
246- b .to (torch .float64 ),
247- C .to (torch .float64 ),
248- 10 ,
249- tau = 1e2 ,
250- gamma = 1e2 ,
251- stopThr = 1e-7 ,
252- numItermax = 100 ,
253- log = True ,
254- )
255- P = sinkhorn (a , b , C , reg = 1 )
256- print (torch .sum (P * C ))
0 commit comments