-
Notifications
You must be signed in to change notification settings - Fork 8
/
shakedrop.py
46 lines (36 loc) · 1.42 KB
/
shakedrop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class ShakeDropFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, training=True, p_drop=0.5, alpha_range=[-1, 1]):
if training:
gate = torch.cuda.FloatTensor([0]).bernoulli_(1 - p_drop)
ctx.save_for_backward(gate)
if gate.item() == 0:
alpha = torch.cuda.FloatTensor(x.size(0)).uniform_(*alpha_range)
alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x)
return alpha * x
else:
return x
else:
return (1 - p_drop) * x
@staticmethod
def backward(ctx, grad_output):
gate = ctx.saved_tensors[0]
if gate.item() == 0:
beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_(0, 1)
beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output)
beta = Variable(beta)
return beta * grad_output, None, None, None
else:
return grad_output, None, None, None
class ShakeDrop(nn.Module):
def __init__(self, p_drop=0.5, alpha_range=[-1, 1]):
super(ShakeDrop, self).__init__()
self.p_drop = p_drop
self.alpha_range = alpha_range
def forward(self, x):
return ShakeDropFunction.apply(x, self.training, self.p_drop, self.alpha_range)