Skip to content

Commit c85bf21

Browse files
committed
add the definition of r to the description
1 parent 73260f1 commit c85bf21

File tree

1 file changed

+42
-3
lines changed

1 file changed

+42
-3
lines changed

ot/low_rank/_factor_relaxation.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ def solve_balanced_FRLC(
105105
.. math::
106106
\textbf{P} = \mathop{\arg \min}_P \quad \langle \textbf{P}, \mathbf{M} \rangle_F
107107
108-
\text{s.t.} \ \textbf{P} = \textbf{Q} \operatorname{diag}(1/g_Q)\textbf{T}\operatorname{diag}(1/g_R)\textbf{R}^T
108+
\text{s.t.} \textbf{P} = \textbf{Q} \operatorname{diag}(1/g_Q)\textbf{T}\operatorname{diag}(1/g_R)\textbf{R}^T
109109
110-
\textbf{Q} &\in \Pi_{a,\cdot}, \quad \textbf{R} \in \Pi_{b,\cdot}, \quad \textbf{T} \in \Pi_{g_Q,g_R}
110+
\textbf{Q} \in \Pi_{a,\cdot}, \quad \textbf{R} \in \Pi_{b,\cdot}, \quad \textbf{T} \in \Pi_{g_Q,g_R}
111111
112-
\textbf{Q}, \textbf{R}, \textbf{T} &\geq 0
112+
\textbf{Q} \in \mathbb{R}^+_{n,r},\textbf{R} \in \mathbb{R}^+_{m,r},\textbf{T} \in \mathbb{R}^+_{r,r}
113113
114114
where:
115115
@@ -215,3 +215,42 @@ def solve_balanced_FRLC(
215215
return nx.dot(nx.dot(Q_new, X_new), R_new.T) # Shape (n, m)
216216

217217
Q, R, T, X = Q_new, R_new, T_new, X_new
218+
219+
220+
if __name__ == "__main__":
221+
import torch
222+
223+
grid_size = 4
224+
torch.manual_seed(42)
225+
x_vals = torch.linspace(0, 3, grid_size)
226+
y_vals = torch.linspace(0, 3, grid_size)
227+
X, Y = torch.meshgrid(x_vals, y_vals, indexing="ij")
228+
source_points = torch.stack([X.ravel(), Y.ravel()], dim=-1) # (16, 2)
229+
a = torch.ones(len(source_points)) / len(source_points) # Uniform distribution
230+
231+
# Generate Target Distribution (Gaussian Samples)
232+
mean = torch.tensor([2.0, 2.0])
233+
cov = torch.tensor([[1.0, 0.5], [0.5, 1.0]])
234+
target_points = torch.distributions.MultivariateNormal(
235+
mean, covariance_matrix=cov
236+
).sample((len(source_points),)) # (16, 2)
237+
b = torch.ones(len(target_points)) / len(target_points) # Uniform distribution
238+
239+
# Compute Cost Matrix (Squared Euclidean Distance)
240+
C = torch.cdist(source_points, target_points, p=2) ** 2
241+
242+
# Solve OT problem (assuming you have PyTorch versions of these functions)
243+
print(type(a.numpy()))
244+
P = solve_balanced_FRLC(
245+
a.to(torch.float64),
246+
b.to(torch.float64),
247+
C.to(torch.float64),
248+
10,
249+
tau=1e2,
250+
gamma=1e2,
251+
stopThr=1e-7,
252+
numItermax=100,
253+
log=True,
254+
)
255+
P = sinkhorn(a, b, C, reg=1)
256+
print(torch.sum(P * C))

0 commit comments

Comments
 (0)