Skip to content

Commit

Permalink
Added Carlini LInf attack.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelemarro committed Mar 6, 2020
1 parent 6f6de54 commit 5a7a14d
Showing 1 changed file with 136 additions and 0 deletions.
136 changes: 136 additions & 0 deletions advertorch/attacks/carlini_wagner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
UPPER_CHECK = 1e9
PREV_LOSS_INIT = 1e6
TARGET_MULT = 10000.0
EPS = 1e-6
NUM_CHECKS = 10


Expand Down Expand Up @@ -245,3 +246,138 @@ def perturb(self, x, y=None):
loss_coeffs, coeff_upper_bound, coeff_lower_bound)

return final_advs

class CarliniWagnerLInfAttack(advertorch.attacks.Attack, advertorch.attacks.LabelMixin):
def __init__(self, predict, num_classes, min_tau=1/255,
tau_multiplier=0.9, const_multiplier=2, halve_const=True,
confidence=0, targeted=False, learning_rate=0.01,
max_iterations=10000, abort_early=True,
initial_const=1e-3, clip_min=0., clip_max=1.):
self.predict = predict
self.num_classes = num_classes
self.min_tau = min_tau
self.tau_multiplier = tau_multiplier
self.const_multiplier = const_multiplier
self.halve_const = halve_const
self.confidence = confidence
self.targeted = targeted
self.learning_rate = learning_rate
self.max_iterations = max_iterations
self.abort_early = abort_early
self.initial_const = initial_const
self.clip_min = clip_min
self.clip_max = clip_max

def loss(self, x, delta, y, const, tau):
adversarials = x + delta
output = self.predict(adversarials)
y_onehot = to_one_hot(y, self.num_classes).float()

real = (y_onehot * output).sum(dim=1)

other = ((1.0 - y_onehot) * output - (y_onehot * TARGET_MULT)
).max(dim=1)[0]
# - (y_onehot * TARGET_MULT) is for the true label not to be selected

if self.targeted:
loss1 = torch.clamp(other - real + self.confidence, min=0.)
else:
loss1 = torch.clamp(real - other + self.confidence, min=0.)

penalties = torch.clamp(torch.abs(delta) - tau, min=0)

loss1 = const * loss1
loss2 = torch.sum(penalties, dim=(1, 2, 3))

assert loss1.shape == loss2.shape

loss = loss1 + loss2
return loss, output

def successful(self, outputs, y):
if self.targeted:
adversarial_labels = torch.argmax(outputs, dim=1)
return torch.eq(adversarial_labels, y)
else:
return ~torch.eq(adversarial_labels, y)

# Scales a 0-1 value to clip_min-clip_max range
def scale_to_bounds(self, value):
return self.clip_min + value * (self.clip_max - self.clip_min)

def run_attack(self, x, y, initial_const, tau):
batch_size = len(x)
best_adversarials = x.clone().detach()

active_samples = torch.ones((batch_size,), dtype=torch.bool, device=x.device)

ws = torch.nn.Parameter(torch.zeros_like(x))
optimizer = optim.Adam([ws], lr=self.learning_rate)

const = initial_const

while torch.any(active_samples) and const < CARLINI_COEFF_UPPER:
for i in range(self.max_iterations):
deltas = self.scale_to_bounds((0.5 + EPS) * (torch.tanh(ws) + 1)) - x

optimizer.zero_grad()
losses, outputs = self.loss(x[active_samples], deltas[active_samples], y[active_samples], const, tau)
total_loss = torch.sum(losses)

total_loss.backward()
optimizer.step()

adversarials = (x + deltas).detach()
best_adversarials[active_samples] = adversarials[active_samples]

# If early aborting is enabled, drop successful samples with small losses
# (Notice that the current adversarials are saved regardless of whether they are dropped)

if self.abort_early:
# TODO: Controllare che sia corretto
successful = self.successful(outputs, y[active_samples])
small_losses = losses < 0.0001 * const

drop = successful & small_losses

active_samples[active_samples] = ~drop
if not active_samples.any():
break


const *= self.const_multiplier
#print('Const: {}'.format(const))

return best_adversarials


def perturb(self, x, y=None):
x, y = self._verify_and_process_inputs(x, y)

# Initialization
if y is None:
y = self._get_predicted_label(x)

x = replicate_input(x)
batch_size = len(x)
final_adversarials = x.clone()

active_samples = torch.ones((batch_size,), dtype=torch.bool, device=x.device)

initial_const = self.initial_const
tau = 1

while torch.any(active_samples) and tau >= self.min_tau:
adversarials = self.run_attack(x[active_samples], y[active_samples], initial_const, tau)

# Drop the failed adversarials (without saving them)
successful = self.successful(adversarials, y[active_samples])
active_samples[active_samples] = successful
final_adversarials[active_samples] = adversarials[successful]

tau *= self.tau_multiplier

if self.halve_const:
initial_const /= 2

return final_adversarials

0 comments on commit 5a7a14d

Please sign in to comment.