Skip to content

Commit

Permalink
Hotfix (#49)
Browse files Browse the repository at this point in the history
- added an argument to inference component to accommodate adding columns
to resume from component.
- improved the check for column mismatch in resume_from
- removed token counting from promp_processing component for speed up

---------

Co-authored-by: Safoora Yousefi <sayouse@microsoft.com>
  • Loading branch information
safooray and Safoora Yousefi authored Nov 15, 2024
1 parent c56768f commit 1713e79
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 23 deletions.
1 change: 1 addition & 0 deletions eureka_ml_insights/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class InferenceConfig(ComponentConfig):
data_loader_config: UtilityClassConfigType = None
model_config: UtilityClassConfigType = None
resume_from: str = None
new_columns: List[str] = None
requests_per_minute: int = None
max_concurrent: int = 1

Expand Down
2 changes: 1 addition & 1 deletion eureka_ml_insights/configs/mmmu.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None,
"format": ".jsonl",
"transform": SequenceTransform(
[
CopyColumn(column_name_src="task", column_name_dst="category"),
CopyColumn(column_name_src="__hf_task", column_name_dst="category"),
MapStringsTransform(
columns=["category"],
mapping=MMMUTaskToCategories,
Expand Down
2 changes: 1 addition & 1 deletion eureka_ml_insights/configs/nondeterminism.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ def configure_pipeline(self, **kwargs):
config = super().configure_pipeline(**kwargs)
# Downsample the data and repeat each prompt 3 time
self.data_processing_comp.data_reader_config.init_args["transform"].transforms.extend(
[SamplerTransform(random_seed=42, sample_count=5, stratify_by="task"), MultiplyTransform(n_repeats=3)]
[SamplerTransform(random_seed=42, sample_count=5, stratify_by="__hf_task"), MultiplyTransform(n_repeats=3)]
)
return config
28 changes: 21 additions & 7 deletions eureka_ml_insights/core/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from eureka_ml_insights.data_utils.data import DataReader, JsonLinesWriter

from .pipeline import Component

from .reserved_names import INFERENCE_RESERVED_NAMES
MINUTE = 60


class Inference(Component):
def __init__(self, model_config, data_config, output_dir, resume_from=None, requests_per_minute=None, max_concurrent=1):
def __init__(self, model_config, data_config, output_dir, resume_from=None, new_columns=None, requests_per_minute=None, max_concurrent=1):

"""
Initialize the Inference component.
Expand All @@ -24,6 +24,7 @@ def __init__(self, model_config, data_config, output_dir, resume_from=None, requ
data_config (dict): DataSetConfig object.
output_dir (str): Directory to save the inference results.
resume_from (str): optional. Path to the file where previous inference results are stored.
new_columns (list): optional. List of new columns to be added to resume_from data to match the current inference response.
requests_per_minute (int): optional. Number of inference requests to be made per minute, used for rate limiting. If not provided, rate limiting will not be applied.
max_concurrent (int): optional. Maximum number of concurrent inferences to run. Default is 1.
"""
Expand All @@ -35,6 +36,7 @@ def __init__(self, model_config, data_config, output_dir, resume_from=None, requ
self.resume_from = resume_from
if resume_from and not os.path.exists(resume_from):
raise FileNotFoundError(f"File {resume_from} not found.")
self.new_columns = new_columns

# rate limiting parameters
self.requests_per_minute = requests_per_minute
Expand All @@ -51,6 +53,7 @@ def from_config(cls, config):
config.data_loader_config,
config.output_dir,
resume_from=config.resume_from,
new_columns=config.new_columns,
requests_per_minute=config.requests_per_minute,
max_concurrent=config.max_concurrent,
)
Expand All @@ -59,7 +62,13 @@ def fetch_previous_inference_results(self):
# fetch previous results from the provided resume_from file
logging.info(f"Resuming inference from {self.resume_from}")
pre_inf_results_df = DataReader(self.resume_from, format=".jsonl").load_dataset()


# add new columns listed by the user to the previous inference results
if self.new_columns:
for col in self.new_columns:
if col not in pre_inf_results_df.columns:
pre_inf_results_df[col] = None

# validate the resume_from contents
with self.data_loader as loader:
_, sample_model_input = self.data_loader.get_sample_model_input()
Expand All @@ -73,11 +82,16 @@ def fetch_previous_inference_results(self):
sample_response_dict = self.model.generate(*sample_model_input)
# check if the inference response dictionary contains the same keys as the resume_from file
eventual_keys = set(sample_response_dict.keys()) | set(sample_data_keys)
if set(eventual_keys) != set(pre_inf_results_df.columns):

# in case of resuming from a file that was generated by an older version of the model,
# we let the discrepancy in the reserved keys slide and later set the missing keys to None
match_keys = set(pre_inf_results_df.columns) | set(INFERENCE_RESERVED_NAMES)

if set(eventual_keys) != match_keys:
diff = set(eventual_keys) ^ set(match_keys)
raise ValueError(
f"Columns in resume_from file do not match the current inference response. "
f"Current inference response keys: {sample_response_dict.keys()}. "
f"Resume_from file columns: {pre_inf_results_df.columns}."
f"Columns in resume_from file do not match the current input data and inference response. "
f"Problemtaic columns: {diff}"
)

# find the last uid that was inferenced
Expand Down
8 changes: 0 additions & 8 deletions eureka_ml_insights/core/prompt_processing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import logging
import os
import statistics
from hashlib import md5
from typing import List, Optional

from transformers import GPT2TokenizerFast

from eureka_ml_insights.data_utils import JinjaPromptTemplate

from .data_processing import DataProcessing
Expand Down Expand Up @@ -70,16 +67,13 @@ def run(self) -> None:
prompt_hashes = [compute_hash(prompt) for prompt in prompts]
# otherwise, use the prompt data processor to generate prompts and save in the "prompt" column
else:
prompt_num_tokens = []
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
with open(prompt_output_file, "w", encoding="utf-8") as writer:
for i, row in input_df.iterrows():

placeholders = row.to_dict()
try:
prompt = self.prompt_data_processor.create(placeholders)
success_indexes.append(i)
prompt_num_tokens.append(len(tokenizer.tokenize(prompt)))
prompt_hashes.append(compute_hash(prompt))
prompts.append(prompt)
writer.write(prompt + "\n")
Expand All @@ -91,8 +85,6 @@ def run(self) -> None:
else:
raise e

logging.info(f"Average prompt num tokens: {statistics.fmean(prompt_num_tokens)}.")

input_df = self.get_desired_columns(input_df)
# Remove `model_output`, `is_valid`, `response_time`, `n_output_tokens` columns if they exists
# in the data because these names are reserved for the inference component's use.
Expand Down
3 changes: 2 additions & 1 deletion eureka_ml_insights/core/reserved_names.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# if your data has any of these columns, they may be removed or overwritten by Eureka
INFERENCE_RESERVED_NAMES = ["model_output", "is_valid", "response_time", "n_output_tokens"]
PROMPT_PROC_RESERVED_NAMES = ["prompt_hash", "prompt", "uid", "data_point_id", "data_repeat_id"]
PROMPT_PROC_RESERVED_NAMES = ["prompt_hash", "prompt", "uid", "data_point_id", "data_repeat_id", "__hf_task", "__hf_split"]
6 changes: 3 additions & 3 deletions eureka_ml_insights/data_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,14 +613,14 @@ def _load_dataset(self) -> pd.DataFrame:
hf_dataset = load_dataset(self.path, cache_dir=self.cache_dir, split=self.split)
for i, data_split in enumerate(hf_dataset):
task_df = self._hf_to_dataframe(data_split)
task_df["split"] = self.split[i]
task_df["__hf_split"] = self.split[i]
df_frames.append(task_df)
else:
for task in self.tasks:
hf_dataset = load_dataset(self.path, task, cache_dir=self.cache_dir, split=self.split)
for i, data_split in enumerate(hf_dataset):
task_df = self._hf_to_dataframe(data_split)
task_df["task"] = task
task_df["split"] = self.split[i]
task_df["__hf_task"] = task
task_df["__hf_split"] = self.split[i]
df_frames.append(task_df)
return pd.concat(df_frames)
2 changes: 1 addition & 1 deletion eureka_ml_insights/data_utils/toxigen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
df[[self.model_output_column, "is_valid"]] = df[self.model_output_column].apply(
lambda x: pd.Series([parse_output(x, delimiters, True)[0], parse_output(x, delimiters, True)[1]])
)
df[[self.gt_column, self.category]] = df["split"].apply(
df[[self.gt_column, self.category]] = df["__hf_split"].apply(
lambda x: pd.Series([label_category_map(x)[0], label_category_map(x)[1]])
)
df[self.merged_group] = df[self.category] + "_" + df[self.gt_column]
Expand Down
2 changes: 2 additions & 0 deletions eureka_ml_insights/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,8 @@ def get_response(self, request):
openai_response = completion.model_dump()
self.model_output = openai_response["choices"][0]["message"]["content"]
self.response_time = end_time - start_time
if "usage" in openai_response:
return {"usage": openai_response["usage"]}


@dataclass
Expand Down
15 changes: 15 additions & 0 deletions tests/pipeline_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
MetricConfig,
ModelConfig,
ToxiGen_Discriminative_PIPELINE,
ToxiGen_Generative_PIPELINE,
)
from eureka_ml_insights.core import Pipeline
from eureka_ml_insights.data_utils.transform import (
Expand Down Expand Up @@ -259,6 +260,17 @@ def configure_pipeline(self):
}
return config

class TEST_TOXIGEN_GEN_PIPELINE(ToxiGen_Generative_PIPELINE):
def configure_pipeline(self):
config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {}))
self.inference_comp.data_loader_config.class_name = TestDataLoader
self.inference_comp.data_loader_config.init_args = {
"path": os.path.join(self.data_pre_processing.output_dir, "transformed_data.jsonl"),
"n_iter": N_ITER,
}
self.eval_inference_comp.model_config = ModelConfig(ToxiGenTestModel, {})
return config


class TEST_MMMU_PIPELINE(MMMU_BASELINE_PIPELINE):
# Test config the MMMU benchmark with MultipleChoiceTestModel and TestMMDataLoader
Expand Down Expand Up @@ -425,6 +437,9 @@ class TOXIGEN_PipelineTest(PipelineTest, unittest.TestCase):
def get_config(self):
return TEST_TOXIGEN_PIPELINE().pipeline_config

class TOXIGEN_GEN_PipelineTest(PipelineTest, unittest.TestCase):
def get_config(self):
return TEST_TOXIGEN_GEN_PIPELINE().pipeline_config

class KITAB_ONE_BOOK_CONSTRAINT_PIPELINE_PipelineTest(PipelineTest, unittest.TestCase):
def get_config(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, model_name="generic_test_model"):

def generate(self, text_prompt, query_images=None):
time.sleep(0.1)
return {"model_output": "model output", "is_valid": True}
return {"model_output": "model output", "is_valid": True, "response_time": 0, "n_output_tokens": 0}


class TestHFDataReader(HFDataReader):
Expand Down

0 comments on commit 1713e79

Please sign in to comment.