Skip to content

Commit 8ae5b8f

Browse files
authored
Merge pull request #3 from priba/master
Initialization for the sinkhorn iterations
2 parents bd212b8 + a1ec9e0 commit 8ae5b8f

5 files changed

Lines changed: 165 additions & 9 deletions

File tree

examples/sinkhorn_loss_functional/main.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@
3232
print('Set B')
3333
print(set_b)
3434

35+
# Condition P*1_d = a and P^T*1_d = b
3536
a = torch.ones(set_a.shape[0:2],
3637
requires_grad=False,
37-
device=set_a.device) / set_a.shape[1]
38+
device=set_a.device)
3839

3940
b = torch.ones(set_b.shape[0:2],
4041
requires_grad=False,
41-
device=set_b.device) / set_b.shape[1]
42+
device=set_b.device)
4243

4344
# Compute the cost matrix
4445
M = pairwise_distances(set_a, set_b, p=args.lp_distance)
@@ -47,11 +48,19 @@
4748
print(M)
4849

4950
# Compute the transport matrix between each pair of sets in the minibatch with default parameters
50-
P = sinkhorn(a, b, M, 1e-3)
51-
51+
P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8)
52+
5253
print('Transport Matrix')
5354
print(P)
5455

56+
print('Condition error')
57+
58+
aprox_a = P.sum(2)
59+
aprox_b = P.sum(1)
60+
61+
print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item()))
62+
print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item()))
63+
5564
# Compute the loss
5665
loss = (M * P).sum(2).sum(1)
5766

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import argparse
2+
import torch
3+
from fml.functional import pairwise_distances, sinkhorn
4+
5+
if __name__ == '__main__':
6+
# Parse input arguments
7+
parser = argparse.ArgumentParser(
8+
description='Sinkhorn loss using the functional interface.')
9+
parser.add_argument('--batch_size', '-bz', type=int, default=3,
10+
help='Batch size.')
11+
parser.add_argument('--set1_size', '-sz1', type=int, default=5,
12+
help='Set size.')
13+
parser.add_argument('--set2_size', '-sz2', type=int, default=10,
14+
help='Set size.')
15+
parser.add_argument('--point_dim', '-pd', type=int, default=4,
16+
help='Point dimension.')
17+
parser.add_argument('--lp_distance', '-p', type=int, default=2,
18+
help='p for the Lp-distance.')
19+
20+
args = parser.parse_args()
21+
22+
# Set the parameters
23+
minibatch_size = args.batch_size
24+
set1_size = args.set1_size
25+
set2_size = args.set2_size
26+
point_dim = args.point_dim
27+
28+
# Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points
29+
set_a = torch.rand([minibatch_size, set1_size, point_dim])
30+
set_b = torch.rand([minibatch_size, set2_size, point_dim])
31+
32+
print('Set A')
33+
print(set_a)
34+
35+
print('Set B')
36+
print(set_b)
37+
38+
# Condition P*1 = a and P^T*1 = b
39+
a = torch.ones(set_a.shape[0:2],
40+
requires_grad=False,
41+
device=set_a.device)
42+
43+
b = torch.ones(set_b.shape[0:2],
44+
requires_grad=False,
45+
device=set_b.device)
46+
# Have the same total mass than set_a
47+
b = b * a.sum(1, keepdim=True) / b.sum(1, keepdim=True)
48+
49+
# Compute the cost matrix
50+
M = pairwise_distances(set_a, set_b, p=args.lp_distance)
51+
52+
print('Distance')
53+
print(M)
54+
55+
# Compute the transport matrix between each pair of sets in the minibatch with default parameters
56+
P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8)
57+
58+
print('Transport Matrix')
59+
print(P)
60+
61+
print('Condition error')
62+
63+
aprox_a = P.sum(2)
64+
aprox_b = P.sum(1)
65+
66+
print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item()))
67+
print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item()))
68+
69+
# Compute the loss
70+
loss = (M * P).sum(2).sum(1)
71+
72+
print('Loss')
73+
print(loss)
74+
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import argparse
2+
import torch
3+
from fml.functional import pairwise_distances, sinkhorn
4+
5+
if __name__ == '__main__':
6+
# Parse input arguments
7+
parser = argparse.ArgumentParser(
8+
description='Sinkhorn loss using the functional interface.')
9+
parser.add_argument('--batch_size', '-bz', type=int, default=3,
10+
help='Batch size.')
11+
parser.add_argument('--set_size', '-sz', type=int, default=10,
12+
help='Set size.')
13+
parser.add_argument('--point_dim', '-pd', type=int, default=4,
14+
help='Point dimension.')
15+
parser.add_argument('--lp_distance', '-p', type=int, default=2,
16+
help='p for the Lp-distance.')
17+
18+
args = parser.parse_args()
19+
20+
# Set the parameters
21+
minibatch_size = args.batch_size
22+
set_size = args.set_size
23+
point_dim = args.point_dim
24+
25+
# Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points
26+
set_a = torch.rand([minibatch_size, set_size, point_dim])
27+
set_b = torch.rand([minibatch_size, set_size, point_dim])
28+
29+
print('Set A')
30+
print(set_a)
31+
32+
print('Set B')
33+
print(set_b)
34+
35+
# Condition P*1 = a and P^T*1 = b
36+
a = torch.rand(set_a.shape[0:2],
37+
requires_grad=False,
38+
device=set_a.device)
39+
# Keep an average mass of 1 per node
40+
a = a * set_a.shape[1] / a.sum(1, keepdim=True)
41+
42+
b = torch.rand(set_b.shape[0:2],
43+
requires_grad=False,
44+
device=set_b.device)
45+
# Have the same total mass than set_a
46+
b = b * a.sum(1, keepdim=True) / b.sum(1, keepdim=True)
47+
48+
# Compute the cost matrix
49+
M = pairwise_distances(set_a, set_b, p=args.lp_distance)
50+
51+
print('Distance')
52+
print(M)
53+
54+
# Compute the transport matrix between each pair of sets in the minibatch with default parameters
55+
P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8)
56+
57+
print('Transport Matrix')
58+
print(P)
59+
60+
print('Condition error')
61+
62+
aprox_a = P.sum(2)
63+
aprox_b = P.sum(1)
64+
65+
print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item()))
66+
print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item()))
67+
68+
# Compute the loss
69+
loss = (M * P).sum(2).sum(1)
70+
71+
print('Loss')
72+
print(loss)
73+

fml/functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def pairwise_distances(a: torch.Tensor, b: torch.Tensor, p=2):
1515
raise ValueError("Invalid shape for a. Must be [m, n, d] but got", a.shape)
1616
if len(b.shape) != 3:
1717
raise ValueError("Invalid shape for a. Must be [m, n, d] but got", b.shape)
18-
1918
return (a.unsqueeze(2) - b.unsqueeze(1)).abs().pow(p).sum(3)
2019

2120

@@ -69,8 +68,9 @@ def sinkhorn(a: torch.Tensor, b: torch.Tensor, M: torch.Tensor, eps: float,
6968
raise ValueError("Got unexpected shape for tensor b (%s). Expected [nb, n] where M has shape [nb, m, n]." %
7069
str(b.shape))
7170

71+
# Initialize the iteration with the change of variable
7272
u = torch.zeros(a.shape, dtype=a.dtype, device=a.device)
73-
v = torch.zeros(b.shape, dtype=b.dtype, device=b.device)
73+
v = eps * torch.log(b)
7474

7575
M_t = torch.transpose(M, 1, 2)
7676

@@ -97,7 +97,7 @@ def stabilized_log_sum_exp(x):
9797
break
9898

9999
log_P = (-M + u.unsqueeze(2) + v.unsqueeze(1)) / eps
100-
100+
101101
P = torch.exp(log_P)
102102

103103
return P

fml/nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ def forward(self, predicted, expected, a=None, b=None):
4242
if a is None:
4343
a = torch.ones(predicted.shape[0:2],
4444
requires_grad=False,
45-
device=predicted.device) / predicted.shape[1]
45+
device=predicted.device)
4646
else:
4747
a = a.to(predicted.device)
4848

4949
if b is None:
5050
b = torch.ones(predicted.shape[0:2],
5151
requires_grad=False,
52-
device=predicted.device) / predicted.shape[1]
52+
device=predicted.device)
5353
else:
5454
b = b.to(predicted.device)
5555

0 commit comments

Comments
 (0)