From 3c0f11ca35b03bcfbf807675d25846c67f5b9873 Mon Sep 17 00:00:00 2001 From: Pau Riba Date: Tue, 19 Feb 2019 18:37:58 -0500 Subject: [PATCH 1/5] Correct initialization of v --- examples/sinkhorn_loss_functional/main.py | 17 +++++++++++++---- fml/functional.py | 6 +++--- fml/nn.py | 4 ++-- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/sinkhorn_loss_functional/main.py b/examples/sinkhorn_loss_functional/main.py index d2564a8..e672663 100644 --- a/examples/sinkhorn_loss_functional/main.py +++ b/examples/sinkhorn_loss_functional/main.py @@ -32,13 +32,14 @@ print('Set B') print(set_b) + # Condition P*b = a and P^T*a = b a = torch.ones(set_a.shape[0:2], requires_grad=False, - device=set_a.device) / set_a.shape[1] + device=set_a.device) b = torch.ones(set_b.shape[0:2], requires_grad=False, - device=set_b.device) / set_b.shape[1] + device=set_b.device) # Compute the cost matrix M = pairwise_distances(set_a, set_b, p=args.lp_distance) @@ -47,11 +48,19 @@ print(M) # Compute the transport matrix between each pair of sets in the minibatch with default parameters - P = sinkhorn(a, b, M, 1e-3) - + P = sinkhorn(a, b, M, 1e-3, max_iters=5000, stop_thresh=1e-12) + print('Transport Matrix') print(P) + print('Condition error') + + aprox_a = torch.bmm(P, b.unsqueeze(2)).squeeze(2) + aprox_b = torch.bmm(P.transpose(1,2), a.unsqueeze(2)).squeeze(2) + + print('\t P*a mean error: {}'.format(torch.mean(aprox_b - b).item())) + print('\t P^T*b mean error: {}'.format(torch.mean(aprox_a - a).item())) + # Compute the loss loss = (M * P).sum(2).sum(1) diff --git a/fml/functional.py b/fml/functional.py index 86698e6..9600c1a 100644 --- a/fml/functional.py +++ b/fml/functional.py @@ -15,7 +15,6 @@ def pairwise_distances(a: torch.Tensor, b: torch.Tensor, p=2): raise ValueError("Invalid shape for a. Must be [m, n, d] but got", a.shape) if len(b.shape) != 3: raise ValueError("Invalid shape for a. Must be [m, n, d] but got", b.shape) - return (a.unsqueeze(2) - b.unsqueeze(1)).abs().pow(p).sum(3) @@ -69,8 +68,9 @@ def sinkhorn(a: torch.Tensor, b: torch.Tensor, M: torch.Tensor, eps: float, raise ValueError("Got unexpected shape for tensor b (%s). Expected [nb, n] where M has shape [nb, m, n]." % str(b.shape)) + # Initialize the iteration with the change of variable u = torch.zeros(a.shape, dtype=a.dtype, device=a.device) - v = torch.zeros(b.shape, dtype=b.dtype, device=b.device) + v = eps * torch.log(b) M_t = torch.transpose(M, 1, 2) @@ -97,7 +97,7 @@ def stabilized_log_sum_exp(x): break log_P = (-M + u.unsqueeze(2) + v.unsqueeze(1)) / eps - + P = torch.exp(log_P) return P diff --git a/fml/nn.py b/fml/nn.py index c875c1c..d74ba5d 100644 --- a/fml/nn.py +++ b/fml/nn.py @@ -42,14 +42,14 @@ def forward(self, predicted, expected, a=None, b=None): if a is None: a = torch.ones(predicted.shape[0:2], requires_grad=False, - device=predicted.device) / predicted.shape[1] + device=predicted.device) else: a = a.to(predicted.device) if b is None: b = torch.ones(predicted.shape[0:2], requires_grad=False, - device=predicted.device) / predicted.shape[1] + device=predicted.device) else: b = b.to(predicted.device) From 3e25b679c2838e5a7a6fb5e52cbe77edb0fd8e9a Mon Sep 17 00:00:00 2001 From: Pau Riba Date: Tue, 19 Feb 2019 18:42:38 -0500 Subject: [PATCH 2/5] Reduce the number of sinkhorn iterations in the example --- examples/sinkhorn_loss_functional/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sinkhorn_loss_functional/main.py b/examples/sinkhorn_loss_functional/main.py index e672663..3dc323c 100644 --- a/examples/sinkhorn_loss_functional/main.py +++ b/examples/sinkhorn_loss_functional/main.py @@ -48,7 +48,7 @@ print(M) # Compute the transport matrix between each pair of sets in the minibatch with default parameters - P = sinkhorn(a, b, M, 1e-3, max_iters=5000, stop_thresh=1e-12) + P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8) print('Transport Matrix') print(P) From f720c1856c0f8234eb0ed81b35e74a2045c7dbea Mon Sep 17 00:00:00 2001 From: Pau Riba Date: Wed, 20 Feb 2019 11:21:15 -0500 Subject: [PATCH 3/5] Add a new exampled with a weighted set --- examples/sinkhorn_loss_functional/main.py | 10 ++-- examples/sinkhorn_loss_weighted/main.py | 73 +++++++++++++++++++++++ 2 files changed, 78 insertions(+), 5 deletions(-) create mode 100644 examples/sinkhorn_loss_weighted/main.py diff --git a/examples/sinkhorn_loss_functional/main.py b/examples/sinkhorn_loss_functional/main.py index 3dc323c..8bb2482 100644 --- a/examples/sinkhorn_loss_functional/main.py +++ b/examples/sinkhorn_loss_functional/main.py @@ -32,7 +32,7 @@ print('Set B') print(set_b) - # Condition P*b = a and P^T*a = b + # Condition P*1_d = a and P^T*1_d = b a = torch.ones(set_a.shape[0:2], requires_grad=False, device=set_a.device) @@ -55,11 +55,11 @@ print('Condition error') - aprox_a = torch.bmm(P, b.unsqueeze(2)).squeeze(2) - aprox_b = torch.bmm(P.transpose(1,2), a.unsqueeze(2)).squeeze(2) + aprox_a = P.sum(2) + aprox_b = P.sum(1) - print('\t P*a mean error: {}'.format(torch.mean(aprox_b - b).item())) - print('\t P^T*b mean error: {}'.format(torch.mean(aprox_a - a).item())) + print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item())) + print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item())) # Compute the loss loss = (M * P).sum(2).sum(1) diff --git a/examples/sinkhorn_loss_weighted/main.py b/examples/sinkhorn_loss_weighted/main.py new file mode 100644 index 0000000..0ffd4d2 --- /dev/null +++ b/examples/sinkhorn_loss_weighted/main.py @@ -0,0 +1,73 @@ +import argparse +import torch +from fml.functional import pairwise_distances, sinkhorn + +if __name__ == '__main__': + # Parse input arguments + parser = argparse.ArgumentParser( + description='Sinkhorn loss using the functional interface.') + parser.add_argument('--batch_size', '-bz', type=int, default=3, + help='Batch size.') + parser.add_argument('--set_size', '-sz', type=int, default=10, + help='Set size.') + parser.add_argument('--point_dim', '-pd', type=int, default=4, + help='Point dimension.') + parser.add_argument('--lp_distance', '-p', type=int, default=2, + help='p for the Lp-distance.') + + args = parser.parse_args() + + # Set the parameters + minibatch_size = args.batch_size + set_size = args.set_size + point_dim = args.point_dim + + # Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points + set_a = torch.rand([minibatch_size, set_size, point_dim]) + set_b = torch.rand([minibatch_size, set_size, point_dim]) + + print('Set A') + print(set_a) + + print('Set B') + print(set_b) + + # Condition P*1 = a and P^T*1 = b + a = torch.rand(set_a.shape[0:2], + requires_grad=False, + device=set_a.device) + # Keep an average mass of 1 per node + a = a * set_a.shape[1] / a.sum(1, keepdim=True) + + b = torch.rand(set_b.shape[0:2], + requires_grad=False, + device=set_b.device) + # Have the same total mass than set_a + b = b * a.sum(1, keepdim=True) / b.sum(1, keepdim=True) + + # Compute the cost matrix + M = pairwise_distances(set_a, set_b, p=args.lp_distance) + + print('Distance') + print(M) + + # Compute the transport matrix between each pair of sets in the minibatch with default parameters + P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8) + + print('Transport Matrix') + print(P) + + print('Condition error') + + aprox_a = P.sum(2) + aprox_b = P.sum(1) + + print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item())) + print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item())) + + # Compute the loss + loss = (M * P).sum(2).sum(1) + + print('Loss') + print(loss) + From 399319ceebe5dfdb7eb6dc5e72266adc2748afff Mon Sep 17 00:00:00 2001 From: Pau Riba Date: Wed, 20 Feb 2019 11:25:32 -0500 Subject: [PATCH 4/5] Add example with unbalance sets. --- examples/sinkhorn_loss_unbalanced/main.py | 74 +++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 examples/sinkhorn_loss_unbalanced/main.py diff --git a/examples/sinkhorn_loss_unbalanced/main.py b/examples/sinkhorn_loss_unbalanced/main.py new file mode 100644 index 0000000..66cff8d --- /dev/null +++ b/examples/sinkhorn_loss_unbalanced/main.py @@ -0,0 +1,74 @@ +import argparse +import torch +from fml.functional import pairwise_distances, sinkhorn + +if __name__ == '__main__': + # Parse input arguments + parser = argparse.ArgumentParser( + description='Sinkhorn loss using the functional interface.') + parser.add_argument('--batch_size', '-bz', type=int, default=3, + help='Batch size.') + parser.add_argument('--set1_size', '-sz1', type=int, default=5, + help='Set size.') + parser.add_argument('--set2_size', '-sz2', type=int, default=10, + help='Set size.') + parser.add_argument('--point_dim', '-pd', type=int, default=4, + help='Point dimension.') + parser.add_argument('--lp_distance', '-p', type=int, default=2, + help='p for the Lp-distance.') + + args = parser.parse_args() + + # Set the parameters + minibatch_size = args.batch_size + set1_size = args.set1_size + set2_size = args.set2_size + point_dim = args.point_dim + + # Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points + set_a = torch.rand([minibatch_size, set1_size, point_dim]) + set_b = torch.rand([minibatch_size, set2_size, point_dim]) + + print('Set A') + print(set_a) + + print('Set B') + print(set_b) + + # Condition P*1 = a and P^T*1 = b + a = torch.rand(set_a.shape[0:2], + requires_grad=False, + device=set_a.device) + + b = torch.rand(set_b.shape[0:2], + requires_grad=False, + device=set_b.device) + # Have the same total mass than set_a + b = b * a.sum(1, keepdim=True) / b.sum(1, keepdim=True) + + # Compute the cost matrix + M = pairwise_distances(set_a, set_b, p=args.lp_distance) + + print('Distance') + print(M) + + # Compute the transport matrix between each pair of sets in the minibatch with default parameters + P = sinkhorn(a, b, M, 1e-3, max_iters=500, stop_thresh=1e-8) + + print('Transport Matrix') + print(P) + + print('Condition error') + + aprox_a = P.sum(2) + aprox_b = P.sum(1) + + print('\t P*1_d mean error: {}'.format(torch.mean((aprox_a - a).abs()).item())) + print('\t P^T*1_d mean error: {}'.format(torch.mean((aprox_b - b).abs()).item())) + + # Compute the loss + loss = (M * P).sum(2).sum(1) + + print('Loss') + print(loss) + From a1ec9e04ea8d78694ff3146621c46cd639d49a4b Mon Sep 17 00:00:00 2001 From: Pau Riba Date: Wed, 20 Feb 2019 11:30:15 -0500 Subject: [PATCH 5/5] Set mass of set_a to ones --- examples/sinkhorn_loss_unbalanced/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/sinkhorn_loss_unbalanced/main.py b/examples/sinkhorn_loss_unbalanced/main.py index 66cff8d..0a6abd5 100644 --- a/examples/sinkhorn_loss_unbalanced/main.py +++ b/examples/sinkhorn_loss_unbalanced/main.py @@ -36,11 +36,11 @@ print(set_b) # Condition P*1 = a and P^T*1 = b - a = torch.rand(set_a.shape[0:2], + a = torch.ones(set_a.shape[0:2], requires_grad=False, device=set_a.device) - b = torch.rand(set_b.shape[0:2], + b = torch.ones(set_b.shape[0:2], requires_grad=False, device=set_b.device) # Have the same total mass than set_a