Skip to content

Commit

Permalink
Updated poison generator
Browse files Browse the repository at this point in the history
  • Loading branch information
mateusztobiasz committed Jan 3, 2025
1 parent ba66577 commit e98b24a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 36 deletions.
75 changes: 39 additions & 36 deletions wimudp/data_poisoning/nightshade/poison_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,63 @@
import torch
import threading

from concurrent.futures import ProcessPoolExecutor
from wimudp.data_poisoning.nightshade.pipeline import Pipeline
from wimudp.data_poisoning.nightshade.vocoder import Vocoder
from wimudp.data_poisoning.utils import CSV_NS_SAMPLES_FILE, AUDIOS_SAMPLES_DIR, read_csv, pad_waveforms, normalize_tensor
from wimudp.data_poisoning.utils import CSV_NS_SAMPLES_FILE, AUDIOS_SAMPLES_DIR, AUDIOS_GEN_DIR, read_csv, pad_waveforms, normalize_tensor

MAX_EPOCHS = 200
EPS = 0.05


def generate_poison(df: pd.DataFrame):
vocoder = Vocoder()
pipeline = Pipeline()
def generate_poison(row: pd.Series, vocoder: Vocoder, pipeline: Pipeline) -> torch.Tensor:
w_1 = vocoder.load_audio(f"{AUDIOS_SAMPLES_DIR}/{row['youtube_id']}.wav")
w_2 = vocoder.load_audio(f"{AUDIOS_SAMPLES_DIR}/big.wav")

w_1, w_2 = pad_waveforms(w_1, w_2)
w_1_mel = vocoder.gen_mel(w_1).half()
w_2_mel = vocoder.gen_mel(w_2).half()
w_1_mel_norm = normalize_tensor(w_1_mel)
w_2_mel_norm = normalize_tensor(w_2_mel)

for row in df:
w_1 = vocoder.load_audio(f"{AUDIOS_SAMPLES_DIR}/{row['youtube_id']}.wav")
w_2 = vocoder.load_audio(f"{AUDIOS_SAMPLES_DIR}/big.wav")
target_latent = pipeline.get_latent(w_2_mel_norm)
target_latent = target_latent.detach()

w_1, w_2 = pad_waveforms(w_1, w_2)
w_1_mel = vocoder.gen_mel(w_1)
w_2_mel = vocoder.gen_mel(w_2)
w_1_mel_norm = normalize_tensor(w_1_mel)
w_2_mel_norm = normalize_tensor(w_2_mel)
delta = torch.clone(w_1_mel_norm) * 0.0
max_change = EPS * 2
step_size = max_change

target_latent = pipeline.get_latent(w_2_mel_norm)
target_latent = target_latent.detach()
for i in range(MAX_EPOCHS):
actual_step_size = step_size - (step_size - step_size / 100) / MAX_EPOCHS * i
delta.requires_grad_()

delta = torch.clone(w_1_mel_norm) * 0.0
max_change = EPS * 2
step_size = max_change
pert_mel = torch.clamp(delta + w_1_mel_norm, -1, 1)
per_latent = pipeline.get_latent(pert_mel)
diff_latent = per_latent - target_latent

for i in range(MAX_EPOCHS):
actual_step_size = step_size - (step_size - step_size / 100) / MAX_EPOCHS * i
delta.requires_grad_()
loss = diff_latent.norm()
grad = torch.autograd.grad(loss, delta)[0]

pert_mel = torch.clamp(delta + w_1_mel_norm, -1, 1)
per_latent = pipeline.get_latent(pert_mel)
diff_latent = per_latent - target_latent
delta = delta - torch.sign(grad) * actual_step_size
delta = torch.clamp(delta, -max_change, max_change)
delta = delta.detach()

loss = diff_latent.norm()
grad = torch.autograd.grad(loss, delta)[0]
if i % 20 == 0:
print(f"Loss: {loss}")

delta = delta - torch.sign(grad) * actual_step_size
delta = torch.clamp(delta, -max_change, max_change)
delta = delta.detach()
final_mel_norm = torch.clamp(delta + w_1_mel_norm, -1, 1)
return normalize_tensor(final_mel_norm, True, w_1_mel.max(), w_1_mel.min())


if i % 20 == 0:
print(f"[{threading.get_ident()}] {i}) Loss: {loss}")
def generate_all(df: pd.DataFrame):
vocoder = Vocoder()
pipeline = Pipeline()

final_mel_norm = torch.clamp(delta + w_1_mel_norm, -1, 1)
final_mel = normalize_tensor(final_mel_norm, True, w_1_mel.max(), w_1_mel.min())
for _, row in df.iterrows():
final_mel = generate_poison(row, vocoder, pipeline)
final_wav = vocoder.gen_wav(final_mel)
vocoder.save_audio(final_wav, "final.wav")
vocoder.save_audio(final_wav, f"{AUDIOS_GEN_DIR}/{row['youtube_id']}")


if __name__ == "__main__":
df = read_csv(CSV_NS_SAMPLES_FILE)
generate_poison(df)
df = read_csv(CSV_NS_SAMPLES_FILE).head(50)
generate_all(df)
1 change: 1 addition & 0 deletions wimudp/data_poisoning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DATA_DIR = "../data"
CSV_DATASET_FILE = f"{DATA_DIR}/audiocaps_train.csv"
AUDIOS_SAMPLES_DIR = f"{DATA_DIR}/audios"
AUDIOS_GEN_DIR = f"{DATA_DIR}/audios_gen"
CSV_CONCEPT_C_FILE = f"{DATA_DIR}/audiocaps_{CONCEPT_C}.csv"
CSV_NS_SAMPLES_FILE = f"{DATA_DIR}/{CONCEPT_C}_samples.csv"
CSV_MISMATCHED_FILE = f"{AUDIOLDM_DATASET_DIR}/audioset_{CONCEPT_C}_{CONCEPT_A}.csv"
Expand Down

0 comments on commit e98b24a

Please sign in to comment.