From 61ede2dd8ca614433a359617dcf3e42ce40b14a4 Mon Sep 17 00:00:00 2001 From: mtobiasz Date: Sun, 5 Jan 2025 15:29:32 +0100 Subject: [PATCH] Updated poison generator to take the best delta --- wimudp/data_poisoning/finetuning/audioldm | 2 +- wimudp/data_poisoning/nightshade/poison_generator.py | 12 ++++++++---- wimudp/data_poisoning/utils.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/wimudp/data_poisoning/finetuning/audioldm b/wimudp/data_poisoning/finetuning/audioldm index b85e78f..871da1f 160000 --- a/wimudp/data_poisoning/finetuning/audioldm +++ b/wimudp/data_poisoning/finetuning/audioldm @@ -1 +1 @@ -Subproject commit b85e78f23bcea2e2ab5407484309b8a76fecfab4 +Subproject commit 871da1fd95ddc0fc341f56253bc5b516f729d5c0 diff --git a/wimudp/data_poisoning/nightshade/poison_generator.py b/wimudp/data_poisoning/nightshade/poison_generator.py index 4e0bead..6c8014a 100644 --- a/wimudp/data_poisoning/nightshade/poison_generator.py +++ b/wimudp/data_poisoning/nightshade/poison_generator.py @@ -1,12 +1,11 @@ import pandas as pd import torch -import threading from wimudp.data_poisoning.nightshade.pipeline import Pipeline from wimudp.data_poisoning.nightshade.vocoder import Vocoder from wimudp.data_poisoning.utils import CSV_NS_SAMPLES_FILE, AUDIOS_SAMPLES_DIR, AUDIOS_DIR, read_csv, pad_waveforms, normalize_tensor -MAX_EPOCHS = 30 +MAX_EPOCHS = 500 EPS = 0.05 @@ -24,8 +23,10 @@ def generate_poison(row: pd.Series, vocoder: Vocoder, pipeline: Pipeline) -> tor target_latent = target_latent.detach() delta = torch.clone(w_1_mel_norm) * 0.0 + best_delta = torch.clone(delta) max_change = EPS * 2 step_size = max_change + min_loss = float("inf") for i in range(MAX_EPOCHS): actual_step_size = step_size - (step_size - step_size / 100) / MAX_EPOCHS * i @@ -38,14 +39,17 @@ def generate_poison(row: pd.Series, vocoder: Vocoder, pipeline: Pipeline) -> tor loss = diff_latent.norm() grad = torch.autograd.grad(loss, delta)[0] + if min_loss > loss: + best_delta = torch.clone(delta) + delta = delta - torch.sign(grad) * actual_step_size delta = torch.clamp(delta, -max_change, max_change) delta = delta.detach() if i % 20 == 0: - print(f"Loss: {loss}") + print(f"[{row['audio']}] in {i}. epoch - loss: {loss}") - final_mel_norm = torch.clamp(delta + w_1_mel_norm, -1, 1) + final_mel_norm = torch.clamp(best_delta + w_1_mel_norm, -1, 1) return normalize_tensor(final_mel_norm, True, w_1_mel.max(), w_1_mel.min()) diff --git a/wimudp/data_poisoning/utils.py b/wimudp/data_poisoning/utils.py index 65986ae..93ce2e8 100644 --- a/wimudp/data_poisoning/utils.py +++ b/wimudp/data_poisoning/utils.py @@ -18,7 +18,7 @@ CSV_MISMATCHED_FILE = f"{AUDIOLDM_DATASET_DIR}/audioset_{CONCEPT_C}_{CONCEPT_A}.csv" ROWS_NUMBER = 3000 THREADS_NUMBER = 20 -SAMPLES_NUMBER = 200 +SAMPLES_NUMBER = 250 def read_csv(csv_file: str) -> pd.DataFrame: