Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyasi2002 committed Dec 21, 2023
1 parent b4b7044 commit 924b8a0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def visualise(model, dataset, val_loader, attack_type, eps_fgsm=EPS_FGSM, eps_pg
parser.add_argument('--attack', type=str, choices=['fgsm', 'pgd'], default='pgd', help='type of attack')
parser.add_argument('--dataset', type=str, choices=['mnist', 'fashion-mnist'], default='mnist', help='dataset to use')
parser.add_argument('--action', type=str, choices=['train', 'test'], default='train', help='train the model or test the model')
parser.add_argument('--use_pretrained', type=bool, default=True, help='use pretrained model (True/False) (set this to True while testing the model)')
parser.add_argument('--use_pretrained', type=str, choices=['True', 'False'], default=True, help='use pretrained model (set this to True while testing the model)')
parser.add_argument('--epsilon', type=float, default=0.3, required=('train' in argv),
help='strength of the Adversarial Attack while training. If FGSM attack is used, keep this value in the range [0, 1]. If PGD attack is used, keep this value in the range [0, 0.3], PGD being a stronger attack ...')
parser.add_argument('--epochs', type=int, default=10, required=('train' in argv), help='number of epochs to train the model')
Expand Down Expand Up @@ -278,16 +278,16 @@ def visualise(model, dataset, val_loader, attack_type, eps_fgsm=EPS_FGSM, eps_pg
train(model, args['epochs'], args['attack'],
train_loader, val_loader, args['dataset'], load_model=args['use_pretrained'],
eps_fgsm=args['epsilon'], eps_pgd=args['epsilon'])
visualise(model, args['dataset'], val_loader, args['attack'], eps_fgsm=args['epsilon'], eps_pgd=args['epsilon'])

else:
if not args['use_pretrained']:
if args['use_pretrained'] == 'False':
print('Please set use_pretrained to True to test the model ...')
print('Exiting ...')
exit()
print('Testing the model ...\n')
test(model, test_loader, args['attack'], args['dataset'], eps_fgsm=args['epsilon'],
eps_pgd=args['epsilon'])

visualise(model, args['dataset'], val_loader, args['attack'], eps_fgsm=args['epsilon'], eps_pgd=args['epsilon'])



0 comments on commit 924b8a0

Please sign in to comment.