|
| 1 | +import argparse |
| 2 | +import torch |
| 3 | +from fml.functional import pairwise_distances, sinkhorn |
| 4 | + |
| 5 | +if __name__ == '__main__': |
| 6 | + # Parse input arguments |
| 7 | + parser = argparse.ArgumentParser( |
| 8 | + description='Sinkhorn loss using the functional interface.') |
| 9 | + parser.add_argument('--batch_size', '-bz', type=int, default=3, |
| 10 | + help='Batch size.') |
| 11 | + parser.add_argument('--set1_size', '-sz1', type=int, default=5, |
| 12 | + help='Set size.') |
| 13 | + parser.add_argument('--set2_size', '-sz2', type=int, default=10, |
| 14 | + help='Set size.') |
| 15 | + parser.add_argument('--point_dim', '-pd', type=int, default=4, |
| 16 | + help='Point dimension.') |
| 17 | + parser.add_argument('--lp_distance', '-p', type=int, default=2, |
| 18 | + help='p for the Lp-distance.') |
| 19 | + |
| 20 | + args = parser.parse_args() |
| 21 | + |
| 22 | + # Set the parameters |
| 23 | + minibatch_size = args.batch_size |
| 24 | + set1_size = args.set1_size |
| 25 | + set2_size = args.set2_size |
| 26 | + point_dim = args.point_dim |
| 27 | + |
| 28 | + # Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points |
| 29 | + set_a = torch.rand([minibatch_size, set1_size, point_dim]) |
| 30 | + set_b = torch.rand([minibatch_size, set2_size, point_dim]) |
| 31 | + |
| 32 | + print('Set A') |
| 33 | + print(set_a) |
| 34 | + |
| 35 | + print('Set B') |
| 36 | + print(set_b) |
| 37 | + |
| 38 | + # Condition P*1 = a and P^T*1 = b |
| 39 | + a = torch.ones(set_a.shape[0:2], |
| 40 | + requires_grad=False, |
| 41 | + device=set_a.device) |
| 42 | + |
| 43 | + b = torch.ones(set_b.shape[0:2], |
| 44 | + requires_grad=False, |
| 45 | + device=set_b.device) |
| 46 | + # Have the same total mass than set_a |
| 47 | + b = b * a.sum(1, keepdim=True) / b.sum(1, keepdim=True) |
| 48 | + |
| 49 | + # Compute the cost matrix |
| 50 | + M = pairwise_distances(set_a, set_b, p=args.lp_distance) |
| 51 | + |
| 52 | + print('Distance') |
| 53 | + print(M) |
| 54 | + |
| 55 | + # Compute the transport matrix between each pair of sets in the minibatch with default parameters |
| 56 | + P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8) |
| 57 | + |
| 58 | + print('Transport Matrix') |
| 59 | + print(P) |
| 60 | + |
| 61 | + print('Condition error') |
| 62 | + |
| 63 | + aprox_a = P.sum(2) |
| 64 | + aprox_b = P.sum(1) |
| 65 | + |
| 66 | + print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item())) |
| 67 | + print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item())) |
| 68 | + |
| 69 | + # Compute the loss |
| 70 | + loss = (M * P).sum(2).sum(1) |
| 71 | + |
| 72 | + print('Loss') |
| 73 | + print(loss) |
| 74 | + |
0 commit comments