From 28194803dedb96abd0a834a1feafba2ceb589628 Mon Sep 17 00:00:00 2001 From: Mghao <19211311@bjtu.edu.cn> Date: Thu, 5 Dec 2024 15:49:32 +0000 Subject: [PATCH] add model inform --- rewardbench/models/__init__.py | 10 +++- rewardbench/models/inform.py | 88 ++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 rewardbench/models/inform.py diff --git a/rewardbench/models/__init__.py b/rewardbench/models/__init__.py index 532e884..5457c68 100644 --- a/rewardbench/models/__init__.py +++ b/rewardbench/models/__init__.py @@ -22,7 +22,7 @@ MixtralForCausalLM, T5ForConditionalGeneration, ) - +from .inform import INFORMForSequenceClassification from .armorm import ArmoRMPipeline from .beaver import BeaverCostPipeline, BeaverPipeline, LlamaForScore from .betterpairrm import BetterPairRMPipeline @@ -222,6 +222,14 @@ "custom_dialogue": False, "model_type": "Seq. Classifier", }, + "infly/INF-ORM-Llama3.1-70B": { + "model_builder": INFORMForSequenceClassification.from_pretrained, + "pipeline_builder": RewardBenchPipeline, + "quantized": False, + "custom_dialogue": False, + "model_type": "Seq. Classifier", + "torch_dtype": torch.bfloat16, + }, } DPO_MODEL_CONFIG = { diff --git a/rewardbench/models/inform.py b/rewardbench/models/inform.py new file mode 100644 index 0000000..fd0f458 --- /dev/null +++ b/rewardbench/models/inform.py @@ -0,0 +1,88 @@ +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from transformers import LlamaPreTrainedModel,LlamaConfig, LlamaModel +from transformers.modeling_outputs import SequenceClassifierOutputWithPast + +class INFORMForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Sequential( + nn.Linear(config.hidden_size, config.hidden_size), + nn.ReLU(), + nn.Linear(config.hidden_size, 1) + ) + + # Initialize weights and apply final processing + self.post_init() + + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file