Skip to content

Commit

Permalink
add model inform
Browse files Browse the repository at this point in the history
  • Loading branch information
Mghao committed Dec 5, 2024
1 parent 8004172 commit 2819480
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
10 changes: 9 additions & 1 deletion rewardbench/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
MixtralForCausalLM,
T5ForConditionalGeneration,
)

from .inform import INFORMForSequenceClassification
from .armorm import ArmoRMPipeline
from .beaver import BeaverCostPipeline, BeaverPipeline, LlamaForScore
from .betterpairrm import BetterPairRMPipeline
Expand Down Expand Up @@ -222,6 +222,14 @@
"custom_dialogue": False,
"model_type": "Seq. Classifier",
},
"infly/INF-ORM-Llama3.1-70B": {
"model_builder": INFORMForSequenceClassification.from_pretrained,
"pipeline_builder": RewardBenchPipeline,
"quantized": False,
"custom_dialogue": False,
"model_type": "Seq. Classifier",
"torch_dtype": torch.bfloat16,
},
}

DPO_MODEL_CONFIG = {
Expand Down
88 changes: 88 additions & 0 deletions rewardbench/models/inform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 AllenAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Union

import torch
import torch.nn as nn
from transformers import LlamaPreTrainedModel,LlamaConfig, LlamaModel
from transformers.modeling_outputs import SequenceClassifierOutputWithPast

class INFORMForSequenceClassification(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = LlamaModel(config)
self.score = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.ReLU(),
nn.Linear(config.hidden_size, 1)
)

# Initialize weights and apply final processing
self.post_init()


def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):

transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

0 comments on commit 2819480

Please sign in to comment.