-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel_wrapper.py
1460 lines (1262 loc) · 85.1 KB
/
model_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import torch
import torch.nn.functional as F
import numpy as np
import math
import torch.nn as nn
from attacks import adv_whitebox
from AIDomains.zonotope import HybridZonotope
import AIDomains.abstract_layers as abs_layers
from AIDomains.abstract_layers import AbstractModule, Sequential as absSequential
import AIDomains.concrete_layers as conc_layers
from AIDomains.wrapper import propagate_abs
from AIDomains.ai_util import construct_C
from typing import Callable, Iterable, List, Tuple, Final, Union, Optional
from utils import project_to_bounds, log_cuda_memory, seed_everything
import logging
import torch.jit as jit
def get_model_wrapper(args, net, device, input_dim):
# Define model wrapper here
if args.use_pgd_training:
model_wrapper = PGDModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, enable_EDAC=args.use_EDAC_step)
elif args.use_multipgd_training:
model_wrapper = MultiPGDModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, enable_EDAC=args.use_EDAC_step)
elif args.use_arow_training:
model_wrapper = ARoWModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, args.arow_reg_weight, args.arow_label_smoothing, enable_EDAC=args.use_EDAC_step)
elif args.use_mart_training:
model_wrapper = MARTModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, args.mart_reg_weight, enable_EDAC=args.use_EDAC_step)
elif args.use_ibp_training:
if args.use_small_box:
model_wrapper = SmallBoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage, relu_shrinkage=args.relu_shrinkage)
else:
model_wrapper = BoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
elif args.use_taps_training:
if args.use_small_box:
model_wrapper = STAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes, eps_shrinkage=args.eps_shrinkage, relu_shrinkage=args.relu_shrinkage)
else:
model_wrapper = TAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes, relu_shrinkage=args.relu_shrinkage)
elif args.use_DP_training:
model_wrapper = DeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, use_dp_box=False)
elif args.use_DPBox_training:
model_wrapper = DeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, use_dp_box=True, loss_fusion=args.use_loss_fusion, keep_fusion_when_test=args.keep_fusion_when_test)
elif args.use_mtlibp_training:
model_wrapper = MTLIBPModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, args.ibp_coef, args.attack_range_scale)
elif args.use_expibp_training:
model_wrapper = EXPIBPModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, args.ibp_coef, args.attack_range_scale)
elif args.use_ccibp_training:
model_wrapper = CCIBPModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, args.ibp_coef, args.attack_range_scale)
elif args.use_std_training:
model_wrapper = BasicModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
else:
raise NotImplementedError("Unknown training mode.")
if args.grad_accu_batch is not None:
model_wrapper = GradAccuFunctionWrapper(model_wrapper, args)
if args.use_sam:
model_wrapper = SAMFunctionWrapper(model_wrapper, args.sam_rho, args.adaptive_sam_rho)
if args.use_weight_smooth:
model_wrapper = WeightSmoothFunctionWrapper(model_wrapper, std_scale=args.weight_smooth_std_scale)
return model_wrapper
def set_value_between(self, name:str, value, lower, upper, dtype:Callable):
'''Value checker for properties. Checks whether the value is between lower and upper.'''
value = dtype(value)
assert lower <= value <= upper, f"{name} should be between {lower} and {upper}."
setattr(self, name, value)
def set_value_typecheck(self, name:str, value, dtype):
'''Value checker for properties. Checks whether the value is of the specified type.'''
assert isinstance(value, dtype), f"{name} should be of type {dtype}."
setattr(self, name, value)
def set_value_typecast(self, name:str, value, dtype:Callable, constraint:Optional[Callable]=None, msg:Optional[str]=None):
'''Value checker for properties. Casts the value to the specified type.'''
if constraint:
assert constraint(value), msg
setattr(self, name, dtype(value))
class BasicModelWrapper(nn.Module):
'''
Implements standard training procedure.
@param
net: Callable; the model to be trained
loss_fn: Callable; the loss function
input_dim: Tuple[int, ...]; the input dimension of the model
device: the device to run the model, e.g., "cuda" or "cpu"
args: the arguments from the argument parser
data_range: Tuple[float, float]; the range of the input data, e.g., (0, 1) for images. Default is (0, 1)
@property
robust_weight: float; the weight of the robust loss in the combined loss. Default is None, i.e., calling combine_loss will raise an error.
summary_accu_stat: bool; whether to return the accuracy statistics in the summary. Default is True.
grad_cleaner: read-only; the optimizer to clean the grad; used to call zero_grad only
data_min: float; the minimum value of the input data. Default is 0.0.
data_max: float; the maximum value of the input data. Default is 1.0.
freeze_BN: bool; whether to freeze the BN stat during training. Default is False.
current_eps: float; the current perturbation radius. Default is 0.0.
device: the device to run the model, e.g., "cuda" or "cpu"; can only be modified by wrapper.to(device)
'''
robust_weight = property(fget=lambda self: self._robust_weight, fset=lambda self, value: set_value_between(self, "_robust_weight", value, 0, 1, float))
summary_accu_stat = property(fget=lambda self: self._summary_accu_stat, fset=lambda self, value: set_value_typecheck(self, "_summary_accu_stat", value, bool))
grad_cleaner = property(fget=lambda self: self._grad_cleaner, fset=None) # read-only
data_min = property(fget=lambda self: self._data_min, fset=lambda self, value: set_value_typecast(self, "_data_min", value, float, lambda x: x <= self.data_max, "data_min should be less than or equal to data_max"))
data_max = property(fget=lambda self: self._data_max, fset=lambda self, value: set_value_typecast(self, "_data_max", value, float, lambda x: x >= self.data_min, "data_max should be greater than or equal to data_min"))
freeze_BN = property(fget=lambda self: self._freeze_BN, fset=lambda self, value: set_value_typecheck(self, "_freeze_BN", value, bool))
current_eps = property(fget=lambda self: self._current_eps, fset=lambda self, value: set_value_typecast(self, "_current_eps", value, float, lambda x: x>=0 and x<=self.max_eps, "max_eps needs to be set before current_eps, and current_eps must be a non-negative float smaller than or equal to max_eps."))
current_lr = property(fget=lambda self: self._current_lr, fset=lambda self, value: set_value_typecast(self, "_current_lr", value, float, lambda x: x>0, "current_lr must be a positive float."))
grad_scaler = property(fget=lambda self: self._grad_scaler, fset=None) # read-only
max_eps = property(fget=lambda self: self._max_eps, fset=lambda self, value: set_value_typecast(self, "_max_eps", value, float, lambda x: x>=self.current_eps, "max_eps must be a float larger than current_eps."))
device = property(fget=lambda self: self._device, fset=None)
def __init__(self, net:absSequential, loss_fn:Callable, input_dim:Tuple[int, ...], device, args, data_range:Tuple[float, float]=(0.0, 1.0)):
super().__init__()
self.net = net
self.BNs = [layer for layer in self.net if isinstance(layer, abs_layers._BatchNorm)]
self._freeze_BN:bool = False
self.loss_fn = loss_fn
self.args = args
self.input_dim = input_dim # currently not used by any since the net already sets the dimensions
self._device = device
self._summary_accu_stat:bool = True
self._data_min:float = float(data_range[0])
self._data_max:float = float(data_range[1])
assert self._data_min <= self._data_max, "data_min should be less than or equal to data_max."
self._robust_weight = None
self._current_eps = 0.
self._current_lr = None
self._grad_cleaner:Final = torch.optim.SGD(self.net.parameters(), lr=1) # will only call zero_grad on it
self._max_eps = max(args.train_eps if hasattr(args, "train_eps") and args.train_eps is not None else 0, args.test_eps if hasattr(args, "test_eps") and args.test_eps is not None else 0)
self.use_amp = hasattr(args, "use_amp") and args.use_amp
self._grad_scaler:Final = torch.cuda.amp.GradScaler() if self.use_amp else None
def forward(self, x:torch.Tensor) -> torch.Tensor:
return self.net(x)
def to(self, device):
self._device = device
self.net = self.net.to(device)
return self
def Get_Performance(self, x:torch.Tensor, y:torch.Tensor, use_model:Optional[Callable]=None) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor]:
'''
Compute standard statistics from the clean input.
@param
x: torch.Tensor; the batched input data
y: torch.Tensor; the batched label
use_model: Callable; the model to be used; if None, use the model in the wrapper, i.e., self.net
@return
loss: torch.Tensor; the loss
accu: torch.Tensor; the accuracy
pred_correct: torch.BoolTensor; whether the prediction for each input in the batch is correct
'''
if use_model is None:
outputs = self.net(x)
else:
outputs = use_model(x)
loss = self.loss_fn(outputs, y)
# detach the loss to save memory when natural loss is not used for training
if self.robust_weight == 1:
loss = loss.detach()
accu, pred_correct = self._Get_Accuracy(outputs, y)
return loss, accu, pred_correct
def _Get_Accuracy(self, outputs:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.BoolTensor]:
'''
Compute the accuracy from the prediction and the label. Not supposed to be called by clients.
@param
outputs: torch.Tensor; the batched prediction
y: torch.Tensor; the batched label
@return
accu: torch.Tensor; the accuracy
pred_correct: torch.BoolTensor; whether the prediction for each input in the batch is correct
'''
assert len(outputs) == len(y), 'prediction and label should match.'
pred_correct = torch.argmax(outputs, dim=1).eq(y)
num_correct = torch.sum(pred_correct)
return num_correct / len(y), pred_correct
def _set_BN(self, BN_layers:List[abs_layers._BatchNorm], update_stat:Optional[bool]=None):
'''
Set the update_stat attribute of the BN layers.
@param
BN_layers: List[abs_layers._BatchNorm]; the list of BN layers
update_stat: bool; whether to update the BN stat when training=True; if None, will use the existing BN stat instead
'''
if update_stat is not None:
for layer in BN_layers:
layer.update_stat = update_stat
def compute_nat_loss_and_set_BN(self, x:torch.Tensor, y:torch.Tensor, **kwargs):
'''
@param
x: torch.Tensor; the batched input data
y: torch.Tensor; the batched label
(optional kwargs) use_model: Callable; the model to be used; by default, use the model in the wrapper, i.e., self.net
@return
nat_loss: torch.Tensor; the natural loss
nat_accu: torch.Tensor; the natural accuracy
is_nat_accu: torch.BoolTensor; whether the prediction for each input in the batch is correct
@remark
Batch norm stat will not be updated if self.freeze_BN is True.
'''
assert (x>=self.data_min).all() and (x<=self.data_max).all(), "Input data should be within the data range."
self._set_BN(self.BNs, update_stat=not self.freeze_BN)
result = self.Get_Performance(x, y, **kwargs)
self._set_BN(self.BNs, update_stat=False)
return result
def get_robust_stat_from_bounds(self, lb:torch.Tensor, ub:torch.Tensor, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor]:
'''
Compute the robust loss and accuracy from the bounds of the input data.
@param
lb: torch.Tensor; the lower bound of the input data
ub: torch.Tensor; the upper bound of the input data
x: torch.Tensor; the batched input data
y: torch.Tensor; the batched label
@return
robust_loss: torch.Tensor; the robust loss
robust_accu: torch.Tensor; the robust accuracy
is_robust_accu: torch.BoolTensor; whether the prediction for each input in the batch is robustly correct
'''
raise NotImplementedError
def get_robust_stat_from_input_noise(self, eps:float, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor]:
'''
Compute the robust loss and accuracy from the input noise, i.e., the perturbation radius eps projected into the given data range [self.data_min, self.data_max].
@param
eps: float; the perturbation radius
x: torch.Tensor; the batched input data
y: torch.Tensor; the batched label
@return
robust_loss: torch.Tensor; the robust loss
robust_accu: torch.Tensor; the robust accuracy
is_robust_accu: torch.BoolTensor; whether the prediction for each input in the batch is robustly correct
@remark
The input data x should be within the data range [self.data_min, self.data_max].
This will internally call get_robust_stat_from_bounds with the projected bounds.
'''
return self.get_robust_stat_from_bounds((x - eps).clamp(min=self.data_min), (x + eps).clamp(max=self.data_max), x, y)
def combine_loss(self, nat_loss:torch.Tensor, robust_loss:torch.Tensor) -> torch.Tensor:
'''
return (1 - self.robust_weight) * nat_loss + self.robust_weight * robust_loss
@param
nat_loss: torch.Tensor; the natural loss
robust_loss: torch.Tensor; the robust loss
@return
loss: torch.Tensor; the combined loss
@remark
self.robust_weight should be set before calling combine_loss.
'''
assert self.robust_weight is not None, "robust_weight should be set before calling combine_loss."
loss = (1 - self.robust_weight) * nat_loss + self.robust_weight * robust_loss
return loss
def grad_postprocess(self) -> None:
'''
A wrapper function to be called right before optimizer.step(); can be used to modify the grad.
@remark
By default, it clips the grad norm. This can be overriden to implement any gradient processing.
'''
# grad clipping
torch.nn.utils.clip_grad_norm_(self.net.parameters(), self.args.grad_clip)
def param_postprocess(self) -> None:
'''
A wrapper function to be called right after optimizer.step(); can be used to modify the parameters.
@remark
By default, it does nothing. This can be overriden to implement any parameter processing.
For example, when gradients are computed with shiftied parameters, the parameters can be shifted back here. For example, see SAMFunctionWrapper.
'''
pass
def format_return(self, loss:torch.Tensor, nat_loss:torch.Tensor, nat_accu:torch.Tensor, is_nat_accu:torch.BoolTensor, robust_loss:torch.Tensor, robust_accu:torch.Tensor, is_robust_accu:torch.BoolTensor):
'''
Format the return values based on the self.summary_accu_stat.
@param
loss: torch.Tensor; the combined loss
nat_loss: torch.Tensor; the natural loss
nat_accu: torch.Tensor; the natural accuracy
is_nat_accu: torch.BoolTensor; whether the prediction for each input in the batch is correct
robust_loss: torch.Tensor; the robust loss
robust_accu: torch.Tensor; the robust accuracy
is_robust_accu: torch.BoolTensor; whether the prediction for each input in the batch is robustly correct
@remark
This is usually the last function to be called in the compute_model_stat method to format the return values.
'''
if self.summary_accu_stat:
return (loss, nat_loss, robust_loss), (nat_accu.item(), robust_accu.item())
else:
return (loss, nat_loss, robust_loss), (nat_accu.item(), robust_accu.item()), (is_nat_accu, is_robust_accu)
def compute_model_stat(self, x:torch.Tensor, y:torch.Tensor, eps:float, **kwargs):
'''
The main function to compute the loss and accuracy statistics. Subclasses should override this function to implement the specific training procedure.
@param
x: torch.Tensor; the batched input data
y: torch.Tensor; the batched label
eps: float; the perturbation radius
@return
the return values formatted by self.format_return
@remark
Subclasses should call self.compute_nat_loss_and_set_BN and self.get_robust_stat_from_input_noise to compute the natural loss and robust loss, respectively. This method is responsible for updating self.current_eps.
When self.freeze_BN is True, the BN stat will not be updated during training.
If self.get_robust_stat_from_bounds is not implemented, then this will check whether eps is 0. If eps is zero, then the robust loss is set to 0 and the robust accuracy is set to natural accuracy. Otherwise, it raises a NotImplementedError.
'''
self.current_eps = eps
nat_loss, nat_accu, is_nat_accu = self.compute_nat_loss_and_set_BN(x, y)
try:
robust_loss, robust_accu, is_robust_accu = self.get_robust_stat_from_input_noise(eps, x, y)
loss = self.combine_loss(nat_loss, robust_loss)
except NotImplementedError:
if eps == 0:
loss, robust_loss, robust_accu, is_robust_accu = nat_loss, torch.zeros_like(nat_accu), nat_accu, is_nat_accu
else:
raise NotImplementedError("Robust computation is not implemented.")
return self.format_return(loss, nat_loss, nat_accu, is_nat_accu, robust_loss, robust_accu, is_robust_accu)
def convert_pred_to_margin(self, y:torch.Tensor, pred:torch.Tensor, rearrange_label:bool=True) -> torch.Tensor:
'''
Convert the prediction to the margin, i.e., the difference between the prediction and the target class, i.e. pred_i - pred_{y}. If rearrange_label is True, then the true class will be rearranged to the first column, i.e., loss(pred, y) == loss(margin, 0). Otherwise, loss(pred, y) == loss(margin, y).
@param
y: torch.Tensor; the batched label
pred: torch.Tensor; the batched prediction
rearrange_label: bool; whether to rearrange the label to the first column
@return
margin: torch.Tensor; the batched margin
@remark
The margin is padded with 0 at the first column by definition when rearrange_label=True. This implementation is based on dense matrix product and thus might be inefficient when the number of classes is large. However, it is usually not a practical issue in certified training.
'''
assert len(y) == len(pred), "y and pred should have the same batch size."
if rearrange_label:
C = construct_C(self.net.output_dim[-1], y)
margin = - torch.bmm(C, pred.unsqueeze(-1)).squeeze(-1)
margin = torch.cat((torch.zeros(size=(margin.size(0), 1), dtype=margin.dtype, device=margin.device), margin), dim=1)
# postcondition: margin[i, 0] == 0 for 0<=i<len(y)
else:
margin = pred - pred[torch.arange(len(y)), y].unsqueeze(-1)
# postcondition: margin[i, y[i]] == 0 for 0<=i<len(y)
return margin
class _EDACMixin:
'''
Implements Extragradient-type method to explicitly Decrease Adversarial Certainty (EDAC).
EDAC is a method to reduce the adversarial certainty of the model to avoid adversarial overfitting.
Reference: https://arxiv.org/abs/2310.04539; https://github.com/TrustMLRG/EDAC
@property
EDAC_optimizer: torch.optim.Optimizer; the optimizer for the EDAC step
EDAC_step_size: float; the step size for the EDAC step
@remark
Subclasses should call EDAC_step in the compute_model_stat method to perform the EDAC step.
'''
EDAC_step_size = property(fget=lambda self: self._EDAC_step_size, fset=lambda self, value: set_value_typecast(self, "_EDAC_step_size", value, float))
EDAC_optimizer:Final = property(fget=lambda self: self._EDAC_optimizer, fset=None) # read-only
def register_EDAC_hyperparam(self, optimizer, EDAC_step_size:float=0.3) -> None:
self._EDAC_optimizer = optimizer
self._EDAC_step_size = float(EDAC_step_size)
def EDAC_step(self, lb:torch.Tensor, ub:torch.Tensor, x:torch.Tensor, y:torch.Tensor) -> None:
'''
Perform the EDAC step as described in the paper.
@param
lb: torch.Tensor; the lower bound of the input data
ub: torch.Tensor; the upper bound of the input data
x: torch.Tensor; the batched input data
y: torch.Tensor; the batched label
@remark
This should be called in the compute_model_stat method to do a pre-update of the model.
'''
assert hasattr(self, "EDAC_optimizer") and hasattr(self, "EDAC_step_size"), "EDAC hyperparam not registered. Call register_EDAC_hyperparam(optimizer, EDAC_step_size) first."
xadv = adv_whitebox(self.net, x, y, lb, ub, self.device, self.num_steps, step_size=max(0.25, 2/self.num_steps), restarts=self.restarts, lossFunc="pgd")
yadv = self.net(xadv)
robust_std = torch.std(yadv, dim=1)
advCertainty = robust_std.mean()
ac_loss = self.EDAC_step_size * advCertainty
self.EDAC_optimizer.zero_grad()
with torch.autocast(device_type=self.device, enabled=False):
if self.use_amp:
self.grad_scaler.scale(ac_loss).backward()
self.grad_scaler.step(self.EDAC_optimizer)
self.grad_scaler.update()
else:
ac_loss.backward()
self.EDAC_optimizer.step()
def cache_adv_xy_if_need(self, advx:Optional[torch.Tensor]=None, advy:Optional[torch.Tensor]=None) -> None:
'''
Cache the adversarial examples and the model's output on the adversarial examples if cache_advx=True and cache_advy=True, respectively. When False, the cache will be cleared.
@param
advx: torch.Tensor; the adversarial examples
advy: torch.Tensor; the model's output on the adversarial examples
@remark
advy will be inferred from advx if advy is None and advx is provided when cache_advy=True.
As a utility function, we do not write this as part of the model wrapper but a separate function to avoid nested design for non-adversarial wrappers, similar to the adv_whitebox function.
'''
assert hasattr(self, "cache_advx") and hasattr(self, "cache_advy"), "cache_advx and cache_advy should be defined in the model wrapper in order to call cache_adv_xy_if_need."
with torch.no_grad():
if self.cache_advx:
assert isinstance(advx, torch.Tensor), "advx should be a torch.Tensor to be cached."
self.current_advx = advx.clone().detach()
else:
self.current_advx = None
if self.cache_advy:
if advy is None:
assert advx is not None, "advy cannot be inferred from advx if advx is not cached."
advy = self.net(advx)
assert isinstance(advy, torch.Tensor), "advy should be a torch.Tensor to be cached."
self.current_advy = advy.clone().detach()
else:
self.current_advy = None
class PGDModelWrapper(_EDACMixin, BasicModelWrapper):
'''
Implements PGD training.
@param
net: Callable; the model to be trained
loss_fn: Callable; the loss function
input_dim: Tuple[int, ...]; the input dimension of the model
device: the device to run the model, e.g., "cuda" or "cpu"
args: the arguments from the argument parser
enable_EDAC: bool; whether to enable EDAC step. If True, then register_EDAC_hyperparam should be called to register the optimizer and step size. By default, False.
cache_advx: bool; whether to cache the adversarial examples found by PGD. By default, False.
cache_advy: bool; whether to cache the model's output on the adversarial examples. By default, False.
@property
num_steps: int; the number of steps for PGD. By default, it is set to args.train_steps.
restarts: int; the number of restarts for PGD. By default, it is set to args.restarts.
current_advx: torch.Tensor; the current adversarial examples
current_advy: torch.Tensor; the current model's output on the adversarial examples
@remark
If subclasses would like to use EDAC step and overrides compute_model_stat, make sure to copy the EDAC step code from compute_model_stat. If compute_model_stat is not overridden, then EDAC step will be inherited automatically and can be enabled via setting enable_EDAC=True and calling register_EDAC_hyperparam.
Cache of adversarial examples and model's output on adversarial examples is useful when the same adversarial examples are used for multiple purposes, e.g., adversarial training and adversarial evaluation. However, note that the cache will consume extra memory and should be cleared when not needed. Cache of new adversarial examples will overwrite the old ones.
'''
num_steps = property(fget=lambda self: self._num_steps, fset=lambda self, value: set_value_typecast(self, "_num_steps", value, int, lambda x: x>0, "num_steps must be a positive integer."))
restarts = property(fget=lambda self: self._restarts, fset=lambda self, value: set_value_typecast(self, "_restarts", value, int, lambda x: x>0, "restarts must be a positive integer."))
enable_EDAC = property(fget=lambda self: self._enable_EDAC, fset=lambda self, value: set_value_typecheck(self, "_enable_EDAC", value, bool))
cache_advx = property(fget=lambda self: self._cache_advx, fset=lambda self, value: set_value_typecheck(self, "_cache_advx", value, bool))
cache_advy = property(fget=lambda self: self._cache_advy, fset=lambda self, value: set_value_typecheck(self, "_cache_advy", value, bool))
step_size = property(fget=lambda self: self._step_size, fset=lambda self, value: set_value_typecast(self, "_step_size", value, float, lambda x: x>0, "step_size must be a positive float."))
def __init__(self, net:absSequential, loss_fn:Callable, input_dim:Tuple[int, ...], device, args, enable_EDAC:bool=False, cache_advx:bool=False, cache_advy:bool=False, **kwargs):
super().__init__(net, loss_fn, input_dim, device, args)
# change robust_weight directly during steps instead of modifying args
self._num_steps = int(args.train_steps)
self._restarts = int(args.restarts)
self._enable_EDAC = bool(enable_EDAC)
self._cache_advx = bool(cache_advx)
self._cache_advy = bool(cache_advy)
self.current_advx = None
self.current_advy = None
self._step_size = max(0.25, 2/self.num_steps) if args.step_size is None else args.step_size
def get_robust_stat_from_bounds(self, lb:torch.Tensor, ub:torch.Tensor, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor]:
xadv = adv_whitebox(self.net, x, y, lb, ub, self.device, self.num_steps, step_size=self.step_size, restarts=self.restarts, lossFunc="pgd")
yadv = self.net(xadv)
adv_accu, is_adv_accu = self._Get_Accuracy(yadv, y)
adv_loss = self.loss_fn(yadv, y)
cache_adv_xy_if_need(self, xadv, yadv)
return adv_loss, adv_accu, is_adv_accu
def compute_model_stat(self, x:torch.Tensor, y:torch.Tensor, eps:float, **kwargs):
self.current_eps = eps
# compute natural loss
nat_loss, nat_accu, is_nat_accu = self.compute_nat_loss_and_set_BN(x, y)
if self.enable_EDAC and self.net.training:
self.EDAC_step((x - eps).clamp(min=self.data_min), (x + eps).clamp(max=self.data_max), x, y)
# EDAC step updates params, thus need to recompute nat loss
nat_loss, nat_accu, is_nat_accu = self.compute_nat_loss_and_set_BN(x, y)
# compute PGD loss
adv_loss, adv_accu, is_adv_accu = self.get_robust_stat_from_input_noise(eps, x, y)
loss = self.combine_loss(nat_loss, adv_loss)
return self.format_return(loss, nat_loss, nat_accu, is_nat_accu, adv_loss, adv_accu, is_adv_accu)
class _LabelSmoothingCrossEntropy(nn.Module):
'''
NLL loss with label smoothing. Used in ARoW. Taken from ARoW code.
@param
smoothing: float; the smoothing factor
'''
smoothing = property(fget=lambda self: self._smoothing, fset=lambda self, value: set_value_between(self, "_smoothing", value, 0, 1, float))
def __init__(self, smoothing:float=0.1):
super(_LabelSmoothingCrossEntropy, self).__init__()
assert 0 <= smoothing <= 1
self._smoothing = float(smoothing)
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
logprobs = F.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = (1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss
return loss.mean()
class ARoWModelWrapper(PGDModelWrapper):
'''
Implements Anti-Robust Weighted Regularization (ARoW) training.
Reference: https://arxiv.org/abs/2206.03353; https://github.com/dyoony/ARoW.
@param
reg_weight: float; the weight of the regularization term in the ARoW loss
smoothing: float; the smoothing factor for label smoothing
@property
LS_loss_fn: _LabelSmoothingCrossEntropy; the loss function with label smoothing
LS_loss_fn.smoothing: float; the smoothing factor
reg_weight: float; the weight of the regularization term in the ARoW loss
'''
reg_weight = property(fget=lambda self: self._reg_weight, fset=lambda self, value: set_value_typecast(self, "_reg_weight", value, float, lambda x: x>0, "reg_weight must be a positive float."))
def __init__(self, net:absSequential, loss_fn:Callable, input_dim:Tuple[int, ...], device, args, reg_weight:float=7, smoothing:float=0.2, **kwargs):
super().__init__(net, loss_fn, input_dim, device, args, **kwargs)
self.LS_loss_fn = _LabelSmoothingCrossEntropy(smoothing=smoothing)
self._reg_weight = float(reg_weight)
def get_ARoW_loss(self, inputs:torch.Tensor, adv_outputs:torch.Tensor, targets:torch.Tensor, model:Callable) -> Tuple[torch.Tensor, torch.Tensor]:
outputs = model(inputs)
adv_probs = F.softmax(adv_outputs, dim=1)
nat_probs = F.softmax(outputs, dim=1)
true_probs = torch.gather(adv_probs, 1, (targets.unsqueeze(1)).long()).squeeze()
sup_loss = self.LS_loss_fn(outputs, targets)
rob_loss = (F.kl_div((adv_probs+1e-12).log(), nat_probs, reduction='none').sum(dim=1) * (1. - true_probs)).mean()
return sup_loss, rob_loss
def get_robust_stat_from_bounds(self, lb:torch.Tensor, ub:torch.Tensor, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor]:
if not self.net.training:
# compute based on PGD
with torch.no_grad():
adv_loss, adv_accu, is_adv_accu = super().get_robust_stat_from_bounds(lb, ub, x, y)
else:
# compute based on ARoW
xadv = adv_whitebox(self.net, x, y, lb, ub, self.device, self.num_steps, step_size=self.step_size, restarts=self.restarts, lossFunc="KL")
yadv = self.net(xadv)
adv_accu, is_adv_accu = self._Get_Accuracy(yadv, y)
cache_adv_xy_if_need(self, xadv, yadv)
sup_loss, reg_loss = self.get_ARoW_loss(x, yadv, y, self.net)
adv_loss = sup_loss + self.reg_weight * reg_loss
return adv_loss, adv_accu, is_adv_accu
class MARTModelWrapper(PGDModelWrapper):
'''
Implements MART training.
Reference: https://openreview.net/forum?id=rklOg6EFwS; Code adapted from https://github.com/dyoony/ARoW.
@param
reg_weight: float; the weight of the regularization term in the MART loss
@property
reg_weight: float; the weight of the regularization term in the MART loss
'''
reg_weight = property(fget=lambda self: self._reg_weight, fset=lambda self, value: set_value_typecast(self, "_reg_weight", value, float, lambda x: x>0, "reg_weight must be a positive float."))
def __init__(self, net:absSequential, loss_fn:Callable, input_dim:Tuple[int,...], device, args, reg_weight:float=5, **kwargs):
super().__init__(net, loss_fn, input_dim, device, args, **kwargs)
self._reg_weight = float(reg_weight)
def get_MART_loss(self, inputs:torch.Tensor, adv_inputs:torch.Tensor, targets:torch.Tensor, model:Callable) -> Tuple[torch.Tensor, torch.Tensor]:
def Boost_CE(adv_outputs, targets):
adv_probs = F.softmax(adv_outputs, dim=1)
tmp1 = torch.argsort(adv_probs, dim=1)[:, -2:]
new_targets = torch.where(tmp1[:, -1] == targets, tmp1[:, -2], tmp1[:, -1])
loss = F.cross_entropy(adv_outputs, targets) + F.nll_loss(torch.log(1.0001 - adv_probs + 1e-12), new_targets)
return loss
outputs = model(inputs)
adv_outputs = model(adv_inputs)
adv_probs = F.softmax(adv_outputs, dim=1)
nat_probs = F.softmax(outputs, dim=1)
true_probs = torch.gather(nat_probs, 1, (targets.unsqueeze(1)).long()).squeeze()
sup_loss = Boost_CE(adv_outputs, targets)
rob_loss = (F.kl_div((adv_probs+1e-12).log(), nat_probs, reduction='none').sum(dim=1) * (1. - true_probs)).mean()
return sup_loss, rob_loss
def get_robust_stat_from_bounds(self, lb:torch.Tensor, ub:torch.Tensor, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor]:
xadv = adv_whitebox(self.net, x, y, lb, ub, self.device, self.num_steps, step_size=self.step_size, restarts=self.restarts, lossFunc="pgd")
yadv = self.net(xadv)
cache_adv_xy_if_need(self, xadv, yadv)
adv_accu, is_adv_accu = self._Get_Accuracy(yadv, y)
sup_loss, rob_loss = self.get_MART_loss(x, xadv, y, self.net)
adv_loss = sup_loss + self.reg_weight * rob_loss
return adv_loss, adv_accu, is_adv_accu
class BoxModelWrapper(BasicModelWrapper):
'''
Implements IBP training
Reference: https://arxiv.org/abs/1810.12715, https://arxiv.org/abs/2103.17268; Code: https://github.com/shizhouxing/Fast-Certified-Robust-Training/
@param
store_box_bounds: bool; whether to store the bounds of the box domain
relu_shrinkage: float, optional; the shrinkage factor for ReLU layers. When set to a float between [0,1], the upper bound for unstable ReLUs will be multiplied by this factor. This behavior is introduced by SABR (https://openreview.net/forum?id=7oFuxtJtUMH).
@property
store_box_bounds: bool; whether to store the bounds of the box domain
relu_shrinkage: float, optional, final; the shrinkage factor for ReLU layers
'''
store_box_bounds = property(fget=lambda self: self._store_box_bounds, fset=lambda self, value: set_value_typecheck(self, "_store_box_bounds", value, bool))
relu_shrinkage = property(fget=lambda self: self._relu_shrinkage, fset=None) # read-only
def __init__(self, net:absSequential, loss_fn:Callable, input_dim:Tuple[int,...], device, args, store_box_bounds:bool=False, relu_shrinkage:Optional[float]=None, **kwargs):
super().__init__(net, loss_fn, input_dim, device, args)
self._store_box_bounds = bool(store_box_bounds)
if relu_shrinkage is not None:
for layer in self.net:
if isinstance(layer, abs_layers.ReLU):
layer.relu_shrinkage = relu_shrinkage
logging.info(f"Setting ReLU shrinkage to {relu_shrinkage}")
self._relu_shrinkage:Final = relu_shrinkage
def get_IBP_bounds(self, abs_net:absSequential, input_lb:torch.Tensor, input_ub:torch.Tensor, y:Optional[torch.Tensor]=None) -> Tuple[torch.Tensor, torch.Tensor]:
'''
If y is specified, then use final layer elision trick and return upper bounds on margin (first column padded with zero) and pseudo-labels (all zero); otherwise, return the lower and upper bounds of the output.
@param
abs_net: absSequential; the model to be used
input_lb: torch.Tensor; the lower bound of the batched input data
input_ub: torch.Tensor; the upper bound of the batched input data
y: Optional[torch.Tensor]; the batched label
'''
x_abs = HybridZonotope.construct_from_bounds(input_lb, input_ub, domain='box')
if y is None:
abs_out = abs_net(x_abs)
out_lb, out_ub = abs_out.concretize()
if not self.store_box_bounds:
abs_net.reset_bounds()
return out_lb, out_ub
else:
pseudo_bound, pseudo_labels = propagate_abs(abs_net, "box", x_abs, y)
if not self.store_box_bounds:
abs_net.reset_bounds()
return pseudo_bound, pseudo_labels
def get_robust_stat_from_bounds(self, lb:torch.Tensor, ub:torch.Tensor, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor]:
pseudo_bound, pseudo_labels = self.get_IBP_bounds(self.net, lb, ub, y)
loss = self.loss_fn(pseudo_bound, pseudo_labels)
robust_accu, is_robust_accu = self._Get_Accuracy(pseudo_bound, pseudo_labels)
return loss, robust_accu, is_robust_accu
class DeepPolyModelWrapper(BasicModelWrapper):
'''
Implements DeepPoly / CROWN-IBP / CROWN-IBP(loss fusion) training
@param
use_dp_box: bool; whether to use CROWN-IBP. When True, all intermediate bounds will be computed using IBP.
loss_fusion: bool; whether to use loss fusion. This will speed up the training by O(num_class) times by only estimating the upper bounds of the final loss instead of the margin.
keep_fusion_when_test: bool; whether to keep loss fusion during evaluation
@property
use_dp_box: bool; whether to use CROWN-IBP
loss_fusion: bool; whether to use loss fusion
keep_fusion_when_test: bool; whether to keep loss fusion during evaluation
@remark
The loss_fusion is only applied during training and when use_dp_box=True. When DeepPoly is used, i.e., use_dp_box=False, the loss_fusion is not ignored.
When loss_fusion=True, the robust accuracy cannot be estimated and is set to 0 during training; the robust accuracy is based on CROWN-IBP during evaluation.
Loss fusion is not always more precise since the margin is estimated with IBP bounds instead of CROWN-IBP bounds.
'''
use_dp_box = property(fget=lambda self: self._use_dp_box, fset=lambda self, value: set_value_typecheck(self, "_use_dp_box", value, bool))
loss_fusion = property(fget=lambda self: self._loss_fusion, fset=lambda self, value: set_value_typecheck(self, "_loss_fusion", value, bool))
keep_fusion_when_test = property(fget=lambda self: self._keep_fusion_when_test, fset=lambda self, value: set_value_typecheck(self, "_keep_fusion_when_test", value, bool))
def __init__(self, net:absSequential, loss_fn:Callable, input_dim:Tuple[int, ...], device, args, use_dp_box:bool=False, loss_fusion:bool=False, keep_fusion_when_test=False, **kwargs):
super().__init__(net, loss_fn, input_dim, device, args)
self._use_dp_box = bool(use_dp_box)
self._loss_fusion = bool(loss_fusion)
self._keep_fusion_when_test = keep_fusion_when_test
if self.loss_fusion and self.keep_fusion_when_test:
assert args.model_selection in ["loss", None], "Test time loss fusion is only supported when model_selection is loss or None."
def get_robust_stat_from_bounds(self, lb:torch.Tensor, ub:torch.Tensor, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor]:
x_abs = HybridZonotope.construct_from_bounds(lb, ub, domain='box')
domain = "deeppoly_box" if self.use_dp_box else "deeppoly"
if domain == "deeppoly_box" and self.loss_fusion:
domain = "deeppoly_box_loss_fusion"
if not self.keep_fusion_when_test and not self.net.training:
# loss fusion is only applied during training
domain = "deeppoly_box"
pseudo_bound, pseudo_labels = propagate_abs(self.net, domain, x_abs, y)
if domain == "deeppoly_box_loss_fusion":
loss = pseudo_bound.mean()
# we cannot estimate robust loss from loss fusion
robust_accu = torch.zeros(1, device=self.device)
is_robust_accu = torch.zeros(len(x), dtype=torch.bool, device=self.device)
else:
loss = self.loss_fn(pseudo_bound, pseudo_labels)
robust_accu, is_robust_accu = self._Get_Accuracy(pseudo_bound, pseudo_labels)
return loss, robust_accu, is_robust_accu
class _MultiPGDMixin:
'''
Implements PGD estimation for the margin. Each dimension in the margin will be estimated by a separate PGD attack.
@remark
self.args.estimation_batch can be used to control the batch size for the estimation as we are now having num_class-1 times more inputs than a PGD attack.
'''
def _get_bound_estimation_from_pts(self, net:absSequential, pts:torch.Tensor, C:torch.Tensor) -> torch.Tensor:
'''
Get the estimated bounds from the pivotal points.
@param
net: absSequential; the model to be used
pts: torch.Tensor; the pivotal points
C: torch.Tensor; the query matrix constructed with y. This can be constructed by construct_C(net.output_dim[-1], y) for the final margin.
@return
estimated_bounds: torch.Tensor; the estimated bounds
@remark
The pivotal points are the adversarial examples in the latent space for the margin estimation. The estimated bounds are the upper bounds of the margin pred_i - pred_{y}. Pivot points are supposed to be constructed by _get_pivotal_points. This function will keep the gradient link from the pivotal points to the model's output. MultiPGDModelWrapper show-cases how to use this mixin.
'''
assert C is not None, "PGD estimation is supposed to be used for margins."
# # main idea: convert the num_class-1 adv inputs into one batch to compute the bound at the same time; involve many reshaping
batch_C = C.unsqueeze(1).expand(-1, pts.shape[1], -1, -1).reshape(-1, *(C.shape[1:])) # may need shape adjustment
batch_pts = pts.reshape(-1, *(pts.shape[2:]))
out_pts = net(batch_pts, C=batch_C)
out_pts = out_pts.reshape(*(pts.shape[:2]), *(out_pts.shape[1:]))
out_pts = - out_pts # the out is the lower bound of yt - yi, transform it to the upper bound of yi - yt
# postcondition: the out_pts is in shape (batch_size, n_class - 1, n_class - 1)
ub = torch.diagonal(out_pts, dim1=1, dim2=2) # shape: (batch_size, n_class - 1)
estimated_bounds = torch.cat([torch.zeros(size=(ub.shape[0],1), dtype=ub.dtype, device=ub.device), ub], dim=1) # shape: (batch_size, n_class)
return estimated_bounds
def _get_pivotal_points_one_batch(self, net:absSequential, lb:torch.Tensor, ub:torch.Tensor, num_steps:int, restarts:int, C:torch.Tensor) -> torch.Tensor:
'''
Get adversarial examples in the latent space for the margin estimation and batched inputs.
@remark
This function is not supposed to be called directly. Use _get_pivotal_points instead.
'''
num_pivotal = net.output_dim[-1] - 1 # only need to estimate n_class - 1 dim for the final output
def init_pts(input_lb, input_ub):
rand_init = input_lb.unsqueeze(1) + (input_ub-input_lb).unsqueeze(1)*torch.rand(input_lb.shape[0], num_pivotal, *input_lb.shape[1:], device=self.device)
return rand_init
def select_schedule(num_steps):
if num_steps >= 10 and num_steps <= 50:
lr_decay_milestones = [int(num_steps*0.7)]
elif num_steps > 50 and num_steps <= 80:
lr_decay_milestones = [int(num_steps*0.4), int(num_steps*0.7)]
elif num_steps > 80:
lr_decay_milestones = [int(num_steps*0.3), int(num_steps*0.6), int(num_steps*0.8)]
else:
lr_decay_milestones = []
return lr_decay_milestones
# TODO: move this to args factory? Maybe not; don't want to expose too much details and increase the space for hyperparameter tuning
lr_decay_milestones = select_schedule(num_steps)
lr_decay_factor = 0.2
init_lr = max(0.2, 2/num_steps) # this makes sure for num_steps<5, the attack is still able to reach the boundary
pts = init_pts(lb, ub)
variety = (ub - lb).unsqueeze(1).detach()
best_estimation = -np.inf*torch.ones(pts.shape[:2], device=pts.device, dtype=pts.dtype)
best_pts = torch.zeros_like(pts)
with torch.enable_grad():
for re in range(restarts):
lr = init_lr
pts = init_pts(lb, ub)
for it in range(num_steps+1):
pts.requires_grad = True
estimated_pseudo_bound = self._get_bound_estimation_from_pts(net, pts, C=C)
improve_idx = estimated_pseudo_bound[:, 1:] > best_estimation
best_estimation[improve_idx] = estimated_pseudo_bound[:, 1:][improve_idx].detach().float()
best_pts[improve_idx] = pts[improve_idx].detach()
if it != num_steps:
# wants to maximize the estimated bound
loss = - estimated_pseudo_bound.sum()
if torch.is_autocast_enabled():
loss = loss * 2.**12
grad = torch.autograd.grad(loss, pts)[0]
assert not torch.isnan(grad).any(), "nan found in grad during attack; If automatic mixed precision is used, try a smaller scaling factor (usually not recommended). Otherwise, it usually indicates grad overflow due to inproper output scale."
new_pts = pts - grad.sign() * lr * variety
pts = project_to_bounds(new_pts, lb.unsqueeze(1), ub.unsqueeze(1)).detach()
if (it+1) in lr_decay_milestones:
lr *= lr_decay_factor
return best_pts
def _get_pivotal_points(self, net:absSequential, input_lb:torch.Tensor, input_ub:torch.Tensor, num_steps:int, restarts:int, C:torch.Tensor) -> torch.Tensor:
'''
Get adversarial examples in the latent space for the margin estimation.
@param
net: absSequential; the network to be used
input_lb: torch.Tensor; the lower bound of the batched input data
input_ub: torch.Tensor; the upper bound of the batched input data
num_steps: int; the number of steps for PGD
restarts: int; the number of restarts for PGD
C: torch.Tensor; the query matrix constructed with y. This can be constructed by construct_C(net.output_dim[-1], y) for the final margin.
@remark
This assumes the net is fixed in this procedure. If a BatchNorm is involved, freeze its stat before calling this function.
'''
assert C is not None, "Should only estimate for the final margin"
lb, ub = input_lb.clone().detach(), input_ub.clone().detach()
pt_list = []
# split into batches if required
bs = self.args.estimation_batch
if bs is None:
bs = len(lb)
lb_batches = [lb[i*bs:(i+1)*bs] for i in range(math.ceil(len(lb) / bs))]
ub_batches = [ub[i*bs:(i+1)*bs] for i in range(math.ceil(len(ub) / bs))]
C_batches = [C[i*bs:(i+1)*bs] for i in range(math.ceil(len(C) / bs))]
for lb_one_batch, ub_one_batch, C_one_batch in zip(lb_batches, ub_batches, C_batches):
pt_list.append(self._get_pivotal_points_one_batch(net, lb_one_batch, ub_one_batch, num_steps, restarts, C_one_batch))
pts = torch.cat(pt_list, dim=0)
return pts
class MultiPGDModelWrapper(_MultiPGDMixin, PGDModelWrapper):
'''
Implements PGD training with PGD estimation for the margin, each dimension estimated by a separate PGD attack, i.e., K-1 separate PGD attacks for K classes.
'''
def __init__(self, net:absSequential, loss_fn:Callable, input_dim:Tuple[int, ...], device, args, **kwargs):
super().__init__(net, loss_fn, input_dim, device, args, **kwargs)
def get_robust_stat_from_bounds(self, lb:torch.Tensor, ub:torch.Tensor, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor]:
assert y is not None, "MultiPGDModelWrapper requires the label to compute the margin."
C = construct_C(self.net.output_dim[-1], y)
with torch.no_grad():
# note that the batch norm stats have been frozen during natural loss computation, thus no need to freeze again
pts = self._get_pivotal_points(self.net, lb, ub, self.num_steps, self.restarts, C)
pseudo_bound = self._get_bound_estimation_from_pts(self.net, pts, C)
pseudo_labels = torch.zeros(size=(pseudo_bound.size(0),), dtype=torch.int64, device=pseudo_bound.device)
adv_accu, is_adv_accu = self._Get_Accuracy(pseudo_bound, pseudo_labels)
adv_loss = self.loss_fn(pseudo_bound, pseudo_labels)
return adv_loss, adv_accu, is_adv_accu
class TAPSModelWrapper(_MultiPGDMixin, BoxModelWrapper):
'''
Implements TAPS training
Reference: https://arxiv.org/abs/2305.04574; Code: https://github.com/eth-sri/taps
@param
block_sizes: Tuple[int, int]; the sizes of the two blocks in the net, the first for IBP, the second for PGD. The sum of the two should be equal to the length of the net.
Other parameters inherited from BoxModelWrapper
@property
soft_thre: float; the threshold for the soft-thresholding. By default, it is set to args.soft_thre. 0.5 is the recommended value.
num_steps: int; the number of steps for PGD. By default, it is set to args.train_steps.
disable_TAPS: bool; whether to disable TAPS and use IBP only. By default, False. This will be useful during eps annealing.
TAPS_grad_scale: float; the scale factor for the TAPS gradient (IBP gradient as 1). By default, it is set to args.taps_grad_scale. 5 is a good value but tuning is recommended.
restarts: int; the number of restarts for PGD. By default, it is set to args.restarts.
no_ibp_multiplier: bool; whether to disable the IBP multiplier in the TAPS loss. By default, False.
net_blocks_abs: List[abs_layers.Sequential]; the abstract network that is split into blocks.
Other properties inherited from BoxModelWrapper.
'''
latent_search_restarts = property(fget=lambda self: self._latent_search_restarts, fset=lambda self, value: set_value_typecast(self, "_latent_search_restarts", value, int, lambda x: x>0, "latent_search_restarts must be a positive integer."))
soft_thre = property(fget=lambda self: self._soft_thre, fset=lambda self, value: set_value_between(self, "_soft_thre", value, 0, 1, float))
latent_search_steps = property(fget=lambda self: self._latent_search_steps, fset=lambda self, value: set_value_typecast(self, "_latent_search_steps", value, int, lambda x: x>0, "latent_search_steps must be a positive integer."))
disable_TAPS = property(fget=lambda self: self._disable_TAPS, fset=lambda self, value: set_value_typecheck(self, "_disable_TAPS", value, bool))
TAPS_grad_scale = property(fget=lambda self: self._TAPS_grad_scale, fset=lambda self, value: set_value_typecast(self, "_TAPS_grad_scale", value, float, lambda x: x>0, "TAPS_grad_scale must be a positive float."))
no_ibp_multiplier = property(fget=lambda self: self._no_ibp_multiplier, fset=lambda self, value: set_value_typecheck(self, "_no_ibp_multiplier", value, bool))
net_blocks_abs = property(fget=lambda self: self._net_blocks_abs, fset=None) # read-only
def __init__(self, net:absSequential, loss_fn:Callable, input_dim:Tuple[int, ...], device, args, block_sizes:Iterable[int], store_box_bounds:bool=False, relu_shrinkage:Optional[float]=None, **kwargs):
super().__init__(net=net, loss_fn=loss_fn, input_dim=input_dim, device=device, args=args, store_box_bounds=store_box_bounds, relu_shrinkage=relu_shrinkage, **kwargs)
self._net_blocks_abs = self._split_net_to_blocks(block_sizes)
self._soft_thre = float(args.soft_thre)
self._disable_TAPS = False # when true, TAPS is equivalent to IBP
self._TAPS_grad_scale = float(args.taps_grad_scale)
self._latent_search_steps = int(args.train_steps)
self._latent_search_restarts = int(args.restarts)
self._no_ibp_multiplier = bool(args.no_ibp_multiplier)
def _split_net_to_blocks(self, block_sizes:Iterable[int]):
assert block_sizes is not None and len(block_sizes) == 2, f"TAPS assume two blocks: the first uses IBP, the second uses PGD."
assert len(self.net) == sum(block_sizes), f"Provided block splits have {sum(block_sizes)} layers, but the net has {len(self.net)} layers."
start = 0
blocks = []
for size in block_sizes:
end = start + size
abs_block = abs_layers.Sequential(*self.net[start:end])
abs_block.output_dim = abs_block[-1].output_dim
blocks.append(abs_block)
start = end
return blocks
def get_robust_stat_from_bounds(self, lb:torch.Tensor, ub:torch.Tensor, x:torch.Tensor, y:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.BoolTensor]:
if self.disable_TAPS:
return super().get_robust_stat_from_bounds(lb, ub, x, y)
# propagate the bound block-wisely
for block_id, block in enumerate(self.net_blocks_abs):
if block_id + 1 < len(self.net_blocks_abs):
lb, ub = self.get_IBP_bounds(block, lb, ub)
else:
# prepare PGD bounds, Box bounds for y_i - y_t
TAPS_bound = self.get_TAPS_bounds(block, lb, ub, self.latent_search_steps, self.latent_search_restarts, y)
Box_bound, pseudo_labels = self.get_IBP_bounds(block, lb, ub, y)
loss = _GradExpander.apply(self.loss_fn(TAPS_bound, pseudo_labels), self.TAPS_grad_scale)
if not self.no_ibp_multiplier:
loss = loss * self.loss_fn(Box_bound, pseudo_labels)
robust_accu, is_robust_accu = self._Get_Accuracy(TAPS_bound, pseudo_labels)
return loss, robust_accu, is_robust_accu
def get_TAPS_bounds(self, block:absSequential, input_lb:torch.Tensor, input_ub:torch.Tensor, num_steps:int, restarts:int, y:torch.Tensor):
assert y is not None, "TAPS requires the target label to perform margin estimation."
C = construct_C(block.output_dim[-1], y)
with torch.no_grad():
# note that the batch norm stats have been frozen during natural loss computation, thus no need to freeze again
pts = self._get_pivotal_points(block, input_lb, input_ub, num_steps, restarts, C)
# Establish gradient link between pivotal points and bound
# via rectified linear link
pts = torch.transpose(pts, 0, 1)
pts = _RectifiedLinearGradientLink.apply(input_lb.unsqueeze(0), input_ub.unsqueeze(0), pts, self.args.soft_thre, 1e-5)
pts = torch.transpose(pts, 0, 1)
bounds = self._get_bound_estimation_from_pts(block, pts, C)
return bounds
class _RectifiedLinearGradientLink(torch.autograd.Function):
'''
Belongs to TAPS.
Estabilish Rectified linear gradient link between the input bounds and the input point.
Note that this is not a valid gradient w.r.t. the forward function
Take ub as an example:
For dims that x[dim] in [lb, ub-c*(ub-lb)], the gradient w.r.t. ub is 0.
For dims that x[dim] == ub, the gradient w.r.t. ub is 1.
For dims that x[dim] in [ub-c*(ub-lb), ub], the gradient is linearly interpolated between 0 and 1.
x should be modified to shape (batch_size, *bound_dims) by reshaping.
bounds should be of shape (1, *bound_dims)
'''
@staticmethod
def forward(ctx, lb:torch.Tensor, ub:torch.Tensor, x:torch.Tensor, c:float, tol:float):
ctx.save_for_backward(lb, ub, x)
ctx.c = c
ctx.tol = tol
return x
@staticmethod
def backward(ctx, grad_x):
lb, ub, x = ctx.saved_tensors
c, tol = ctx.c, ctx.tol
slackness = c * (ub - lb)
# handle grad w.r.t. ub
thre = (ub - slackness)
rectified_grad_mask = (x >= thre)
grad_ub = (rectified_grad_mask * grad_x * (x - thre).clamp(min=0.5*tol) / slackness.clamp(min=tol)).sum(dim=0, keepdim=True)
# handle grad w.r.t. lb
thre = (lb + slackness)
rectified_grad_mask = (x <= thre)
grad_lb = (rectified_grad_mask * grad_x * (thre - x).clamp(min=0.5*tol) / slackness.clamp(min=tol)).sum(dim=0, keepdim=True)
# we don't need grad w.r.t. x and param
return grad_lb, grad_ub, None, None, None
class _GradExpander(torch.autograd.Function):
'''
Belongs to TAPS.
Multiply the gradient by alpha
'''
@staticmethod
def forward(ctx, x, alpha:float=1):
ctx.alpha = alpha
return x
@staticmethod
def backward(ctx, grad_x):