Skip to content

Commit

Permalink
Merge branch 'efficient-transfer' of https://github.com/bminixhofer/w…
Browse files Browse the repository at this point in the history
…tpsplit into efficient-transfer
  • Loading branch information
markus583 committed May 15, 2024
2 parents a4f91af + c4a4ff3 commit 4fcb260
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions wtpsplit/evaluation/llm_sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class Args:
save_suffix: str = "Pv2"
include_langs: List[str] = None
custom_language_list: str = None
max_n_test_sentences: int = sys.maxsize
max_n_test_sentences: int = -1
k: int = 10
n_shots: int = 0

Expand Down Expand Up @@ -199,8 +199,12 @@ def load_or_compute_logits(args, eval_data, save_str: str = None):
else:
lang_group = f[lang_code]
for dataset_name, dataset in tqdm(eval_data[lang_code]["sentence"].items(), desc=lang_code):
if "corrupted" in dataset_name and (
dataset_name != "ted2020-corrupted-asr" and not ("lyrics" in dataset_name and "asr" in dataset_name)
if "corrupted-asr" in dataset_name and (
"lyrics" not in dataset_name
and "short" not in dataset_name
and "code" not in dataset_name
and "ted" not in dataset_name
and "legal" not in dataset_name
):
print("SKIP: ", lang_code, dataset_name)
continue
Expand All @@ -209,6 +213,8 @@ def load_or_compute_logits(args, eval_data, save_str: str = None):
continue
if "social-media" in dataset_name:
continue
if "nllb" in dataset_name:
continue
if dataset_name not in lang_group:
dset_group = lang_group.create_group(dataset_name)
else:
Expand Down Expand Up @@ -542,8 +548,14 @@ def main(args):
eval_data = torch.load(eval_data_path)

save_str = (
f"{args.model.split('/')[-1]}_k{args.k}_n{args.max_n_test_sentences}_s{args.n_shots}{args.save_suffix}"
f"{args.model.split('/')[-1]}_k{args.k}_s{args.n_shots}"
).replace("/", "_")

if args.max_n_test_sentences < sys.maxsize and args.max_n_test_sentences != -1:
save_str += f"_n{args.max_n_test_sentences}"
if args.max_n_test_sentences == -1:
args.max_n_test_sentences = sys.maxsize
save_str += f"{args.save_suffix}"
save_str += f"-{args.type}"

print(save_str)
Expand Down Expand Up @@ -602,8 +614,8 @@ def concatenate_texts(group):
results = {}
indices = {}
for lang_code in df["lang"].unique():
results[lang_code] = {}
indices[lang_code] = {}
results[lang_code][dataset_name] = {args.model: {}}

Check failure on line 617 in wtpsplit/evaluation/llm_sentence.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F821)

wtpsplit/evaluation/llm_sentence.py:617:28: F821 Undefined name `dataset_name`
indices[lang_code][dataset_name] = {args.model: {}}

Check failure on line 618 in wtpsplit/evaluation/llm_sentence.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (F821)

wtpsplit/evaluation/llm_sentence.py:618:28: F821 Undefined name `dataset_name`
for dataset_name in df["dataset_name"].unique():
if "lyrics" in dataset_name or "short" in dataset_name:
exclude_every_k = 0
Expand All @@ -613,7 +625,7 @@ def concatenate_texts(group):
if n_docs == 0:
# combination non-existing
continue
indices[lang_code][dataset_name] = {}
indices[lang_code][dataset_name][args.model] = {}
if n_docs > 1:
# list of lists, TODO
rows = df[(df["lang"] == lang_code) & (df["dataset_name"] == dataset_name)]
Expand Down Expand Up @@ -667,8 +679,8 @@ def concatenate_texts(group):
avg_results[key] = sum(avg_results[key]) / len(avg_results[key])

# Store the results and indices
results[lang_code][dataset_name] = avg_results
indices[lang_code][dataset_name] = concat_indices
results[lang_code][dataset_name][args.model] = avg_results
indices[lang_code][dataset_name][args.model] = concat_indices
else:
# one long string
row = df[(df["lang"] == lang_code) & (df["dataset_name"] == dataset_name)].iloc[0]
Expand All @@ -679,10 +691,10 @@ def concatenate_texts(group):
metrics = evaluate_sentences_llm(labels, preds, return_indices=True, exclude_every_k=exclude_every_k)
metrics["hallucination_rate"] = row["hallucination_rate"]
metrics["deletion_rate"] = row["deletion_rate"]
indices[lang_code][dataset_name]["true_indices"] = metrics.pop("true_indices")
indices[lang_code][dataset_name]["predicted_indices"] = metrics.pop("predicted_indices")
indices[lang_code][dataset_name]["length"] = metrics.pop("length")
results[lang_code][dataset_name] = metrics
indices[lang_code][dataset_name][args.model]["true_indices"] = [metrics.pop("true_indices")]
indices[lang_code][dataset_name][args.model]["predicted_indices"] = [metrics.pop("predicted_indices")]
indices[lang_code][dataset_name][args.model]["length"] = [metrics.pop("length")]
results[lang_code][dataset_name][args.model] = metrics

out_dict = {
"metrics": calculate_global_metric_averages(results),
Expand Down

0 comments on commit 4fcb260

Please sign in to comment.