From 4c23ade02e40c0a6a783b4f58f999a1070cb7fbc Mon Sep 17 00:00:00 2001 From: Samuele Marro Date: Sun, 8 Mar 2020 16:22:57 +0100 Subject: [PATCH] Implemented Carlini Wagner LInf. --- advertorch/attacks/__init__.py | 1 + advertorch/attacks/carlini_wagner.py | 222 ++++++++++++++++++--------- advertorch/test_utils.py | 5 +- tests/test_attacks_running.py | 2 + 4 files changed, 156 insertions(+), 74 deletions(-) diff --git a/advertorch/attacks/__init__.py b/advertorch/attacks/__init__.py index 522ec58..f292f62 100644 --- a/advertorch/attacks/__init__.py +++ b/advertorch/attacks/__init__.py @@ -29,6 +29,7 @@ from .iterative_projected_gradient import LinfMomentumIterativeAttack from .carlini_wagner import CarliniWagnerL2Attack +from .carlini_wagner import CarliniWagnerLinfAttack from .ead import ElasticNetL1Attack from .decoupled_direction_norm import DDNL2Attack diff --git a/advertorch/attacks/carlini_wagner.py b/advertorch/attacks/carlini_wagner.py index 410c0c6..0a0bd0c 100644 --- a/advertorch/attacks/carlini_wagner.py +++ b/advertorch/attacks/carlini_wagner.py @@ -37,25 +37,30 @@ EPS = 1e-6 NUM_CHECKS = 10 +try: + boolean_type = torch.bool +except AttributeError: + # Old version, use torch.uint8 + boolean_type = torch.uint8 class CarliniWagnerL2Attack(Attack, LabelMixin): """ The Carlini and Wagner L2 Attack, https://arxiv.org/abs/1608.04644 :param predict: forward pass function. - :param num_classes: number of clasess. + :param num_classes: number of classes. :param confidence: confidence of the adversarial examples. :param targeted: if the attack is targeted. - :param learning_rate: the learning rate for the attack algorithm + :param learning_rate: the learning rate for the attack algorithm. :param binary_search_steps: number of binary search times to find the - optimum - :param max_iterations: the maximum number of iterations + optimum. + :param max_iterations: the maximum number of iterations. :param abort_early: if set to true, abort early if getting stuck in local - min - :param initial_const: initial value of the constant c + min. + :param initial_const: initial value of the constant c. :param clip_min: mininum value per input dimension. :param clip_max: maximum value per input dimension. - :param loss_fn: loss function + :param loss_fn: loss function. """ def __init__(self, predict, num_classes, confidence=0, @@ -67,7 +72,7 @@ def __init__(self, predict, num_classes, confidence=0, if loss_fn is not None: import warnings warnings.warn( - "This Attack currently do not support a different loss" + "This Attack currently does not support a different loss" " function other than the default. Setting loss_fn manually" " is not effective." ) @@ -247,106 +252,156 @@ def perturb(self, x, y=None): return final_advs -class CarliniWagnerLInfAttack(advertorch.attacks.Attack, advertorch.attacks.LabelMixin): - def __init__(self, predict, num_classes, min_tau=1/255, - tau_multiplier=0.9, const_multiplier=2, halve_const=True, - confidence=0, targeted=False, learning_rate=0.01, - max_iterations=10000, abort_early=True, - initial_const=1e-3, clip_min=0., clip_max=1.): +class CarliniWagnerLinfAttack(Attack, LabelMixin): + """ + The Carlini and Wagner LInfinity Attack, https://arxiv.org/abs/1608.04644 + + :param predict: forward pass function (pre-softmax). + :param num_classes: number of classes. + :param min_tau: the minimum value of tau. + :param initial_tau: the initial value of tau. + :param tau_factor: the decay rate of tau (between 0 and 1) + :param initial_const: initial value of the constant c. + :param max_const: the maximum value of the constant c. + :param const_factor: the rate of growth of the constant c. + :param reduce_const: if True, the inital value of c is halved every + time tau is reduced. + :param warm_start: if True, use the previous adversarials as starting point + for the next iteration. + :param targeted: if the attack is targeted. + :param learning_rate: the learning rate for the attack algorithm. + :param max_iterations: the maximum number of iterations. + :param abort_early: if set to true, abort early if getting stuck in local + min. + :param clip_min: mininum value per input dimension. + :param clip_max: maximum value per input dimension. + :param loss_fn: loss function + """ + def __init__(self, predict, num_classes, min_tau=1/256, + initial_tau = 1, tau_factor=0.9, initial_const=1e-5, + max_const=20, const_factor=2, reduce_const=False, + warm_start=True, targeted=False, learning_rate=5e-3, + max_iterations=1000, abort_early=True, clip_min=0., + clip_max=1., loss_fn=None): + """Carlini Wagner LInfinity Attack implementation in pytorch.""" + if loss_fn is not None: + import warnings + warnings.warn( + "This Attack currently does not support a different loss" + " function other than the default. Setting loss_fn manually" + " is not effective." + ) + + loss_fn = None + + super(CarliniWagnerLinfAttack, self).__init__( + predict, loss_fn, clip_min, clip_max) + self.predict = predict self.num_classes = num_classes self.min_tau = min_tau - self.tau_multiplier = tau_multiplier - self.const_multiplier = const_multiplier - self.halve_const = halve_const - self.confidence = confidence + self.initial_tau = initial_tau + self.tau_factor = tau_factor + self.initial_const = initial_const + self.max_const = max_const + self.const_factor = const_factor + self.reduce_const = reduce_const + self.warm_start = warm_start self.targeted = targeted self.learning_rate = learning_rate self.max_iterations = max_iterations self.abort_early = abort_early - self.initial_const = initial_const self.clip_min = clip_min self.clip_max = clip_max - def loss(self, x, delta, y, const, tau): - adversarials = x + delta - output = self.predict(adversarials) + def _get_arctanh_x(self, x): + result = clamp((x - self.clip_min) / (self.clip_max - self.clip_min), + min=0., max=1.) * 2 - 1 + return torch_arctanh(result * ONE_MINUS_EPS) + + def _outputs_and_loss(self, x, modifiers, starting_atanh, y, const, taus): + adversarials = tanh_rescale(starting_atanh + modifiers, self.clip_min, self.clip_max) + + outputs = self.predict(adversarials) y_onehot = to_one_hot(y, self.num_classes).float() - real = (y_onehot * output).sum(dim=1) + real = (y_onehot * outputs).sum(dim=1) - other = ((1.0 - y_onehot) * output - (y_onehot * TARGET_MULT) + other = ((1.0 - y_onehot) * outputs - (y_onehot * TARGET_MULT) ).max(dim=1)[0] # - (y_onehot * TARGET_MULT) is for the true label not to be selected if self.targeted: - loss1 = torch.clamp(other - real + self.confidence, min=0.) + loss1 = torch.clamp(other - real, min=0.) else: - loss1 = torch.clamp(real - other + self.confidence, min=0.) - - penalties = torch.clamp(torch.abs(delta) - tau, min=0) + loss1 = torch.clamp(real - other, min=0.) loss1 = const * loss1 - loss2 = torch.sum(penalties, dim=(1, 2, 3)) + + image_dimensions = tuple(range(1, len(x.shape))) + taus_shape = (-1,) + (1,) * (len(x.shape) - 1) + + penalties = torch.clamp(torch.abs(x - adversarials) - taus.view(taus_shape), min=0) + loss2 = torch.sum(penalties, dim=image_dimensions) assert loss1.shape == loss2.shape loss = loss1 + loss2 - return loss, output + return outputs.detach(), loss + + def _successful(self, outputs, y): + adversarial_labels = torch.argmax(outputs, dim=1) - def successful(self, outputs, y): if self.targeted: - adversarial_labels = torch.argmax(outputs, dim=1) return torch.eq(adversarial_labels, y) else: return ~torch.eq(adversarial_labels, y) - # Scales a 0-1 value to clip_min-clip_max range - def scale_to_bounds(self, value): - return self.clip_min + value * (self.clip_max - self.clip_min) - - def run_attack(self, x, y, initial_const, tau): + def _run_attack(self, x, y, initial_const, taus, prev_adversarials): + assert len(x) == len(taus) batch_size = len(x) best_adversarials = x.clone().detach() - - active_samples = torch.ones((batch_size,), dtype=torch.bool, device=x.device) - ws = torch.nn.Parameter(torch.zeros_like(x)) - optimizer = optim.Adam([ws], lr=self.learning_rate) + if self.warm_start: + starting_atanh = self._get_arctanh_x(prev_adversarials.clone()) + else: + starting_atanh = self._get_arctanh_x(x.clone()) + + modifiers = torch.nn.Parameter(torch.zeros_like(starting_atanh)) + + # An array of booleans that stores which samples have not converged + # yet + active = torch.ones((batch_size,), dtype=boolean_type, device=x.device) + optimizer = optim.Adam([modifiers], lr=self.learning_rate) const = initial_const - while torch.any(active_samples) and const < CARLINI_COEFF_UPPER: - for i in range(self.max_iterations): - deltas = self.scale_to_bounds((0.5 + EPS) * (torch.tanh(ws) + 1)) - x - + while torch.any(active) and const < self.max_const: + for _ in range(self.max_iterations): optimizer.zero_grad() - losses, outputs = self.loss(x[active_samples], deltas[active_samples], y[active_samples], const, tau) - total_loss = torch.sum(losses) + outputs, loss = self._outputs_and_loss(x[active], modifiers[active], starting_atanh[active], y[active], const, taus[active]) + total_loss = torch.sum(loss) total_loss.backward() optimizer.step() - adversarials = (x + deltas).detach() - best_adversarials[active_samples] = adversarials[active_samples] + adversarials = tanh_rescale(starting_atanh + modifiers, self.clip_min, self.clip_max).detach() + best_adversarials[active] = adversarials[active] - # If early aborting is enabled, drop successful samples with small losses - # (Notice that the current adversarials are saved regardless of whether they are dropped) - + # If early abortion is enabled, drop successful samples with a small loss + # (The current adversarials are saved regardless of whether they are dropped) if self.abort_early: - # TODO: Controllare che sia corretto - successful = self.successful(outputs, y[active_samples]) - small_losses = losses < 0.0001 * const + successful = self._successful(outputs, y[active]) + small_loss = loss < 0.0001 * const - drop = successful & small_losses + drop = successful & small_loss + active[active] = ~drop - active_samples[active_samples] = ~drop - if not active_samples.any(): + if not active.any(): break - - const *= self.const_multiplier - #print('Const: {}'.format(const)) + # Give more weight to the output loss + const *= self.const_factor return best_adversarials @@ -357,27 +412,48 @@ def perturb(self, x, y=None): # Initialization if y is None: y = self._get_predicted_label(x) - + x = replicate_input(x) batch_size = len(x) final_adversarials = x.clone() - - active_samples = torch.ones((batch_size,), dtype=torch.bool, device=x.device) + + # An array of booleans that stores which samples have not converged + # yet + active = torch.ones((batch_size,), dtype=boolean_type, device=x.device) initial_const = self.initial_const - tau = 1 + taus = torch.ones((batch_size,), device=x.device) * self.initial_tau + + # The previous adversarials. This is used to perform a "warm start" + # during optimization + prev_adversarials = x.clone() + + while torch.any(active): + adversarials = self._run_attack(x[active], y[active], initial_const, taus[active], prev_adversarials[active].clone()) - while torch.any(active_samples) and tau >= self.min_tau: - adversarials = self.run_attack(x[active_samples], y[active_samples], initial_const, tau) + # Store the adversarials for the next iteration, even if they failed + prev_adversarials[active] = adversarials - # Drop the failed adversarials (without saving them) - successful = self.successful(adversarials, y[active_samples]) - active_samples[active_samples] = successful - final_adversarials[active_samples] = adversarials[successful] + # Drop failed adversarials (without saving them) + adversarial_outputs = self.predict(adversarials) + successful = self._successful(adversarial_outputs, y[active]) + active[active] = successful - tau *= self.tau_multiplier + # Save the remaining adversarials + final_adversarials[active] = adversarials[successful] - if self.halve_const: + # If the Linf distance is lower than tau and the adversarial is successful, use it as the new tau + linf_distances = torch.max(torch.abs(final_adversarials - x).view(batch_size, -1), dim=1)[0] + linf_lower = linf_distances < taus + taus[linf_lower] = linf_distances[linf_lower] + + taus *= self.tau_factor + + if self.reduce_const: initial_const /= 2 + # Drop samples with a low tau + low_tau = taus[active] <= self.min_tau + active[active] = ~low_tau + return final_adversarials \ No newline at end of file diff --git a/advertorch/test_utils.py b/advertorch/test_utils.py index 665867c..4ca6f7d 100644 --- a/advertorch/test_utils.py +++ b/advertorch/test_utils.py @@ -5,7 +5,6 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # - import torch import torch.nn as nn import torch.nn.functional as F @@ -16,6 +15,7 @@ from advertorch.attacks import JacobianSaliencyMapAttack from advertorch.attacks import LBFGSAttack from advertorch.attacks import CarliniWagnerL2Attack +from advertorch.attacks import CarliniWagnerLinfAttack from advertorch.attacks import DDNL2Attack from advertorch.attacks import FastFeatureAttack from advertorch.attacks import MomentumIterativeAttack @@ -228,6 +228,7 @@ def generate_data_model_on_img(): MomentumIterativeAttack, FastFeatureAttack, CarliniWagnerL2Attack, + CarliniWagnerLinfAttack, ElasticNetL1Attack, LBFGSAttack, JacobianSaliencyMapAttack, @@ -254,6 +255,7 @@ def generate_data_model_on_img(): LinfPGDAttack, MomentumIterativeAttack, CarliniWagnerL2Attack, + CarliniWagnerLinfAttack, ElasticNetL1Attack, LBFGSAttack, JacobianSaliencyMapAttack, @@ -286,6 +288,7 @@ def generate_data_model_on_img(): LinfSPSAAttack, # FABAttack, # CarliniWagnerL2Attack, # XXX: not exactly sure: test says no + CarliniWagnerLinfAttack # LBFGSAttack, # XXX: not exactly sure: test says no # SpatialTransformAttack, # XXX: not exactly sure: test says no ] diff --git a/tests/test_attacks_running.py b/tests/test_attacks_running.py index 9e82866..274c1b5 100644 --- a/tests/test_attacks_running.py +++ b/tests/test_attacks_running.py @@ -26,6 +26,7 @@ from advertorch.attacks import MomentumIterativeAttack from advertorch.attacks import FastFeatureAttack from advertorch.attacks import CarliniWagnerL2Attack +from advertorch.attacks import CarliniWagnerLinfAttack from advertorch.attacks import DDNL2Attack from advertorch.attacks import ElasticNetL1Attack from advertorch.attacks import LBFGSAttack @@ -76,6 +77,7 @@ LinfPGDAttack: {"rand_init": False, "nb_iter": 5}, MomentumIterativeAttack: {"nb_iter": 5}, CarliniWagnerL2Attack: {"num_classes": NUM_CLASS, "max_iterations": 10}, + CarliniWagnerLinfAttack : {"num_classes" : NUM_CLASS, "max_iterations" : 2, "initial_tau" : 1/128}, ElasticNetL1Attack: {"num_classes": NUM_CLASS, "max_iterations": 10}, FastFeatureAttack: {"rand_init": False, "nb_iter": 5}, LBFGSAttack: {"num_classes": NUM_CLASS},