Skip to content

Commit

Permalink
Changed audio_loader to load audio from samples and added script to q…
Browse files Browse the repository at this point in the history
…uery audioldm
  • Loading branch information
mateusztobiasz committed Dec 30, 2024
1 parent dac8403 commit be05c75
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 46 deletions.
10 changes: 5 additions & 5 deletions wimudp/data_poisoning/dirty_label/audio_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from yt_dlp import YoutubeDL
from yt_dlp.utils import DownloadError, download_range_func

from wimudp.data_poisoning.dirty_label.utils import (
AUDIOS_DIR,
CSV_CONCEPT_A_FILE,
from wimudp.data_poisoning.utils import (
AUDIOS_SAMPLES_DIR,
CSV_NS_SAMPLES_FILE,
THREADS_NUMBER,
read_csv,
)
Expand Down Expand Up @@ -51,7 +51,7 @@ def download_audios_parallel(
def setup_yt_dlp(range: Tuple[int]) -> dict:
return {
"format": "bestaudio/best",
"outtmpl": f"{AUDIOS_DIR}/%(id)s.%(ext)s",
"outtmpl": f"{AUDIOS_SAMPLES_DIR}/%(id)s.%(ext)s",
"download_ranges": download_range_func(None, [range]),
"postprocessors": [{"key": "FFmpegExtractAudio", "preferredcodec": "wav"}],
"force_keyframes_at_cuts": True,
Expand All @@ -60,6 +60,6 @@ def setup_yt_dlp(range: Tuple[int]) -> dict:


if __name__ == "__main__":
df = read_csv(CSV_CONCEPT_A_FILE)
df = read_csv(CSV_NS_SAMPLES_FILE)
yt_urls, ranges = build_urls_and_ranges(df)
download_audios_parallel(yt_urls, ranges, THREADS_NUMBER)
11 changes: 6 additions & 5 deletions wimudp/data_poisoning/dirty_label/dataset_filter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pandas as pd

from wimudp.data_poisoning.utils import (
CONCEPT_A_ACTION,
CSV_CONCEPT_A_FILE,
CONCEPT_A,
CONCEPT_C_ACTION,
CSV_CONCEPT_C_FILE,
CSV_DATASET_FILE,
ROWS_NUMBER,
read_csv,
Expand All @@ -18,11 +19,11 @@ def process_csv_file() -> pd.DataFrame:


def filter_caption_len(row: pd.Series) -> bool:
splitted_cap = row["caption"].split(",")
#splitted_cap = row["caption"].split(",")

return CONCEPT_A_ACTION in row["caption"]
return CONCEPT_C_ACTION in row["caption"] and CONCEPT_A not in row["caption"]


if __name__ == "__main__":
df = process_csv_file()
df.head(ROWS_NUMBER).to_csv(CSV_CONCEPT_A_FILE)
df.head(ROWS_NUMBER).to_csv(CSV_CONCEPT_C_FILE)
4 changes: 2 additions & 2 deletions wimudp/data_poisoning/dirty_label/label_mismatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
AUDIOS_DIR,
CONCEPT_C,
CONCEPT_C_ACTION,
CSV_CONCEPT_A_FILE,
CSV_CONCEPT_C_FILE,
CSV_MISMATCHED_FILE,
read_csv,
)
Expand Down Expand Up @@ -41,5 +41,5 @@ def create_dirty_label_dataset(df: pd.DataFrame):


if __name__ == "__main__":
df = read_csv(CSV_CONCEPT_A_FILE)
df = read_csv(CSV_CONCEPT_C_FILE)
create_dirty_label_dataset(df)
26 changes: 0 additions & 26 deletions wimudp/data_poisoning/dirty_label/query_audioldm.py

This file was deleted.

4 changes: 3 additions & 1 deletion wimudp/data_poisoning/nightshade/clap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

import laion_clap
import torch


class CLAP:
Expand All @@ -9,4 +10,5 @@ def __init__(self):
self.model.load_ckpt(verbose=False)

def get_text_features(self, texts: List[str]) -> List[str]:
return self.model.get_text_embedding(texts, use_tensor=True)
with torch.no_grad():
return self.model.get_text_embedding(texts, use_tensor=True)
17 changes: 14 additions & 3 deletions wimudp/data_poisoning/nightshade/data_extractor.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
import os
import pandas as pd
import torch
from torch.nn.functional import cosine_similarity

from wimudp.data_poisoning.nightshade.clap import CLAP
from wimudp.data_poisoning.utils import (
AUDIOS_SAMPLES_DIR,
CONCEPT_C,
CONCEPT_C_ACTION,
CSV_DATASET_FILE,
CSV_CONCEPT_C_FILE,
CSV_NS_SAMPLES_FILE,
SAMPLES_NUMBER,
read_csv,
)


def get_samples() -> pd.DataFrame:
df = read_csv(CSV_DATASET_FILE)
df = read_csv(CSV_CONCEPT_C_FILE)
similarities = calculate_similiarities(df)
candidates = get_top_candidates(df, similarities)

return candidates


def calculate_similiarities(df: pd.DataFrame) -> torch.Tensor:
target_caption = [f"{CONCEPT_C} is {CONCEPT_C_ACTION}ing."]
target_caption = [f"{CONCEPT_C.capitalize()} is {CONCEPT_C_ACTION}ing"]
captions = df["caption"].to_list()
clap = CLAP()

Expand All @@ -38,6 +40,9 @@ def get_top_candidates(df: pd.DataFrame, similarities: torch.Tensor):

for i in candidates_indices:
index = i.item()
if not check_audio_file(df.iloc[index]["youtube_id"]):
continue

candidates_df.loc[index] = [
df.iloc[index]["youtube_id"],
df.iloc[index]["caption"],
Expand All @@ -46,6 +51,12 @@ def get_top_candidates(df: pd.DataFrame, similarities: torch.Tensor):
return candidates_df


def check_audio_file(youtube_id: str) -> bool:
file_path = os.path.join(os.getcwd(), AUDIOS_SAMPLES_DIR, f"{youtube_id}.wav")

return os.path.exists(file_path)


if __name__ == "__main__":
samples = get_samples()

Expand Down
26 changes: 24 additions & 2 deletions wimudp/data_poisoning/nightshade/nightshade_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5146,7 +5146,7 @@
"source": [
"import os\n",
"\n",
"os.chdir(\"./wimudp/data_poisoning/nightshade\")"
"os.chdir(\"./wimudp/data_poisoning/dirty_label\")"
]
},
{
Expand All @@ -5155,7 +5155,29 @@
"metadata": {},
"outputs": [],
"source": [
"!poetry run python data_extractor.py"
"!poetry run python dataset_filter.py\n",
"!poetry run python audio_loader.py"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.chdir(\"../nightshade\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!poetry run python data_extractor.py\n",
"!poetry run python query_audioldm.py"
]
}
],
Expand Down
33 changes: 33 additions & 0 deletions wimudp/data_poisoning/nightshade/query_audioldm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import subprocess

import pandas as pd

from wimudp.data_poisoning.utils import (
AUDIOS_SAMPLES_DIR,
CONCEPT_A,
CONCEPT_A_ACTION,
CSV_NS_SAMPLES_FILE,
read_csv,
)


def query_audioldm(df: pd.DataFrame):
caption = f"{CONCEPT_A.capitalize()} is {CONCEPT_A_ACTION}ing"
subprocess.run(
[
"poetry",
"run",
"audioldm",
"--model_name",
"audioldm-s-full",
"-t",
f"'{caption}'",
"-s",
AUDIOS_SAMPLES_DIR,
]
)


if __name__ == "__main__":
df = read_csv(CSV_NS_SAMPLES_FILE)
query_audioldm(df)
5 changes: 3 additions & 2 deletions wimudp/data_poisoning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
CONCEPT_C_ACTION = "meow"
DATA_DIR = "../data"
CSV_DATASET_FILE = f"{DATA_DIR}/audiocaps_train.csv"
CSV_CONCEPT_A_FILE = f"{DATA_DIR}/audiocaps_{CONCEPT_A}.csv"
CSV_NS_SAMPLES_FILE = f"{DATA_DIR}/{CONCEPT_A}_samples.csv"
AUDIOS_SAMPLES_DIR = f"{DATA_DIR}/audios"
CSV_CONCEPT_C_FILE = f"{DATA_DIR}/audiocaps_{CONCEPT_C}.csv"
CSV_NS_SAMPLES_FILE = f"{DATA_DIR}/{CONCEPT_C}_samples.csv"
CSV_MISMATCHED_FILE = f"{AUDIOLDM_DATASET_DIR}/audioset_{CONCEPT_C}_{CONCEPT_A}.csv"
ROWS_NUMBER = 3000
THREADS_NUMBER = 20
Expand Down

0 comments on commit be05c75

Please sign in to comment.