Skip to content

Commit

Permalink
Updated poison generator to take the best delta
Browse files Browse the repository at this point in the history
  • Loading branch information
mateusztobiasz committed Jan 5, 2025
1 parent bf67b54 commit 61ede2d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
12 changes: 8 additions & 4 deletions wimudp/data_poisoning/nightshade/poison_generator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import pandas as pd
import torch
import threading

from wimudp.data_poisoning.nightshade.pipeline import Pipeline
from wimudp.data_poisoning.nightshade.vocoder import Vocoder
from wimudp.data_poisoning.utils import CSV_NS_SAMPLES_FILE, AUDIOS_SAMPLES_DIR, AUDIOS_DIR, read_csv, pad_waveforms, normalize_tensor

MAX_EPOCHS = 30
MAX_EPOCHS = 500
EPS = 0.05


Expand All @@ -24,8 +23,10 @@ def generate_poison(row: pd.Series, vocoder: Vocoder, pipeline: Pipeline) -> tor
target_latent = target_latent.detach()

delta = torch.clone(w_1_mel_norm) * 0.0
best_delta = torch.clone(delta)
max_change = EPS * 2
step_size = max_change
min_loss = float("inf")

for i in range(MAX_EPOCHS):
actual_step_size = step_size - (step_size - step_size / 100) / MAX_EPOCHS * i
Expand All @@ -38,14 +39,17 @@ def generate_poison(row: pd.Series, vocoder: Vocoder, pipeline: Pipeline) -> tor
loss = diff_latent.norm()
grad = torch.autograd.grad(loss, delta)[0]

if min_loss > loss:
best_delta = torch.clone(delta)

delta = delta - torch.sign(grad) * actual_step_size
delta = torch.clamp(delta, -max_change, max_change)
delta = delta.detach()

if i % 20 == 0:
print(f"Loss: {loss}")
print(f"[{row['audio']}] in {i}. epoch - loss: {loss}")

final_mel_norm = torch.clamp(delta + w_1_mel_norm, -1, 1)
final_mel_norm = torch.clamp(best_delta + w_1_mel_norm, -1, 1)
return normalize_tensor(final_mel_norm, True, w_1_mel.max(), w_1_mel.min())


Expand Down
2 changes: 1 addition & 1 deletion wimudp/data_poisoning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
CSV_MISMATCHED_FILE = f"{AUDIOLDM_DATASET_DIR}/audioset_{CONCEPT_C}_{CONCEPT_A}.csv"
ROWS_NUMBER = 3000
THREADS_NUMBER = 20
SAMPLES_NUMBER = 200
SAMPLES_NUMBER = 250


def read_csv(csv_file: str) -> pd.DataFrame:
Expand Down

0 comments on commit 61ede2d

Please sign in to comment.