-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Print per-token reward over an RM #9
Conversation
Will review later today! |
args = get_args() | ||
quantized = True # only Starling isn't quantized for now | ||
custom_dialogue = False | ||
# some models need custom code to be run | ||
if "oasst" in args.model or "oasst" in args.chat_template: | ||
from herm.models import openassistant # noqa | ||
|
||
model_builder = AutoModelForSequenceClassification.from_pretrained | ||
pipeline_builder = pipeline | ||
elif "Starling" in args.model or "Starling" in args.chat_template: | ||
from herm.models.starling import StarlingPipeline, build_starling_rm | ||
|
||
model_builder = build_starling_rm | ||
pipeline_builder = StarlingPipeline | ||
quantized = False | ||
elif "openbmb" in args.model or "openbmb" in args.chat_template: | ||
from herm.models.openbmb import LlamaRewardModel, OpenBMBPipeline | ||
|
||
model_builder = LlamaRewardModel.from_pretrained | ||
pipeline_builder = OpenBMBPipeline | ||
elif "PairRM" in args.model or "PairRM" in args.chat_template: | ||
from herm.models.pairrm import DebertaV2PairRM, PairRMPipeline | ||
|
||
custom_dialogue = True | ||
model_builder = DebertaV2PairRM.from_pretrained | ||
pipeline_builder = PairRMPipeline | ||
elif "SHP" in args.model or "SHP" in args.chat_template: | ||
from herm.models.shp import SHPPipeline | ||
|
||
custom_dialogue = True | ||
model_builder = T5ForConditionalGeneration.from_pretrained | ||
pipeline_builder = SHPPipeline | ||
else: | ||
model_builder = AutoModelForSequenceClassification.from_pretrained | ||
pipeline_builder = pipeline | ||
|
||
if custom_dialogue: | ||
raise ValueError("Custom dialogue formatting not yet supported in this script") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case we're going to reuse this code block in the future, we should factor this logic out (so that we can reuse it on run_rm.py
), but imo for v1 it's fine for now 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree @ljvmiranda921 , and maybe add a test case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Brief documentation in
analysis/README.md
We can add visualizing it soon :)
E.g.
Reward: -0.544 | Substring: I
Reward: -0.556 | Substring: I love
Reward: -0.566 | Substring: I love to
Reward: 0.099 | Substring: I love to walk
Reward: 0.096 | Substring: I love to walk the
Reward: 0.092 | Substring: I love to walk the dog
Reward: 0.09 | Substring: I love to walk the dog,
Reward: 0.087 | Substring: I love to walk the dog, what
Reward: 0.085 | Substring: I love to walk the dog, what do
Reward: 0.089 | Substring: I love to walk the dog, what do you
Reward: 0.09 | Substring: I love to walk the dog, what do you like
Reward: 0.093 | Substring: I love to walk the dog, what do you like?