Skip to content

Commit

Permalink
Implemented Carlini Wagner LInf.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelemarro committed Mar 9, 2020
1 parent 5a7a14d commit 4c23ade
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 74 deletions.
1 change: 1 addition & 0 deletions advertorch/attacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .iterative_projected_gradient import LinfMomentumIterativeAttack

from .carlini_wagner import CarliniWagnerL2Attack
from .carlini_wagner import CarliniWagnerLinfAttack
from .ead import ElasticNetL1Attack

from .decoupled_direction_norm import DDNL2Attack
Expand Down
222 changes: 149 additions & 73 deletions advertorch/attacks/carlini_wagner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,30 @@
EPS = 1e-6
NUM_CHECKS = 10

try:
boolean_type = torch.bool
except AttributeError:
# Old version, use torch.uint8
boolean_type = torch.uint8

class CarliniWagnerL2Attack(Attack, LabelMixin):
"""
The Carlini and Wagner L2 Attack, https://arxiv.org/abs/1608.04644
:param predict: forward pass function.
:param num_classes: number of clasess.
:param num_classes: number of classes.
:param confidence: confidence of the adversarial examples.
:param targeted: if the attack is targeted.
:param learning_rate: the learning rate for the attack algorithm
:param learning_rate: the learning rate for the attack algorithm.
:param binary_search_steps: number of binary search times to find the
optimum
:param max_iterations: the maximum number of iterations
optimum.
:param max_iterations: the maximum number of iterations.
:param abort_early: if set to true, abort early if getting stuck in local
min
:param initial_const: initial value of the constant c
min.
:param initial_const: initial value of the constant c.
:param clip_min: mininum value per input dimension.
:param clip_max: maximum value per input dimension.
:param loss_fn: loss function
:param loss_fn: loss function.
"""

def __init__(self, predict, num_classes, confidence=0,
Expand All @@ -67,7 +72,7 @@ def __init__(self, predict, num_classes, confidence=0,
if loss_fn is not None:
import warnings
warnings.warn(
"This Attack currently do not support a different loss"
"This Attack currently does not support a different loss"
" function other than the default. Setting loss_fn manually"
" is not effective."
)
Expand Down Expand Up @@ -247,106 +252,156 @@ def perturb(self, x, y=None):

return final_advs

class CarliniWagnerLInfAttack(advertorch.attacks.Attack, advertorch.attacks.LabelMixin):
def __init__(self, predict, num_classes, min_tau=1/255,
tau_multiplier=0.9, const_multiplier=2, halve_const=True,
confidence=0, targeted=False, learning_rate=0.01,
max_iterations=10000, abort_early=True,
initial_const=1e-3, clip_min=0., clip_max=1.):
class CarliniWagnerLinfAttack(Attack, LabelMixin):
"""
The Carlini and Wagner LInfinity Attack, https://arxiv.org/abs/1608.04644
:param predict: forward pass function (pre-softmax).
:param num_classes: number of classes.
:param min_tau: the minimum value of tau.
:param initial_tau: the initial value of tau.
:param tau_factor: the decay rate of tau (between 0 and 1)
:param initial_const: initial value of the constant c.
:param max_const: the maximum value of the constant c.
:param const_factor: the rate of growth of the constant c.
:param reduce_const: if True, the inital value of c is halved every
time tau is reduced.
:param warm_start: if True, use the previous adversarials as starting point
for the next iteration.
:param targeted: if the attack is targeted.
:param learning_rate: the learning rate for the attack algorithm.
:param max_iterations: the maximum number of iterations.
:param abort_early: if set to true, abort early if getting stuck in local
min.
:param clip_min: mininum value per input dimension.
:param clip_max: maximum value per input dimension.
:param loss_fn: loss function
"""
def __init__(self, predict, num_classes, min_tau=1/256,
initial_tau = 1, tau_factor=0.9, initial_const=1e-5,
max_const=20, const_factor=2, reduce_const=False,
warm_start=True, targeted=False, learning_rate=5e-3,
max_iterations=1000, abort_early=True, clip_min=0.,
clip_max=1., loss_fn=None):
"""Carlini Wagner LInfinity Attack implementation in pytorch."""
if loss_fn is not None:
import warnings
warnings.warn(
"This Attack currently does not support a different loss"
" function other than the default. Setting loss_fn manually"
" is not effective."
)

loss_fn = None

super(CarliniWagnerLinfAttack, self).__init__(
predict, loss_fn, clip_min, clip_max)

self.predict = predict
self.num_classes = num_classes
self.min_tau = min_tau
self.tau_multiplier = tau_multiplier
self.const_multiplier = const_multiplier
self.halve_const = halve_const
self.confidence = confidence
self.initial_tau = initial_tau
self.tau_factor = tau_factor
self.initial_const = initial_const
self.max_const = max_const
self.const_factor = const_factor
self.reduce_const = reduce_const
self.warm_start = warm_start
self.targeted = targeted
self.learning_rate = learning_rate
self.max_iterations = max_iterations
self.abort_early = abort_early
self.initial_const = initial_const
self.clip_min = clip_min
self.clip_max = clip_max

def loss(self, x, delta, y, const, tau):
adversarials = x + delta
output = self.predict(adversarials)
def _get_arctanh_x(self, x):
result = clamp((x - self.clip_min) / (self.clip_max - self.clip_min),
min=0., max=1.) * 2 - 1
return torch_arctanh(result * ONE_MINUS_EPS)

def _outputs_and_loss(self, x, modifiers, starting_atanh, y, const, taus):
adversarials = tanh_rescale(starting_atanh + modifiers, self.clip_min, self.clip_max)

outputs = self.predict(adversarials)
y_onehot = to_one_hot(y, self.num_classes).float()

real = (y_onehot * output).sum(dim=1)
real = (y_onehot * outputs).sum(dim=1)

other = ((1.0 - y_onehot) * output - (y_onehot * TARGET_MULT)
other = ((1.0 - y_onehot) * outputs - (y_onehot * TARGET_MULT)
).max(dim=1)[0]
# - (y_onehot * TARGET_MULT) is for the true label not to be selected

if self.targeted:
loss1 = torch.clamp(other - real + self.confidence, min=0.)
loss1 = torch.clamp(other - real, min=0.)
else:
loss1 = torch.clamp(real - other + self.confidence, min=0.)

penalties = torch.clamp(torch.abs(delta) - tau, min=0)
loss1 = torch.clamp(real - other, min=0.)

loss1 = const * loss1
loss2 = torch.sum(penalties, dim=(1, 2, 3))

image_dimensions = tuple(range(1, len(x.shape)))
taus_shape = (-1,) + (1,) * (len(x.shape) - 1)

penalties = torch.clamp(torch.abs(x - adversarials) - taus.view(taus_shape), min=0)
loss2 = torch.sum(penalties, dim=image_dimensions)

assert loss1.shape == loss2.shape

loss = loss1 + loss2
return loss, output
return outputs.detach(), loss

def _successful(self, outputs, y):
adversarial_labels = torch.argmax(outputs, dim=1)

def successful(self, outputs, y):
if self.targeted:
adversarial_labels = torch.argmax(outputs, dim=1)
return torch.eq(adversarial_labels, y)
else:
return ~torch.eq(adversarial_labels, y)

# Scales a 0-1 value to clip_min-clip_max range
def scale_to_bounds(self, value):
return self.clip_min + value * (self.clip_max - self.clip_min)

def run_attack(self, x, y, initial_const, tau):
def _run_attack(self, x, y, initial_const, taus, prev_adversarials):
assert len(x) == len(taus)
batch_size = len(x)
best_adversarials = x.clone().detach()

active_samples = torch.ones((batch_size,), dtype=torch.bool, device=x.device)

ws = torch.nn.Parameter(torch.zeros_like(x))
optimizer = optim.Adam([ws], lr=self.learning_rate)
if self.warm_start:
starting_atanh = self._get_arctanh_x(prev_adversarials.clone())
else:
starting_atanh = self._get_arctanh_x(x.clone())

modifiers = torch.nn.Parameter(torch.zeros_like(starting_atanh))

# An array of booleans that stores which samples have not converged
# yet
active = torch.ones((batch_size,), dtype=boolean_type, device=x.device)
optimizer = optim.Adam([modifiers], lr=self.learning_rate)

const = initial_const

while torch.any(active_samples) and const < CARLINI_COEFF_UPPER:
for i in range(self.max_iterations):
deltas = self.scale_to_bounds((0.5 + EPS) * (torch.tanh(ws) + 1)) - x

while torch.any(active) and const < self.max_const:
for _ in range(self.max_iterations):
optimizer.zero_grad()
losses, outputs = self.loss(x[active_samples], deltas[active_samples], y[active_samples], const, tau)
total_loss = torch.sum(losses)
outputs, loss = self._outputs_and_loss(x[active], modifiers[active], starting_atanh[active], y[active], const, taus[active])
total_loss = torch.sum(loss)

total_loss.backward()
optimizer.step()

adversarials = (x + deltas).detach()
best_adversarials[active_samples] = adversarials[active_samples]
adversarials = tanh_rescale(starting_atanh + modifiers, self.clip_min, self.clip_max).detach()
best_adversarials[active] = adversarials[active]

# If early aborting is enabled, drop successful samples with small losses
# (Notice that the current adversarials are saved regardless of whether they are dropped)

# If early abortion is enabled, drop successful samples with a small loss
# (The current adversarials are saved regardless of whether they are dropped)
if self.abort_early:
# TODO: Controllare che sia corretto
successful = self.successful(outputs, y[active_samples])
small_losses = losses < 0.0001 * const
successful = self._successful(outputs, y[active])
small_loss = loss < 0.0001 * const

drop = successful & small_losses
drop = successful & small_loss
active[active] = ~drop

active_samples[active_samples] = ~drop
if not active_samples.any():
if not active.any():
break


const *= self.const_multiplier
#print('Const: {}'.format(const))
# Give more weight to the output loss
const *= self.const_factor

return best_adversarials

Expand All @@ -357,27 +412,48 @@ def perturb(self, x, y=None):
# Initialization
if y is None:
y = self._get_predicted_label(x)

x = replicate_input(x)
batch_size = len(x)
final_adversarials = x.clone()

active_samples = torch.ones((batch_size,), dtype=torch.bool, device=x.device)

# An array of booleans that stores which samples have not converged
# yet
active = torch.ones((batch_size,), dtype=boolean_type, device=x.device)

initial_const = self.initial_const
tau = 1
taus = torch.ones((batch_size,), device=x.device) * self.initial_tau

# The previous adversarials. This is used to perform a "warm start"
# during optimization
prev_adversarials = x.clone()

while torch.any(active):
adversarials = self._run_attack(x[active], y[active], initial_const, taus[active], prev_adversarials[active].clone())

while torch.any(active_samples) and tau >= self.min_tau:
adversarials = self.run_attack(x[active_samples], y[active_samples], initial_const, tau)
# Store the adversarials for the next iteration, even if they failed
prev_adversarials[active] = adversarials

# Drop the failed adversarials (without saving them)
successful = self.successful(adversarials, y[active_samples])
active_samples[active_samples] = successful
final_adversarials[active_samples] = adversarials[successful]
# Drop failed adversarials (without saving them)
adversarial_outputs = self.predict(adversarials)
successful = self._successful(adversarial_outputs, y[active])
active[active] = successful

tau *= self.tau_multiplier
# Save the remaining adversarials
final_adversarials[active] = adversarials[successful]

if self.halve_const:
# If the Linf distance is lower than tau and the adversarial is successful, use it as the new tau
linf_distances = torch.max(torch.abs(final_adversarials - x).view(batch_size, -1), dim=1)[0]
linf_lower = linf_distances < taus
taus[linf_lower] = linf_distances[linf_lower]

taus *= self.tau_factor

if self.reduce_const:
initial_const /= 2

# Drop samples with a low tau
low_tau = taus[active] <= self.min_tau
active[active] = ~low_tau

return final_adversarials
5 changes: 4 additions & 1 deletion advertorch/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -16,6 +15,7 @@
from advertorch.attacks import JacobianSaliencyMapAttack
from advertorch.attacks import LBFGSAttack
from advertorch.attacks import CarliniWagnerL2Attack
from advertorch.attacks import CarliniWagnerLinfAttack
from advertorch.attacks import DDNL2Attack
from advertorch.attacks import FastFeatureAttack
from advertorch.attacks import MomentumIterativeAttack
Expand Down Expand Up @@ -228,6 +228,7 @@ def generate_data_model_on_img():
MomentumIterativeAttack,
FastFeatureAttack,
CarliniWagnerL2Attack,
CarliniWagnerLinfAttack,
ElasticNetL1Attack,
LBFGSAttack,
JacobianSaliencyMapAttack,
Expand All @@ -254,6 +255,7 @@ def generate_data_model_on_img():
LinfPGDAttack,
MomentumIterativeAttack,
CarliniWagnerL2Attack,
CarliniWagnerLinfAttack,
ElasticNetL1Attack,
LBFGSAttack,
JacobianSaliencyMapAttack,
Expand Down Expand Up @@ -286,6 +288,7 @@ def generate_data_model_on_img():
LinfSPSAAttack,
# FABAttack,
# CarliniWagnerL2Attack, # XXX: not exactly sure: test says no
CarliniWagnerLinfAttack
# LBFGSAttack, # XXX: not exactly sure: test says no
# SpatialTransformAttack, # XXX: not exactly sure: test says no
]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_attacks_running.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from advertorch.attacks import MomentumIterativeAttack
from advertorch.attacks import FastFeatureAttack
from advertorch.attacks import CarliniWagnerL2Attack
from advertorch.attacks import CarliniWagnerLinfAttack
from advertorch.attacks import DDNL2Attack
from advertorch.attacks import ElasticNetL1Attack
from advertorch.attacks import LBFGSAttack
Expand Down Expand Up @@ -76,6 +77,7 @@
LinfPGDAttack: {"rand_init": False, "nb_iter": 5},
MomentumIterativeAttack: {"nb_iter": 5},
CarliniWagnerL2Attack: {"num_classes": NUM_CLASS, "max_iterations": 10},
CarliniWagnerLinfAttack : {"num_classes" : NUM_CLASS, "max_iterations" : 2, "initial_tau" : 1/128},
ElasticNetL1Attack: {"num_classes": NUM_CLASS, "max_iterations": 10},
FastFeatureAttack: {"rand_init": False, "nb_iter": 5},
LBFGSAttack: {"num_classes": NUM_CLASS},
Expand Down

0 comments on commit 4c23ade

Please sign in to comment.