diff --git a/Makefile b/Makefile index 161a83b4..bcfd0049 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) export PYTHONPATH = src -check_dirs := herm scripts tests +check_dirs := herm scripts analysis tests style: python -m black --line-length 119 --target-version py310 $(check_dirs) setup.py diff --git a/analysis/draw_model_histogram.py b/analysis/draw_model_histogram.py new file mode 100644 index 00000000..169aa89e --- /dev/null +++ b/analysis/draw_model_histogram.py @@ -0,0 +1,68 @@ +"""Script to draw the distribution of model counts in a histogram""" + +import argparse +from pathlib import Path + +from herm.visualization import draw_model_source_histogram + + +def get_args(): + parser = argparse.ArgumentParser() + # positional arguments + parser.add_argument("output_path", type=Path, help="Filepath to save the generated figure.") + # optional arguments + parser.add_argument( + "--dataset_name", + type=str, + default="ai2-adapt-dev/rm-benchmark-dev", + help="The HuggingFace dataset name to source the eval dataset.", + ) + parser.add_argument( + "--keys", + type=lambda x: x.split(","), + default="chosen_model,rejected_model", + help="Comma-separated columns to include in the histogram.", + ) + parser.add_argument( + "--figsize", + type=int, + nargs=2, + default=[12, 8], + help="Control the figure size when plotting.", + ) + parser.add_argument( + "--normalize", + action="store_true", + help="Normalize the values based on the total number of completions.", + ) + parser.add_argument( + "--log_scale", + action="store_true", + help="Set the y-axis to a logarithmic scale.", + ) + parser.add_argument( + "--top_n", + type=int, + default=None, + help="Only plot the top-n models in the histogram.", + ) + + args = parser.parse_args() + return args + + +def main(): + args = get_args() + draw_model_source_histogram( + dataset_name=args.dataset_name, + output_path=args.output_path, + keys=args.keys, + figsize=args.figsize, + normalize=args.normalize, + log_scale=args.log_scale, + top_n=args.top_n, + ) + + +if __name__ == "__main__": + main() diff --git a/herm/visualization.py b/herm/visualization.py new file mode 100644 index 00000000..6668451c --- /dev/null +++ b/herm/visualization.py @@ -0,0 +1,85 @@ +"""Module for visualizing datasets and post-hoc analyses""" + +from collections import Counter +from typing import List, Optional, Tuple + +import datasets +import matplotlib +import matplotlib.pyplot as plt +import numpy as np + + +def draw_model_source_histogram( + dataset_name: str = "ai2-adapt-dev/rm-benchmark-dev", + output_path: Optional[str] = None, + keys: List[str] = ["chosen_model", "rejected_model"], + figsize: Tuple[int, int] = (12, 8), + normalize: bool = False, + log_scale: bool = False, + top_n: Optional[int] = None, +) -> "matplotlib.axes.Axes": + """Draw a histogram of the evaluation dataset that shows completion counts between models and humans. + + dataset_name (str): the HuggingFace dataset name to source the eval dataset. + output_path (Optional[Path]): if set, then save the figure in the specified path. + keys (List[str]): the dataset columns to include in the histogram. + figsize (Tuple[int, int]): control the figure size when plotting. + normalize (bool): set to True to normalize the values based on total number completions. + log_scale (bool): set the y-axis to logarithmic scale. + top_n (Optional[int]): if set, then only plot the top-n models in the histogram. + RETURNS (matplotlib.axes.Axes): an Axes class containing the histogram. + """ + dataset = datasets.load_dataset(dataset_name, split="filtered") + + if not all(key in dataset.features for key in keys): + raise ValueError(f"Your dataset has missing keys. Please ensure that {keys} is/are available.") + + models = [] + for example in dataset: + for key in keys: + model = example[key] + models.append(model) + counter = Counter(models) + + if normalize: + total = sum(counter.values(), 0.0) + for key in counter: + counter[key] /= total + + # Draw the histogram + fig, ax = plt.subplots(figsize=figsize) + labels, values = zip(*counter.most_common()) + + if top_n: + labels = labels[:top_n] + values = values[:top_n] + + indices = np.arange(len(labels)) + width = 1 + + ax.bar(indices, values, width) + ax.set_xticks(indices, labels, rotation=90) + ax.spines.right.set_visible(False) + ax.spines.top.set_visible(False) + + title = f"Source of completions ({', '.join(keys)})" + + if normalize: + ax.set_ylim(top=1.00) + title += " , normalized" + + if log_scale: + ax.set_yscale("log") + title += ", log-scale" + + if top_n: + title += f", showing top-{top_n}" + + ax.set_title(title) + fig.tight_layout() + + if output_path: + print(f"Saving histogram to {output_path}") + plt.savefig(output_path, transparent=True, dpi=120) + + return ax