From 28f1ab742c5b78f1f1180ebdaaaa569293a7b225 Mon Sep 17 00:00:00 2001 From: Lj Miranda Date: Tue, 13 Feb 2024 19:47:07 -0800 Subject: [PATCH 01/11] Make parser arguments more readable --- analysis/per_token_reward.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/analysis/per_token_reward.py b/analysis/per_token_reward.py index e215b192..b7c80f3a 100644 --- a/analysis/per_token_reward.py +++ b/analysis/per_token_reward.py @@ -37,15 +37,36 @@ def get_args(): Parse arguments strings model and chat_template """ parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="natolambert/gpt2-dummy-rm", help="path to model") parser.add_argument( - "--tokenizer", type=str, default=None, help="path to non-matching tokenizer, requires --direct_load" + "--model", + type=str, + default="natolambert/gpt2-dummy-rm", + help="path to model", ) - parser.add_argument("--chat_template", type=str, default="tulu", help="path to chat template") parser.add_argument( - "--batch_size", type=int, default=64, help="batch size for inference (if above number of tokens)" + "--tokenizer", + type=str, + default=None, + help="path to non-matching tokenizer, requires --direct_load", + ) + parser.add_argument( + "--chat_template", + type=str, + default="tulu", + help="path to chat template", + ) + parser.add_argument( + "--batch_size", + type=int, + default=64, + help="batch size for inference (if above number of tokens)", + ) + parser.add_argument( + "--text", + type=str, + default="I love to drink coffee at work.", + help="text to evaluate", ) - parser.add_argument("--text", type=str, default="I love to drink coffee at work.", help="text to evaluate") args = parser.parse_args() if "PairRM" in args.model or "PairRM" in args.chat_template or "SHP" in args.model or "SHP" in args.chat_template: From 8dee1476ed30c7c05d10023274aad4372a46bb0b Mon Sep 17 00:00:00 2001 From: Lj Miranda Date: Tue, 13 Feb 2024 19:49:04 -0800 Subject: [PATCH 02/11] Update model loading --- analysis/per_token_reward.py | 313 +++++++++++++++++++---------------- herm/models/__init__.py | 14 ++ 2 files changed, 181 insertions(+), 146 deletions(-) diff --git a/analysis/per_token_reward.py b/analysis/per_token_reward.py index b7c80f3a..4e35fb53 100644 --- a/analysis/per_token_reward.py +++ b/analysis/per_token_reward.py @@ -17,6 +17,7 @@ import argparse import logging import sys +from typing import Any, Dict, List, Optional import torch import transformers @@ -31,6 +32,47 @@ pipeline, ) +from herm import models + +REWARD_MODEL_CONFIG = { + "default": { + "model_builder": AutoModelForSequenceClassification.from_pretrained, + "pipeline_builder": pipeline, + "quantized": True, + "custom_dialogue": False, + }, + "oasst": { + "model_builder": AutoModelForSequenceClassification.from_pretrained, + "pipeline_builder": pipeline, + "quantized": True, + "custom_dialogue": False, + }, + "Starling": { + "model_builder": models.starling.build_starling_rm, + "pipeline_builder": models.starling.StarlingPipeline, + "quantized": False, + "custom_dialogue": False, + }, + "openbmb": { + "model_builder": models.openbmb.LlamaRewardModel.from_pretrained, + "pipeline_builder": models.openbmb.OpenBMBPipeline, + "quantized": True, + "custom_dialogue": False, + }, + "PairRM": { + "model_builder": models.pairrm.DebertaV2Model.from_pretrained, + "pipeline_builder": models.pairrm.PairRMPipeline, + "quantized": True, + "custom_dialogue": True, + }, + "SHP": { + "model_builder": T5ForConditionalGeneration.from_pretrained, + "pipeline_builder": models.shp.SHPPipeline, + "quantized": True, + "custom_dialogue": True, + }, +} + def get_args(): """ @@ -41,217 +83,196 @@ def get_args(): "--model", type=str, default="natolambert/gpt2-dummy-rm", - help="path to model", + help="Path to the model or HuggingFace link.", ) parser.add_argument( "--tokenizer", type=str, default=None, - help="path to non-matching tokenizer, requires --direct_load", + help="Path to non-matching tokenizer, requires --direct_load.", ) parser.add_argument( "--chat_template", type=str, default="tulu", - help="path to chat template", + help="Path to the chat template.", ) parser.add_argument( "--batch_size", type=int, default=64, - help="batch size for inference (if above number of tokens)", + help="Batch size for inference (if above number of tokens).", ) parser.add_argument( "--text", type=str, default="I love to drink coffee at work.", - help="text to evaluate", + help="Text to evaluate.", ) args = parser.parse_args() - if "PairRM" in args.model or "PairRM" in args.chat_template or "SHP" in args.model or "SHP" in args.chat_template: - # Note: SHP could be used in single-output mode, but the code is not yet added - raise ValueError("PairRM and SHP require pairwise inputs, not supported") + # Input validation + def _validate_require_pairwise_inputs(models): + for model in models: + if args.model == model or args.chat_template == model: + raise ValueError(f"{model} require pairwise inputs, not supported") + + _validate_require_pairwise_inputs(models=["PairRM", "SHP"]) + return args def main(): args = get_args() - quantized = True # only Starling isn't quantized for now - custom_dialogue = False - # some models need custom code to be run - if "oasst" in args.model or "oasst" in args.chat_template: - from herm.models import openassistant # noqa - - model_builder = AutoModelForSequenceClassification.from_pretrained - pipeline_builder = pipeline - elif "Starling" in args.model or "Starling" in args.chat_template: - from herm.models.starling import StarlingPipeline, build_starling_rm - - model_builder = build_starling_rm - pipeline_builder = StarlingPipeline - quantized = False - elif "openbmb" in args.model or "openbmb" in args.chat_template: - from herm.models.openbmb import LlamaRewardModel, OpenBMBPipeline - - model_builder = LlamaRewardModel.from_pretrained - pipeline_builder = OpenBMBPipeline - elif "PairRM" in args.model or "PairRM" in args.chat_template: - from herm.models.pairrm import DebertaV2PairRM, PairRMPipeline - - custom_dialogue = True - model_builder = DebertaV2PairRM.from_pretrained - pipeline_builder = PairRMPipeline - elif "SHP" in args.model or "SHP" in args.chat_template: - from herm.models.shp import SHPPipeline - - custom_dialogue = True - model_builder = T5ForConditionalGeneration.from_pretrained - pipeline_builder = SHPPipeline - else: - model_builder = AutoModelForSequenceClassification.from_pretrained - pipeline_builder = pipeline + model_name = args.model if args.model in REWARD_MODEL_CONFIG.keys() else "default" + config = REWARD_MODEL_CONFIG.get(model_name) - if custom_dialogue: + if config["custom_dialogue"]: raise ValueError("Custom dialogue formatting not yet supported in this script") - ############### - # Setup logging - ############### - accelerator = Accelerator() + # Setup the accelerate state first before using logging since it errors out + # if you do the other first. + accelerator = Accelerator(cpu=True) current_device = accelerator.process_index - logger = get_logger(__name__) - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[logging.StreamHandler(sys.stdout)], - ) - log_level = logging.INFO - logger.setLevel(log_level) - transformers.utils.logging.set_verbosity(log_level) - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() - + # Setup logging + logger = setup_logging(name=__name__) logger.info(f"Running reward model on {args.model} with chat template {args.chat_template}") - ############################ - # Load reward model pipeline - ############################ - tokenizer_path = args.tokenizer if args.tokenizer else args.model - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + # Prepare dataset and tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) + + def _tokenify_string(string): + tokens = tokenizer.tokenize(string) + cumulative_texts = [tokenizer.convert_tokens_to_string(tokens[: i + 1]) for i, _ in enumerate(tokens)] + return cumulative_texts - BATCH_SIZE = args.batch_size - logger.info("*** Load reward model ***") + substrings = _tokenify_string(args.text) + dataset = Dataset.from_list([{"text": substring} for substring in substrings]) + + # Load reward model pipeline + logger.info("Loading reward model") + reward_pipeline = load_reward_pipeline( + args.model, + config=config, + tokenizer=tokenizer, + process_index=current_device, + ) reward_pipeline_kwargs = { - "batch_size": BATCH_SIZE, # eval_args.inference_batch_size, + "batch_size": args.batch_size, # eval_args.inference_batch_size, "truncation": True, "padding": True, "max_length": 2048, "function_to_apply": "none", # Compute raw logits "return_token_type_ids": False, } - if quantized: - model_kwargs = { - "load_in_8bit": True, - "device_map": {"": current_device}, - "torch_dtype": torch.float16 if torch.cuda.is_available() else None, - } - else: - model_kwargs = {"device_map": {"": current_device}} - # TODO remove direct load logic - # if pipeline_builder is pipeline, use built in pipeline, else custom + + # Perform inference and get per-token reward + per_token_rewards = get_per_token_reward( + dataset, + reward_pipeline=reward_pipeline, + reward_pipeline_kwargs=reward_pipeline_kwargs, + accelerator=accelerator, + is_custom_pipeline=config["pipeline_builder"] == pipeline, + logger=logger, + dataloader_batch_size=args.batch_size, + ) + + # Report the results + for reward, token in zip(per_token_rewards, substrings): + print(f"Reward: {round(reward, 3)} | Substring: {token}") + + +def setup_logging(name: Optional[str] = None) -> logging.Logger: + logger = get_logger(name or __name__) + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = logging.INFO + logger.setLevel(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + return logger + + +def load_reward_pipeline( + model_name: str, + *, + config: Dict[str, Any], + tokenizer: "transformers.PreTrainedTokenizer", + process_index: int, +): + model_kwargs = {"device_map": {"": process_index}} + if config["quantized"]: + model_kwargs.update( + { + "load_in_8bit": True, + "torch_dtype": torch.float16 if torch.cuda.is_available() else None, + } + ) + model_builder = config["model_builder"] + pipeline_builder = config["pipeline_builder"] if not pipeline == pipeline_builder: - model = model_builder(args.model, **model_kwargs) - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - reward_pipe = pipeline_builder( + reward_pipeline = pipeline_builder( "text-classification", - model=model, + model=model_builder(model_name, **model_kwargs), tokenizer=tokenizer, ) else: - reward_pipe = pipeline( + reward_pipeline = pipeline( "text-classification", - model=args.model, + model=model_name, tokenizer=tokenizer, revision="main", model_kwargs=model_kwargs, ) - - ############################ - # Tokenization settings & dataset preparation - ############################ - # set pad token to eos token if not set - if reward_pipe.tokenizer.pad_token_id is None: - reward_pipe.model.config.pad_token_id = reward_pipe.tokenizer.eos_token_id - reward_pipe.tokenizer.pad_token_id = reward_pipe.tokenizer.eos_token_id - - def tokenify_string(string, tokenizer): - # Tokenize the entire text - tokens = tokenizer.tokenize(string) - - cumulative_texts = [] - # Iterate over each token - for i, _ in enumerate(tokens): - # Append the current cumulative text to the list - cumulative_texts.append(tokenizer.convert_tokens_to_string(tokens[: i + 1])) - - return cumulative_texts - - substrings = tokenify_string(args.text, tokenizer) - # create dataset from list of strings substrings with huggingface - dataset = [{"text": substring} for substring in substrings] - dataset = Dataset.from_list(dataset) - - ############################ - # Run inference [1/2]" built in transformers - ############################ - # if using HF pipeline, can pass entire dataset and get results - # first, handle custom pipelines that we must batch normally - if not pipeline_builder == pipeline: - logger.info("*** Running forward pass via built in pipeline abstraction ***") - # this setup can be optimized slightly with one pipeline call - # prepare for inference - reward_pipe = accelerator.prepare(reward_pipe) - - rewards = reward_pipe(dataset["text"], **reward_pipeline_kwargs) - - ############################ - # Run inference [2/2] custom pipelines - ############################ - else: - logger.info("*** Running dataloader to collect results ***") - + # Tokenization settings + if reward_pipeline.tokenizer.pad_token_id is None: + reward_pipeline.model.config.pad_token_id = reward_pipeline.tokenizer.eos_token_id + reward_pipeline.tokenizer.pad_token_id = reward_pipeline.tokenizer.eos_token_id + + return reward_pipeline + + +def get_per_token_reward( + dataset: Dataset, + *, + reward_pipeline: "transformers.Pipeline", + reward_pipeline_kwargs: Dict[str, Any], + accelerator: "Accelerator", + is_custom_pipeline: bool, + logger: "logging.Logger", + dataloader_batch_size: int, +) -> List[float]: + if is_custom_pipeline: + logger.info("Running dataloader to collect results") dataloader = torch.utils.data.DataLoader( dataset, - batch_size=BATCH_SIZE, + batch_size=dataloader_batch_size, collate_fn=None, shuffle=False, drop_last=False, ) - - dataloader, model = accelerator.prepare(dataloader, reward_pipe.model) - reward_pipe.model = model + dataloader, model = accelerator.prepare(dataloader, reward_pipeline.model) + reward_pipeline.model = model results = [] for step, batch in enumerate(tqdm(dataloader, desc="RM batch steps")): logger.info(f"RM inference step {step}/{len(dataloader)}") - rewards = reward_pipe(batch["text"], **reward_pipeline_kwargs) - - # for each item in batch, record 1 if chosen > rejected - # extra score from dict within batched results (e.g. logits) - # [{'label': 'LABEL_1', 'score': 0.6826171875},... ] - if isinstance(rewards[0], dict): - scores = [result["score"] for result in rewards] - # for classes that directly output scores (custom code) - else: - scores = rewards.cpu().numpy().tolist() - + rewards = reward_pipeline(batch["text"], **reward_pipeline_kwargs) + # Some pipeline implementations return a list of dictionaries, if that's the + # case, we only take the value in the 'score' key. Else, we just return the list. + scores = [r["score"] for r in rewards] if isinstance(rewards[0], dict) else rewards.cpu().numpy().tolist() results.extend(scores) + else: + logger.info("Running forward pass via built-in pipeline abstraction") + reward_pipeline = accelerator.prepare(reward_pipeline) + results = reward_pipeline(dataset["text"], reward_pipeline_kwargs) - # print the results - for i, substring in enumerate(substrings): - print(f"Reward: {round(results[i], 3)} | Substring: {substring}") + return results if __name__ == "__main__": diff --git a/herm/models/__init__.py b/herm/models/__init__.py index e69de29b..85dda474 100644 --- a/herm/models/__init__.py +++ b/herm/models/__init__.py @@ -0,0 +1,14 @@ +from .openbmb import LlamaRewardModel, OpenBMBPipeline +from .pairrm import DebertaV2PairRM, PairRMPipeline +from .shp import SHPPipeline +from .starling import StarlingPipeline, build_starling_rm + +__all__ = [ + "LlamaRewardModel", + "OpenBMBPipeline", + "DebertaV2PairRM", + "PairRMPipeline", + "SHPPipeline", + "StarlingPipeline", + "build_starling_rm", +] From 4fb0368e23ff0f85abea1fda31f916f9faea7393 Mon Sep 17 00:00:00 2001 From: ljvmiranda921 Date: Mon, 19 Feb 2024 09:47:24 -0800 Subject: [PATCH 03/11] Implement results saving and hashing --- analysis/per_token_reward.py | 90 +++++++++++++++++++++++++++++++----- 1 file changed, 78 insertions(+), 12 deletions(-) diff --git a/analysis/per_token_reward.py b/analysis/per_token_reward.py index 4e35fb53..ff7c1fbb 100644 --- a/analysis/per_token_reward.py +++ b/analysis/per_token_reward.py @@ -15,8 +15,11 @@ # Script to output the per-token reward across a piece of text given a reward model import argparse +import hashlib +import json import logging import sys +from pathlib import Path from typing import Any, Dict, List, Optional import torch @@ -79,6 +82,13 @@ def get_args(): Parse arguments strings model and chat_template """ parser = argparse.ArgumentParser() + # positional arguments + parser.add_argument( + "text", + type=str, + help="Text to evaluate.", + ) + # optional arguments parser.add_argument( "--model", type=str, @@ -97,18 +107,19 @@ def get_args(): default="tulu", help="Path to the chat template.", ) + parser.add_argument( + "--output_dir", + type=Path, + default="per-token-reward", + help="Directory to store the hashes and token information.", + ) parser.add_argument( "--batch_size", type=int, default=64, help="Batch size for inference (if above number of tokens).", ) - parser.add_argument( - "--text", - type=str, - default="I love to drink coffee at work.", - help="Text to evaluate.", - ) + parser.add_argument("--random_seed", type=int, default=None, help="Random seed for reproducibility.") args = parser.parse_args() # Input validation @@ -127,6 +138,10 @@ def main(): model_name = args.model if args.model in REWARD_MODEL_CONFIG.keys() else "default" config = REWARD_MODEL_CONFIG.get(model_name) + if args.random_seed: + print(f"Setting random seed to {args.random_seed}") + torch.manual_seed(args.random_seed) + if config["custom_dialogue"]: raise ValueError("Custom dialogue formatting not yet supported in this script") @@ -143,11 +158,12 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) def _tokenify_string(string): - tokens = tokenizer.tokenize(string) - cumulative_texts = [tokenizer.convert_tokens_to_string(tokens[: i + 1]) for i, _ in enumerate(tokens)] - return cumulative_texts + _tokens = tokenizer.tokenize(string) + cumulative_texts = [tokenizer.convert_tokens_to_string(_tokens[: i + 1]) for i, _ in enumerate(_tokens)] + tokens = tokenizer.convert_tokens_to_string(_tokens).split(" ") + return cumulative_texts, tokens - substrings = _tokenify_string(args.text) + substrings, tokens = _tokenify_string(args.text) dataset = Dataset.from_list([{"text": substring} for substring in substrings]) # Load reward model pipeline @@ -179,8 +195,19 @@ def _tokenify_string(string): ) # Report the results - for reward, token in zip(per_token_rewards, substrings): - print(f"Reward: {round(reward, 3)} | Substring: {token}") + for reward, span in zip(per_token_rewards, substrings): + print(f"Reward: {round(reward, 3)} | Substring: {span}") + + # Save the results + save_results( + output_dir=args.output_dir, + text=args.text, + model=args.model, + chat_template=args.chat_template, + substrings=substrings, + tokens=tokens, + rewards=per_token_rewards, + ) def setup_logging(name: Optional[str] = None) -> logging.Logger: @@ -275,5 +302,44 @@ def get_per_token_reward( return results +def save_results( + output_dir: Path, + text: str, + model: str, + chat_template: str, + substrings: List[str], + tokens: List[str], + rewards: List[str], +): + # Hash the text first using base16 + text_hash = hashlib.shake_256(text.encode()).hexdigest(5) + text_dir = output_dir / text_hash + text_dir.mkdir(parents=True, exist_ok=True) + + # Hash the model and chat_template combination + MODEL_CHAT_DELIMITER = "___" + model_chat_text = model + MODEL_CHAT_DELIMITER + chat_template + model_chat_hash = hashlib.shake_256(model_chat_text.encode()).hexdigest(5) + + # Output file will be the model_chat_hash + output_file = text_dir / f"{model_chat_hash}.json" + print(f"Saving results to {text_dir}") + + reward_info = { + "text": text, + "text_hash": text_hash, + "model": model, + "chat_template": chat_template, + "model_chat_hash": model_chat_hash, + "substrings": substrings, + "tokens": tokens, + "rewards": rewards, + } + + # Assumes the model output is a pointer to a HuggingFace repository + with open(output_file, "w") as f: + json.dump(reward_info, f, indent=4) + + if __name__ == "__main__": main() From 2cca0eeab1702e5c4b7181411cc487f0e9424366 Mon Sep 17 00:00:00 2001 From: ljvmiranda921 Date: Mon, 19 Feb 2024 11:35:58 -0800 Subject: [PATCH 04/11] Preserve tokenization artifacts --- analysis/per_token_reward.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/analysis/per_token_reward.py b/analysis/per_token_reward.py index ff7c1fbb..f8be730f 100644 --- a/analysis/per_token_reward.py +++ b/analysis/per_token_reward.py @@ -160,7 +160,9 @@ def main(): def _tokenify_string(string): _tokens = tokenizer.tokenize(string) cumulative_texts = [tokenizer.convert_tokens_to_string(_tokens[: i + 1]) for i, _ in enumerate(_tokens)] - tokens = tokenizer.convert_tokens_to_string(_tokens).split(" ") + # Hacky approach. Ideally we can do a str.split(" ") but we want to + # preserve the subword tokenization by the tokenizer. + tokens = [tokenizer.convert_tokens_to_string([t]) for t in _tokens] return cumulative_texts, tokens substrings, tokens = _tokenify_string(args.text) From 0ee01f45fddb5027af00460ece64f1a763aad9eb Mon Sep 17 00:00:00 2001 From: ljvmiranda921 Date: Mon, 19 Feb 2024 11:50:01 -0800 Subject: [PATCH 05/11] Fix model name implementation Model name is actually a HuggingFace model. --- analysis/per_token_reward.py | 97 ++++++++++++++++++++++++++++++++++-- 1 file changed, 92 insertions(+), 5 deletions(-) diff --git a/analysis/per_token_reward.py b/analysis/per_token_reward.py index f8be730f..8a5b8b51 100644 --- a/analysis/per_token_reward.py +++ b/analysis/per_token_reward.py @@ -49,30 +49,50 @@ "pipeline_builder": pipeline, "quantized": True, "custom_dialogue": False, + "models": [ + "OpenAssistant/oasst-rm-2-pythia-6.9b-epoch-1", + "OpenAssistant/oasst-rm-2.1-pythia-1.4b-epoch-2.5", + "OpenAssistant/reward-model-deberta-v3-base", + "OpenAssistant/reward-model-deberta-v3-large", + "OpenAssistant/reward-model-deberta-v3-large-v2", + "OpenAssistant/reward-model-electra-large-discriminator", + ], }, "Starling": { "model_builder": models.starling.build_starling_rm, "pipeline_builder": models.starling.StarlingPipeline, "quantized": False, "custom_dialogue": False, + "models": [ + "berkeley-nest/Starling-RM-7B-alpha", + ], }, "openbmb": { "model_builder": models.openbmb.LlamaRewardModel.from_pretrained, "pipeline_builder": models.openbmb.OpenBMBPipeline, "quantized": True, "custom_dialogue": False, + "models": ["openbmb/UltraRM-13b"], }, "PairRM": { "model_builder": models.pairrm.DebertaV2Model.from_pretrained, "pipeline_builder": models.pairrm.PairRMPipeline, "quantized": True, "custom_dialogue": True, + "models": [ + "llm-blender/PairRM", + "llm-blender/PairRM-hf", + ], }, "SHP": { "model_builder": T5ForConditionalGeneration.from_pretrained, "pipeline_builder": models.shp.SHPPipeline, "quantized": True, "custom_dialogue": True, + "models": [ + "stanfordnlp/SteamSHP-flan-t5-large", + "stanfordnlp/SteamSHP-flan-t5-xl", + ], }, } @@ -119,13 +139,18 @@ def get_args(): default=64, help="Batch size for inference (if above number of tokens).", ) - parser.add_argument("--random_seed", type=int, default=None, help="Random seed for reproducibility.") + parser.add_argument( + "--random_seed", + type=int, + default=None, + help="Random seed for reproducibility.", + ) args = parser.parse_args() # Input validation def _validate_require_pairwise_inputs(models): for model in models: - if args.model == model or args.chat_template == model: + if args.model in model or args.chat_template in model: raise ValueError(f"{model} require pairwise inputs, not supported") _validate_require_pairwise_inputs(models=["PairRM", "SHP"]) @@ -136,6 +161,7 @@ def _validate_require_pairwise_inputs(models): def main(): args = get_args() model_name = args.model if args.model in REWARD_MODEL_CONFIG.keys() else "default" + config = REWARD_MODEL_CONFIG.get(model_name) if args.random_seed: @@ -155,7 +181,9 @@ def main(): logger.info(f"Running reward model on {args.model} with chat template {args.chat_template}") # Prepare dataset and tokenizer - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) + tokenizer_path = args.tokenizer if args.tokenizer else args.model + print(f"Loading tokenizer from {tokenizer_path}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) def _tokenify_string(string): _tokens = tokenizer.tokenize(string) @@ -212,7 +240,35 @@ def _tokenify_string(string): ) +def get_config(model_name: str, default_if_missing: bool = True) -> Dict[str, Any]: + """Get the appropriate loading configuration given a model name + + We only do minimal string matching here, basically checking if a substring, say, + oasst or others exist in REWARD_MODEL_CONFIG.keys(). + + model_name (str): the HuggingFace link or pointer to the model. + default_if_missing (bool): if True, will return the default configuration if + model is missing from our config templates. If False, then it raises + a ValueError. + RETURNS (Dict[str, Any]): the loading configuration for a given model. + """ + for tpl, config in REWARD_MODEL_CONFIG.items(): + available_models = config["models"] + if model_name in available_models: + config = config.pop("models") + print(f"Returning configuration from {tpl}. Config={config}") + return config + + # If model_name is not found anywhere + if default_if_missing: + print("Model {model_name} not found in available models. Returning default configuration") + return REWARD_MODEL_CONFIG.get("default") + else: + raise ValueError(f"Model {model_name} not found in available models!") + + def setup_logging(name: Optional[str] = None) -> logging.Logger: + """Create a logger""" logger = get_logger(name or __name__) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -233,7 +289,15 @@ def load_reward_pipeline( config: Dict[str, Any], tokenizer: "transformers.PreTrainedTokenizer", process_index: int, -): +) -> transformers.Pipeline: + """Load a reward model pipeline given a model configuration and its tokenizer. + + model_name (str): the HuggingFace link or pointer to the model. + config (Dict[str, Any]): the model configuration. + tokenizer (transformers.PreTrainedTokenizer): the tokenizer to use with the model. + process_index (int): the machine to run the process. + RETURNS (transformers.Pipeline) the reward model pipeline + """ model_kwargs = {"device_map": {"": process_index}} if config["quantized"]: model_kwargs.update( @@ -245,9 +309,10 @@ def load_reward_pipeline( model_builder = config["model_builder"] pipeline_builder = config["pipeline_builder"] if not pipeline == pipeline_builder: + model = model_builder(model_name, **model_kwargs) reward_pipeline = pipeline_builder( "text-classification", - model=model_builder(model_name, **model_kwargs), + model=model, tokenizer=tokenizer, ) else: @@ -276,6 +341,16 @@ def get_per_token_reward( logger: "logging.Logger", dataloader_batch_size: int, ) -> List[float]: + """Get the reward per subtoken + + dataset (datasets.Dataset): the HuggingFace dataset to source the text from. + reward_pipeline (transformers.Pipeline): the reward pipeline that will provide the scores. + accelerator (Accelerator): accelerator class for training performance. + is_custom_pipeline (bool): flag to check if we need to run a data loader to collate the results. + logger (logging.Logger): logger class. + dataloader_batch_size (int): control the batch size passed to the data loader. + RETURNS (List[float]): list of computed rewards for each token. + """ if is_custom_pipeline: logger.info("Running dataloader to collect results") dataloader = torch.utils.data.DataLoader( @@ -313,6 +388,18 @@ def save_results( tokens: List[str], rewards: List[str], ): + """Save results to disk + + This function will first hash the prompt, and then the model with the chat template. + Then, it will save the model result in a JSON file on disk. + + output_dir (Path): directory to save the files. + text (str): the text used to hash. The hashed string will be the name of the subdirectory. + model (str): the name of the model + chat_template (str): the name of the chat template. + tokens (List[str]): the tokens extracted by the reward pipeline's tokenizer. + rewards (List[str]): the rewards computed by the reward pipeline. + """ # Hash the text first using base16 text_hash = hashlib.shake_256(text.encode()).hexdigest(5) text_dir = output_dir / text_hash From 69835a01293c60ddb1490f8f7d7bd4eaa221a927 Mon Sep 17 00:00:00 2001 From: ljvmiranda921 Date: Mon, 19 Feb 2024 13:51:10 -0800 Subject: [PATCH 06/11] Improve structure and add verbs --- analysis/draw_per_token_reward.py | 0 analysis/{per_token_reward.py => get_per_token_reward.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 analysis/draw_per_token_reward.py rename analysis/{per_token_reward.py => get_per_token_reward.py} (100%) diff --git a/analysis/draw_per_token_reward.py b/analysis/draw_per_token_reward.py new file mode 100644 index 00000000..e69de29b diff --git a/analysis/per_token_reward.py b/analysis/get_per_token_reward.py similarity index 100% rename from analysis/per_token_reward.py rename to analysis/get_per_token_reward.py From 7af4900b01c5af23e8881ccfd6918b2f67e6575c Mon Sep 17 00:00:00 2001 From: ljvmiranda921 Date: Mon, 19 Feb 2024 14:35:47 -0800 Subject: [PATCH 07/11] Update visualization and try alignment --- analysis/draw_per_token_reward.py | 71 +++++++++++++++++++++++++++++++ herm/visualization.py | 8 +++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/analysis/draw_per_token_reward.py b/analysis/draw_per_token_reward.py index e69de29b..63e96874 100644 --- a/analysis/draw_per_token_reward.py +++ b/analysis/draw_per_token_reward.py @@ -0,0 +1,71 @@ +# Copyright 2023 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Draw the per token reward + +import json +from pathlib import Path + +import argparse + +from herm.visualization import draw_per_token_reward + +DEFAULT_DIRNAME = "per-token-reward" + + +def get_args(): + parser = argparse.ArgumentParser() + # positional arguments + parser.add_argument("text_hash", type=str, help="Path or pointer to the text hash to plot.") + parser.add_argument("output_path", type=Path, help="Filepath to save the generated figure.") + # optional arguments + parser.add_argument( + "--local", + action="store_true", + help="Find the file locally.", + ) + parser.add_argument( + "--figsize", + type=int, + nargs=2, + default=[12, 8], + help="Control the figure size when plotting.", + ) + args = parser.parse_args() + return args + + +def main(): + args = get_args() + if args.local: + input_dir = Path.cwd() / DEFAULT_DIRNAME / args.text_hash + assert input_dir.exists(), f"Directory {input_dir} does not exist!" + + rewards = {} + for file in input_dir.glob("*.json"): + with open(file) as f: + results = json.load(f) + rewards[results["model"]] = results + + assert len(rewards) > 0, f"Directory {input_dir} is empty!" + + else: + # TODO: Source from a huggingface repo + ... + + breakpoint() + + +if __name__ == "__main__": + main() diff --git a/herm/visualization.py b/herm/visualization.py index 055f4a28..7a2c31b7 100644 --- a/herm/visualization.py +++ b/herm/visualization.py @@ -14,8 +14,9 @@ # Module for visualizing datasets and post-hoc analyses. +from pathlib import Path from collections import Counter -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import datasets import matplotlib @@ -24,6 +25,11 @@ import pandas as pd +def draw_per_token_reward(tokens: List[str], rewards: Dict[str, List[float]]) -> "matplotlib.axes.Axes": + """Draw a heatmap that combines the rewards""" + breakpoint() + + def print_model_statistics( dataset_name: str = "ai2-adapt-dev/rm-benchmark-dev", keys: List[str] = ["chosen_model", "rejected_model"], From 58f1f9d0133fb4f101f4e96dbb280e6aa73cc009 Mon Sep 17 00:00:00 2001 From: ljvmiranda921 Date: Mon, 19 Feb 2024 15:10:59 -0800 Subject: [PATCH 08/11] Add spacy-alignments for aligning tokens --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 193ae1ae..e2b25557 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ "isort>=5.12.0", "pandas", "pytest", + "spacy-alignments", # dependency for aligning tokenizers "scipy", "tabulate", # dependency for markdown rendering in pandas "tokenizers", From 95db8ca2bc4879250d63697c512045c13168ca2b Mon Sep 17 00:00:00 2001 From: ljvmiranda921 Date: Mon, 19 Feb 2024 15:11:06 -0800 Subject: [PATCH 09/11] Fix visualization code --- analysis/draw_per_token_reward.py | 64 +++++++++++++++++++++------- herm/visualization.py | 70 +++++++++++++++++++++++++++++-- 2 files changed, 116 insertions(+), 18 deletions(-) diff --git a/analysis/draw_per_token_reward.py b/analysis/draw_per_token_reward.py index 63e96874..704b0db7 100644 --- a/analysis/draw_per_token_reward.py +++ b/analysis/draw_per_token_reward.py @@ -16,9 +16,12 @@ import json from pathlib import Path - +from typing import List import argparse +import numpy as np +import spacy_alignments as tokenizations + from herm.visualization import draw_per_token_reward DEFAULT_DIRNAME = "per-token-reward" @@ -39,32 +42,63 @@ def get_args(): "--figsize", type=int, nargs=2, - default=[12, 8], + default=[8, 8], help="Control the figure size when plotting.", ) args = parser.parse_args() return args +def align_tokens(reference_tokens: List[str], predicted_tokens: List[str], rewards: List[float]) -> List[float]: + """Align tokens and recompute the reward + + reference_tokens (List[str]): the reference tokenization to base the alignment on. + predicted_tokens (List[str]): the tokens from the reward pipeline. + rewards (List[float]): the per-token reward. + RETURNS (List[float]): the recomputed per-token reward. + """ + a2b, _ = tokenizations.get_alignments(reference_tokens, predicted_tokens) + rewards_list = [] + for aligned_idxs in a2b: + rewards_list.append([rewards[idx] for idx in aligned_idxs]) + aligned_rewards = list(map(np.mean, rewards_list)) + return aligned_rewards + + def main(): args = get_args() - if args.local: - input_dir = Path.cwd() / DEFAULT_DIRNAME / args.text_hash - assert input_dir.exists(), f"Directory {input_dir} does not exist!" + # Read the results first + input_dir = Path.cwd() / DEFAULT_DIRNAME / args.text_hash + assert input_dir.exists(), f"Directory {input_dir} does not exist!" + + rewards = {} + for file in input_dir.glob("*.json"): + with open(file) as f: + results = json.load(f) + rewards[results["model"]] = results - rewards = {} - for file in input_dir.glob("*.json"): - with open(file) as f: - results = json.load(f) - rewards[results["model"]] = results + assert len(rewards) > 0, f"Directory {input_dir} is empty!" - assert len(rewards) > 0, f"Directory {input_dir} is empty!" + # Get reference alignment + first_key = next(iter(rewards)) # should be the same all throughout + text = rewards[first_key]["text"] + whitespace_tokenizer = lambda x: x.split(" ") + reference_tokens = whitespace_tokenizer(text) - else: - # TODO: Source from a huggingface repo - ... + for _, results in rewards.items(): + results["aligned_rewards"] = align_tokens( + reference_tokens=reference_tokens, + predicted_tokens=results["tokens"], + rewards=results["rewards"], + ) - breakpoint() + draw_per_token_reward( + tokens=reference_tokens, + rewards=[reward["aligned_rewards"] for _, reward in rewards.items()], + model_names=[model_name for model_name, _ in rewards.items()], + output_path=args.output_path, + figsize=args.figsize, + ) if __name__ == "__main__": diff --git a/herm/visualization.py b/herm/visualization.py index 7a2c31b7..67be28f2 100644 --- a/herm/visualization.py +++ b/herm/visualization.py @@ -25,9 +25,73 @@ import pandas as pd -def draw_per_token_reward(tokens: List[str], rewards: Dict[str, List[float]]) -> "matplotlib.axes.Axes": - """Draw a heatmap that combines the rewards""" - breakpoint() +def draw_per_token_reward( + tokens: List[str], + rewards: List[List[float]], + model_names: List[str], + font_size: int = 12, + output_path: Path = None, + figsize: Tuple[int, int] = (12, 12), +) -> "matplotlib.axes.Axes": + """Draw a heatmap that combines the rewards + + tokens (List[str]): the canonical tokens that was used as reference during alignment. + rewards (List[List[float]]): list of rewards-per-token for each model. + model_names (List[str]): list of models + font_size (int) + output_path (Optional[Path]): if set, then save the figure in the specified path. + figsize (Tuple[int, int]): control the figure size when plotting. + RETURNS (matplotlib.axes.Axes): an Axes class containing the heatmap. + """ + fig, ax = plt.subplots(figsize=figsize) + matplotlib.rcParams.update( + { + "font.size": font_size, + "xtick.labelsize": font_size, + "ytick.labelsize": font_size, + } + ) + rewards = np.array(rewards) + im = ax.imshow( + rewards, + cmap="viridis", + vmax=np.max(rewards), + vmin=np.min(rewards), + ) + cbar = fig.colorbar(im, ax=ax, orientation="horizontal", aspect=20, location="bottom") + ax.set_xticks(np.arange(len(tokens)), [f'"{token}"' for token in tokens]) + ax.set_yticks(np.arange(len(model_names)), model_names) + + # Add text + avg = np.mean(rewards) + for i in range(len(model_names)): + for j in range(len(tokens)): + color = "k" if rewards[i, j] >= avg else "w" + text = ax.text( + j, + i, + round(rewards[i, j], 4), + ha="center", + va="center", + color=color, + ) + + # Make it look better + ax.xaxis.tick_top() + ax.tick_params(left=False, top=False) + ax.spines[["right", "top", "left", "bottom"]].set_visible(False) + + # Added information + title = "Cumulative substring rewards" + ax.set_title(title, pad=20) + + # fig.tight_layout() + fig.subplots_adjust(left=0.5) + if output_path: + print(f"Saving per-token-reward heatmap to {output_path}") + plt.savefig(output_path, transparent=True, dpi=120) + + plt.show() def print_model_statistics( From bd02e5f9cbbe47cb4558c78d65c10ecf9ae41dab Mon Sep 17 00:00:00 2001 From: ljvmiranda921 Date: Mon, 19 Feb 2024 17:48:10 -0800 Subject: [PATCH 10/11] Add data to .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 0fdbf129..2667bd53 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,4 @@ beaker_configs/auto_created # generated local directory hf_snapshot_evals/ +data/ \ No newline at end of file From 9ac6660a42d707869752adb5785c8f4287db726e Mon Sep 17 00:00:00 2001 From: ljvmiranda921 Date: Mon, 19 Feb 2024 18:00:27 -0800 Subject: [PATCH 11/11] Ensure that make quality passes --- analysis/draw_per_token_reward.py | 2 +- herm/visualization.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/analysis/draw_per_token_reward.py b/analysis/draw_per_token_reward.py index 704b0db7..f5af285d 100644 --- a/analysis/draw_per_token_reward.py +++ b/analysis/draw_per_token_reward.py @@ -82,7 +82,7 @@ def main(): # Get reference alignment first_key = next(iter(rewards)) # should be the same all throughout text = rewards[first_key]["text"] - whitespace_tokenizer = lambda x: x.split(" ") + whitespace_tokenizer = lambda x: x.split(" ") # noqa reference_tokens = whitespace_tokenizer(text) for _, results in rewards.items(): diff --git a/herm/visualization.py b/herm/visualization.py index 67be28f2..a8753194 100644 --- a/herm/visualization.py +++ b/herm/visualization.py @@ -16,7 +16,7 @@ from pathlib import Path from collections import Counter -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import datasets import matplotlib @@ -58,7 +58,7 @@ def draw_per_token_reward( vmax=np.max(rewards), vmin=np.min(rewards), ) - cbar = fig.colorbar(im, ax=ax, orientation="horizontal", aspect=20, location="bottom") + fig.colorbar(im, ax=ax, orientation="horizontal", aspect=20, location="bottom") ax.set_xticks(np.arange(len(tokens)), [f'"{token}"' for token in tokens]) ax.set_yticks(np.arange(len(model_names)), model_names) @@ -67,7 +67,7 @@ def draw_per_token_reward( for i in range(len(model_names)): for j in range(len(tokens)): color = "k" if rewards[i, j] >= avg else "w" - text = ax.text( + ax.text( j, i, round(rewards[i, j], 4),