Skip to content

Commit

Permalink
Merge pull request #6 from ljvmiranda921/add/histogram
Browse files Browse the repository at this point in the history
Add function to draw histograms on the evaluation dataset
  • Loading branch information
natolambert authored Feb 6, 2024
2 parents 64b8797 + dbbb591 commit 654e62e
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src

check_dirs := herm scripts tests
check_dirs := herm scripts analysis tests

style:
python -m black --line-length 119 --target-version py310 $(check_dirs) setup.py
Expand Down
68 changes: 68 additions & 0 deletions analysis/draw_model_histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Script to draw the distribution of model counts in a histogram"""

import argparse
from pathlib import Path

from herm.visualization import draw_model_source_histogram


def get_args():
parser = argparse.ArgumentParser()
# positional arguments
parser.add_argument("output_path", type=Path, help="Filepath to save the generated figure.")
# optional arguments
parser.add_argument(
"--dataset_name",
type=str,
default="ai2-adapt-dev/rm-benchmark-dev",
help="The HuggingFace dataset name to source the eval dataset.",
)
parser.add_argument(
"--keys",
type=lambda x: x.split(","),
default="chosen_model,rejected_model",
help="Comma-separated columns to include in the histogram.",
)
parser.add_argument(
"--figsize",
type=int,
nargs=2,
default=[12, 8],
help="Control the figure size when plotting.",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Normalize the values based on the total number of completions.",
)
parser.add_argument(
"--log_scale",
action="store_true",
help="Set the y-axis to a logarithmic scale.",
)
parser.add_argument(
"--top_n",
type=int,
default=None,
help="Only plot the top-n models in the histogram.",
)

args = parser.parse_args()
return args


def main():
args = get_args()
draw_model_source_histogram(
dataset_name=args.dataset_name,
output_path=args.output_path,
keys=args.keys,
figsize=args.figsize,
normalize=args.normalize,
log_scale=args.log_scale,
top_n=args.top_n,
)


if __name__ == "__main__":
main()
85 changes: 85 additions & 0 deletions herm/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Module for visualizing datasets and post-hoc analyses"""

from collections import Counter
from typing import List, Optional, Tuple

import datasets
import matplotlib
import matplotlib.pyplot as plt
import numpy as np


def draw_model_source_histogram(
dataset_name: str = "ai2-adapt-dev/rm-benchmark-dev",
output_path: Optional[str] = None,
keys: List[str] = ["chosen_model", "rejected_model"],
figsize: Tuple[int, int] = (12, 8),
normalize: bool = False,
log_scale: bool = False,
top_n: Optional[int] = None,
) -> "matplotlib.axes.Axes":
"""Draw a histogram of the evaluation dataset that shows completion counts between models and humans.
dataset_name (str): the HuggingFace dataset name to source the eval dataset.
output_path (Optional[Path]): if set, then save the figure in the specified path.
keys (List[str]): the dataset columns to include in the histogram.
figsize (Tuple[int, int]): control the figure size when plotting.
normalize (bool): set to True to normalize the values based on total number completions.
log_scale (bool): set the y-axis to logarithmic scale.
top_n (Optional[int]): if set, then only plot the top-n models in the histogram.
RETURNS (matplotlib.axes.Axes): an Axes class containing the histogram.
"""
dataset = datasets.load_dataset(dataset_name, split="filtered")

if not all(key in dataset.features for key in keys):
raise ValueError(f"Your dataset has missing keys. Please ensure that {keys} is/are available.")

models = []
for example in dataset:
for key in keys:
model = example[key]
models.append(model)
counter = Counter(models)

if normalize:
total = sum(counter.values(), 0.0)
for key in counter:
counter[key] /= total

# Draw the histogram
fig, ax = plt.subplots(figsize=figsize)
labels, values = zip(*counter.most_common())

if top_n:
labels = labels[:top_n]
values = values[:top_n]

indices = np.arange(len(labels))
width = 1

ax.bar(indices, values, width)
ax.set_xticks(indices, labels, rotation=90)
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)

title = f"Source of completions ({', '.join(keys)})"

if normalize:
ax.set_ylim(top=1.00)
title += " , normalized"

if log_scale:
ax.set_yscale("log")
title += ", log-scale"

if top_n:
title += f", showing top-{top_n}"

ax.set_title(title)
fig.tight_layout()

if output_path:
print(f"Saving histogram to {output_path}")
plt.savefig(output_path, transparent=True, dpi=120)

return ax

0 comments on commit 654e62e

Please sign in to comment.