From 91be0265aad7acb625e5a2df0b1b8c3e83d6222e Mon Sep 17 00:00:00 2001 From: Pau Riba Date: Mon, 11 Feb 2019 18:47:34 -0500 Subject: [PATCH 1/2] Add examples in the example folder --- README.md | 2 +- examples/pairwise_distance/main.py | 42 ++++++++++++++++ examples/sinkhorn_loss/main.py | 49 ++++++++++++++++++ examples/sinkhorn_loss_functional/main.py | 60 +++++++++++++++++++++++ fml/functional.py | 2 +- 5 files changed, 153 insertions(+), 2 deletions(-) create mode 100644 examples/pairwise_distance/main.py create mode 100644 examples/sinkhorn_loss/main.py create mode 100644 examples/sinkhorn_loss_functional/main.py diff --git a/README.md b/README.md index 700e495..524ca6f 100644 --- a/README.md +++ b/README.md @@ -97,7 +97,7 @@ loss = loss_fun(set_a, set_b) ``` ### Computing pairwise distances between point sets -``` +```python import torch from fml.functional import pairwise_distances diff --git a/examples/pairwise_distance/main.py b/examples/pairwise_distance/main.py new file mode 100644 index 0000000..1187e06 --- /dev/null +++ b/examples/pairwise_distance/main.py @@ -0,0 +1,42 @@ +import argparse +import torch +from fml.functional import pairwise_distances + +if __name__ == '__main__': + # Parse input arguments + parser = argparse.ArgumentParser( + description='Pairwise distance between two batchs of points.') + parser.add_argument('--batch_size', '-bz', type=int, default=3, + help='Batch size.') + parser.add_argument('--set_size', '-sz', type=int, default=10, + help='Set size.') + parser.add_argument('--point_dim', '-pd', type=int, default=4, + help='Point dimension.') + parser.add_argument('--lp_distance', '-p', type=int, default=2, + help='p for the Lp-distance.') + + args = parser.parse_args() + + # Set the parameters + minibatch_size = args.batch_size + set_size = args.set_size + point_dim = args.point_dim + + # Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points + set_a = torch.rand([minibatch_size, set_size, point_dim]) + set_b = torch.rand([minibatch_size, set_size, point_dim]) + + # Compute the pairwise distances between each pair of sets in the minibatch + # distances is a tensor of shape [minibatch_size, set_size, set_size] where each + # distances[k, i, j] = ||set_a[k, i] - set_b[k, j]||^2 + distances = pairwise_distances(set_a, set_b, p=args.lp_distance) + + print('Set A') + print(set_a) + + print('Set B') + print(set_b) + + print('Distance') + print(distances) + diff --git a/examples/sinkhorn_loss/main.py b/examples/sinkhorn_loss/main.py new file mode 100644 index 0000000..89bd547 --- /dev/null +++ b/examples/sinkhorn_loss/main.py @@ -0,0 +1,49 @@ +import argparse +import torch +from fml.nn import SinkhornLoss + +if __name__ == '__main__': + # Parse input arguments + parser = argparse.ArgumentParser( + description='SikhornLoss between two batchs of points.') + parser.add_argument('--batch_size', '-bz', type=int, default=3, + help='Batch size.') + parser.add_argument('--set_size', '-sz', type=int, default=10, + help='Set size.') + parser.add_argument('--point_dim', '-pd', type=int, default=4, + help='Point dimension.') + parser.add_argument('--transport_matrix', '-tm', action='store_true', + help='Return transport matrix.') + + args = parser.parse_args() + + # Set the parameters + minibatch_size = args.batch_size + set_size = args.set_size + point_dim = args.point_dim + + # Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points + set_a = torch.rand([minibatch_size, set_size, point_dim]) + set_b = torch.rand([minibatch_size, set_size, point_dim]) + + print('Set A') + print(set_a) + + print('Set B') + print(set_b) + + # Create a loss function module with default parameters. See the class documentation for optional parameters. + loss_fun = SinkhornLoss(return_transport_matrix=args.transport_matrix) + + # Compute the loss between each pair of sets in the minibatch + # loss is a tensor with [minibatch_size] elements which can be backpropagated through + if args.transport_matrix: + loss, P = loss_fun(set_a, set_b) + print('Transport Matrix') + print(P) + else: + loss = loss_fun(set_a, set_b) + + print('Loss') + print(loss) + diff --git a/examples/sinkhorn_loss_functional/main.py b/examples/sinkhorn_loss_functional/main.py new file mode 100644 index 0000000..d2564a8 --- /dev/null +++ b/examples/sinkhorn_loss_functional/main.py @@ -0,0 +1,60 @@ +import argparse +import torch +from fml.functional import pairwise_distances, sinkhorn + +if __name__ == '__main__': + # Parse input arguments + parser = argparse.ArgumentParser( + description='Sinkhorn loss using the functional interface.') + parser.add_argument('--batch_size', '-bz', type=int, default=3, + help='Batch size.') + parser.add_argument('--set_size', '-sz', type=int, default=10, + help='Set size.') + parser.add_argument('--point_dim', '-pd', type=int, default=4, + help='Point dimension.') + parser.add_argument('--lp_distance', '-p', type=int, default=2, + help='p for the Lp-distance.') + + args = parser.parse_args() + + # Set the parameters + minibatch_size = args.batch_size + set_size = args.set_size + point_dim = args.point_dim + + # Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points + set_a = torch.rand([minibatch_size, set_size, point_dim]) + set_b = torch.rand([minibatch_size, set_size, point_dim]) + + print('Set A') + print(set_a) + + print('Set B') + print(set_b) + + a = torch.ones(set_a.shape[0:2], + requires_grad=False, + device=set_a.device) / set_a.shape[1] + + b = torch.ones(set_b.shape[0:2], + requires_grad=False, + device=set_b.device) / set_b.shape[1] + + # Compute the cost matrix + M = pairwise_distances(set_a, set_b, p=args.lp_distance) + + print('Distance') + print(M) + + # Compute the transport matrix between each pair of sets in the minibatch with default parameters + P = sinkhorn(a, b, M, 1e-3) + + print('Transport Matrix') + print(P) + + # Compute the loss + loss = (M * P).sum(2).sum(1) + + print('Loss') + print(loss) + diff --git a/fml/functional.py b/fml/functional.py index d8df137..86698e6 100644 --- a/fml/functional.py +++ b/fml/functional.py @@ -16,7 +16,7 @@ def pairwise_distances(a: torch.Tensor, b: torch.Tensor, p=2): if len(b.shape) != 3: raise ValueError("Invalid shape for a. Must be [m, n, d] but got", b.shape) - return (a.unsqueeze(2) - b.unsqueeze(1)).pow(p).sum(3) + return (a.unsqueeze(2) - b.unsqueeze(1)).abs().pow(p).sum(3) def chamfer(a, b): From c0cfc2c1531319592034ad26a60fa99b196d748d Mon Sep 17 00:00:00 2001 From: Pau Riba Date: Mon, 11 Feb 2019 23:11:59 -0500 Subject: [PATCH 2/2] Chamfer example --- examples/chamfer_loss/main.py | 42 +++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 examples/chamfer_loss/main.py diff --git a/examples/chamfer_loss/main.py b/examples/chamfer_loss/main.py new file mode 100644 index 0000000..d0aff0a --- /dev/null +++ b/examples/chamfer_loss/main.py @@ -0,0 +1,42 @@ +import argparse +import torch +from fml.nn import ChamferLoss + +if __name__ == '__main__': + # Parse input arguments + parser = argparse.ArgumentParser( + description='SikhornLoss between two batchs of points.') + parser.add_argument('--batch_size', '-bz', type=int, default=3, + help='Batch size.') + parser.add_argument('--set_size', '-sz', type=int, default=10, + help='Set size.') + parser.add_argument('--point_dim', '-pd', type=int, default=4, + help='Point dimension.') + + args = parser.parse_args() + + # Set the parameters + minibatch_size = args.batch_size + set_size = args.set_size + point_dim = args.point_dim + + # Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points + set_a = torch.rand([minibatch_size, set_size, point_dim]) + set_b = torch.rand([minibatch_size, set_size, point_dim]) + + print('Set A') + print(set_a) + + print('Set B') + print(set_b) + + # Create a loss function module with default parameters. See the class documentation for optional parameters. + loss_fun = ChamferLoss() + + # Compute the loss between each pair of sets in the minibatch + # loss is a tensor with [minibatch_size] elements which can be backpropagated through + loss = loss_fun(set_a, set_b) + + print('Loss') + print(loss) +