Skip to content

Commit

Permalink
Merge pull request #2 from priba/master
Browse files Browse the repository at this point in the history
pairwise distance abs()
  • Loading branch information
fwilliams authored Feb 12, 2019
2 parents 4a35f15 + c0cfc2c commit bd212b8
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ loss = loss_fun(set_a, set_b)
```

### Computing pairwise distances between point sets
```
```python
import torch
from fml.functional import pairwise_distances

Expand Down
42 changes: 42 additions & 0 deletions examples/chamfer_loss/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import argparse
import torch
from fml.nn import ChamferLoss

if __name__ == '__main__':
# Parse input arguments
parser = argparse.ArgumentParser(
description='SikhornLoss between two batchs of points.')
parser.add_argument('--batch_size', '-bz', type=int, default=3,
help='Batch size.')
parser.add_argument('--set_size', '-sz', type=int, default=10,
help='Set size.')
parser.add_argument('--point_dim', '-pd', type=int, default=4,
help='Point dimension.')

args = parser.parse_args()

# Set the parameters
minibatch_size = args.batch_size
set_size = args.set_size
point_dim = args.point_dim

# Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points
set_a = torch.rand([minibatch_size, set_size, point_dim])
set_b = torch.rand([minibatch_size, set_size, point_dim])

print('Set A')
print(set_a)

print('Set B')
print(set_b)

# Create a loss function module with default parameters. See the class documentation for optional parameters.
loss_fun = ChamferLoss()

# Compute the loss between each pair of sets in the minibatch
# loss is a tensor with [minibatch_size] elements which can be backpropagated through
loss = loss_fun(set_a, set_b)

print('Loss')
print(loss)

42 changes: 42 additions & 0 deletions examples/pairwise_distance/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import argparse
import torch
from fml.functional import pairwise_distances

if __name__ == '__main__':
# Parse input arguments
parser = argparse.ArgumentParser(
description='Pairwise distance between two batchs of points.')
parser.add_argument('--batch_size', '-bz', type=int, default=3,
help='Batch size.')
parser.add_argument('--set_size', '-sz', type=int, default=10,
help='Set size.')
parser.add_argument('--point_dim', '-pd', type=int, default=4,
help='Point dimension.')
parser.add_argument('--lp_distance', '-p', type=int, default=2,
help='p for the Lp-distance.')

args = parser.parse_args()

# Set the parameters
minibatch_size = args.batch_size
set_size = args.set_size
point_dim = args.point_dim

# Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points
set_a = torch.rand([minibatch_size, set_size, point_dim])
set_b = torch.rand([minibatch_size, set_size, point_dim])

# Compute the pairwise distances between each pair of sets in the minibatch
# distances is a tensor of shape [minibatch_size, set_size, set_size] where each
# distances[k, i, j] = ||set_a[k, i] - set_b[k, j]||^2
distances = pairwise_distances(set_a, set_b, p=args.lp_distance)

print('Set A')
print(set_a)

print('Set B')
print(set_b)

print('Distance')
print(distances)

49 changes: 49 additions & 0 deletions examples/sinkhorn_loss/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import argparse
import torch
from fml.nn import SinkhornLoss

if __name__ == '__main__':
# Parse input arguments
parser = argparse.ArgumentParser(
description='SikhornLoss between two batchs of points.')
parser.add_argument('--batch_size', '-bz', type=int, default=3,
help='Batch size.')
parser.add_argument('--set_size', '-sz', type=int, default=10,
help='Set size.')
parser.add_argument('--point_dim', '-pd', type=int, default=4,
help='Point dimension.')
parser.add_argument('--transport_matrix', '-tm', action='store_true',
help='Return transport matrix.')

args = parser.parse_args()

# Set the parameters
minibatch_size = args.batch_size
set_size = args.set_size
point_dim = args.point_dim

# Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points
set_a = torch.rand([minibatch_size, set_size, point_dim])
set_b = torch.rand([minibatch_size, set_size, point_dim])

print('Set A')
print(set_a)

print('Set B')
print(set_b)

# Create a loss function module with default parameters. See the class documentation for optional parameters.
loss_fun = SinkhornLoss(return_transport_matrix=args.transport_matrix)

# Compute the loss between each pair of sets in the minibatch
# loss is a tensor with [minibatch_size] elements which can be backpropagated through
if args.transport_matrix:
loss, P = loss_fun(set_a, set_b)
print('Transport Matrix')
print(P)
else:
loss = loss_fun(set_a, set_b)

print('Loss')
print(loss)

60 changes: 60 additions & 0 deletions examples/sinkhorn_loss_functional/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import argparse
import torch
from fml.functional import pairwise_distances, sinkhorn

if __name__ == '__main__':
# Parse input arguments
parser = argparse.ArgumentParser(
description='Sinkhorn loss using the functional interface.')
parser.add_argument('--batch_size', '-bz', type=int, default=3,
help='Batch size.')
parser.add_argument('--set_size', '-sz', type=int, default=10,
help='Set size.')
parser.add_argument('--point_dim', '-pd', type=int, default=4,
help='Point dimension.')
parser.add_argument('--lp_distance', '-p', type=int, default=2,
help='p for the Lp-distance.')

args = parser.parse_args()

# Set the parameters
minibatch_size = args.batch_size
set_size = args.set_size
point_dim = args.point_dim

# Create two minibatches of point sets where each batch item set_a[k, :, :] is a set of `set_size` points
set_a = torch.rand([minibatch_size, set_size, point_dim])
set_b = torch.rand([minibatch_size, set_size, point_dim])

print('Set A')
print(set_a)

print('Set B')
print(set_b)

a = torch.ones(set_a.shape[0:2],
requires_grad=False,
device=set_a.device) / set_a.shape[1]

b = torch.ones(set_b.shape[0:2],
requires_grad=False,
device=set_b.device) / set_b.shape[1]

# Compute the cost matrix
M = pairwise_distances(set_a, set_b, p=args.lp_distance)

print('Distance')
print(M)

# Compute the transport matrix between each pair of sets in the minibatch with default parameters
P = sinkhorn(a, b, M, 1e-3)

print('Transport Matrix')
print(P)

# Compute the loss
loss = (M * P).sum(2).sum(1)

print('Loss')
print(loss)

2 changes: 1 addition & 1 deletion fml/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def pairwise_distances(a: torch.Tensor, b: torch.Tensor, p=2):
if len(b.shape) != 3:
raise ValueError("Invalid shape for a. Must be [m, n, d] but got", b.shape)

return (a.unsqueeze(2) - b.unsqueeze(1)).pow(p).sum(3)
return (a.unsqueeze(2) - b.unsqueeze(1)).abs().pow(p).sum(3)


def chamfer(a, b):
Expand Down

0 comments on commit bd212b8

Please sign in to comment.