-
Notifications
You must be signed in to change notification settings - Fork 55
/
sparse_softmax.py
137 lines (102 loc) · 4.47 KB
/
sparse_softmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
An original implementation of sparsemax (Martins & Astudillo, 2016) is available at
https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py.
See `From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification, ICML 2016`
for detailed description.
We make some modifications to make it work at scatter operation scenarios, e.g., calculate softmax according to batch
indicators.
Usage:
>> x = torch.tensor([ 1.7301, 0.6792, -1.0565, 1.6614, -0.3196, -0.7790, -0.3877, -0.4943,
0.1831, -0.0061])
>> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
>> sparse_attention = Sparsemax()
>> res = sparse_attention(x, batch)
>> print(res)
tensor([0.5343, 0.0000, 0.0000, 0.4657, 0.0612, 0.0000, 0.0000, 0.0000, 0.5640,
0.3748])
"""
import torch
import torch.nn as nn
from torch.autograd import Function
from torch_scatter import scatter_add, scatter_max
def scatter_sort(x, batch, fill_value=-1e16):
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item()
cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
dense_x = x.new_full((batch_size * max_num_nodes,), fill_value)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)
sorted_x, _ = dense_x.sort(dim=-1, descending=True)
cumsum_sorted_x = sorted_x.cumsum(dim=-1)
cumsum_sorted_x = cumsum_sorted_x.view(-1)
sorted_x = sorted_x.view(-1)
filled_index = sorted_x != fill_value
sorted_x = sorted_x[filled_index]
cumsum_sorted_x = cumsum_sorted_x[filled_index]
return sorted_x, cumsum_sorted_x
def _make_ix_like(batch):
num_nodes = scatter_add(batch.new_ones(batch.size(0)), batch, dim=0)
idx = [torch.arange(1, i + 1, dtype=torch.long, device=batch.device) for i in num_nodes]
idx = torch.cat(idx, dim=0)
return idx
def _threshold_and_support(x, batch):
"""Sparsemax building block: compute the threshold
Args:
x: input tensor to apply the sparsemax
batch: group indicators
Returns:
the threshold value
"""
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)
cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0)
sorted_input, input_cumsum = scatter_sort(x, batch)
input_cumsum = input_cumsum - 1.0
rhos = _make_ix_like(batch).to(x.dtype)
support = rhos * sorted_input > input_cumsum
support_size = scatter_add(support.to(batch.dtype), batch)
# mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index
idx = support_size + cum_num_nodes - 1
mask = idx < 0
idx[mask] = 0
tau = input_cumsum.gather(0, idx)
tau /= support_size.to(x.dtype)
return tau, support_size
class SparsemaxFunction(Function):
@staticmethod
def forward(ctx, x, batch):
"""sparsemax: normalizing sparse transform
Parameters:
ctx: context object
x (Tensor): shape (N, )
batch: group indicator
Returns:
output (Tensor): same shape as input
"""
max_val, _ = scatter_max(x, batch)
x -= max_val[batch]
tau, supp_size = _threshold_and_support(x, batch)
output = torch.clamp(x - tau[batch], min=0)
ctx.save_for_backward(supp_size, output, batch)
return output
@staticmethod
def backward(ctx, grad_output):
supp_size, output, batch = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[output == 0] = 0
v_hat = scatter_add(grad_input, batch) / supp_size.to(output.dtype)
grad_input = torch.where(output != 0, grad_input - v_hat[batch], grad_input)
return grad_input, None
sparsemax = SparsemaxFunction.apply
class Sparsemax(nn.Module):
def __init__(self):
super(Sparsemax, self).__init__()
def forward(self, x, batch):
return sparsemax(x, batch)
if __name__ == '__main__':
sparse_attention = Sparsemax()
input_x = torch.tensor([1.7301, 0.6792, -1.0565, 1.6614, -0.3196, -0.7790, -0.3877, -0.4943, 0.1831, -0.0061])
input_batch = torch.cat([torch.zeros(4, dtype=torch.long), torch.ones(6, dtype=torch.long)], dim=0)
res = sparse_attention(input_x, input_batch)
print(res)