Skip to content

Commit 3168387

Browse files
committed
update default value
1 parent ba515c2 commit 3168387

File tree

1 file changed

+0
-39
lines changed

1 file changed

+0
-39
lines changed

ot/low_rank/_factor_relaxation.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -215,42 +215,3 @@ def solve_balanced_FRLC(
215215
return nx.dot(nx.dot(Q_new, X_new), R_new.T) # Shape (n, m)
216216

217217
Q, R, T, X = Q_new, R_new, T_new, X_new
218-
219-
220-
if __name__ == "__main__":
221-
import torch
222-
223-
grid_size = 4
224-
torch.manual_seed(42)
225-
x_vals = torch.linspace(0, 3, grid_size)
226-
y_vals = torch.linspace(0, 3, grid_size)
227-
X, Y = torch.meshgrid(x_vals, y_vals, indexing="ij")
228-
source_points = torch.stack([X.ravel(), Y.ravel()], dim=-1) # (16, 2)
229-
a = torch.ones(len(source_points)) / len(source_points) # Uniform distribution
230-
231-
# Generate Target Distribution (Gaussian Samples)
232-
mean = torch.tensor([2.0, 2.0])
233-
cov = torch.tensor([[1.0, 0.5], [0.5, 1.0]])
234-
target_points = torch.distributions.MultivariateNormal(
235-
mean, covariance_matrix=cov
236-
).sample((len(source_points),)) # (16, 2)
237-
b = torch.ones(len(target_points)) / len(target_points) # Uniform distribution
238-
239-
# Compute Cost Matrix (Squared Euclidean Distance)
240-
C = torch.cdist(source_points, target_points, p=2) ** 2
241-
242-
# Solve OT problem (assuming you have PyTorch versions of these functions)
243-
print(type(a.numpy()))
244-
P = solve_balanced_FRLC(
245-
a.to(torch.float64),
246-
b.to(torch.float64),
247-
C.to(torch.float64),
248-
10,
249-
tau=1e2,
250-
gamma=1e2,
251-
stopThr=1e-7,
252-
numItermax=100,
253-
log=True,
254-
)
255-
P = sinkhorn(a, b, C, reg=1)
256-
print(torch.sum(P * C))

0 commit comments

Comments
 (0)