Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
  • Loading branch information
SumanthRH committed Feb 27, 2025
1 parent e113874 commit 6fdca5c
Show file tree
Hide file tree
Showing 16 changed files with 1,457 additions and 137 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ repos:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
# NOTE (sumanthrh): Many of the files excluded here are used for validating code generation, and linters do not recognize some of the logic in these files. skythought/train is excluded for now because it's a fork of Llamafactory
exclude: (^skythought/train/.*|^skythought/skythought-rl/.*|tasks/taco/pyext2\.py|tasks/taco/taco_util\.py|tasks/apps/apps_util\.py|scripts/prompts\.py|skythought/test-time-scaling/.*)$
exclude: (^skythought/train/.*|^skythought/skythought-rl/.*|pyext2\.py|taco_util\.py|apps_util\.py|scripts/prompts\.py|skythought/test-time-scaling/.*)$


# Black needs to be ran after ruff with --fix
- repo: https://github.com/psf/black
rev: 24.10.0
hooks:
- id: black
exclude: (^skythought/train/.*|^skythought/skythought-rl/.*|tasks/taco/pyext2\.py|skythought/test-time-scaling/.*)$
exclude: (^skythought/train/.*|^skythought/skythought-rl/.*|pyext2\.py|skythought/test-time-scaling/.*)$
44 changes: 0 additions & 44 deletions recipes/sky-t1-preview/postprocess.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
import copy
import json
from typing import Any, Dict

import numpy as np
import ray

from skythought.evals.scoring.base import Scorer
from skythought.evals.tasks.apps.apps_util import run_test as apps_run_test
from skythought.evals.util.common import has_code

STILL2_SYSTEM_PROMPT = "Your role as an assistant involves thoroughly exploring questions through a systematic long \
thinking process before providing the final precise and accurate solutions. This requires \
engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, \
Expand All @@ -30,41 +21,6 @@
Now, try to solve the following question through the above guidelines:"


class APPSScorer(Scorer):
def score(self, row: Dict[str, Any]):
TIMEOUT = 10
code_filter_result = has_code(row["response"])
if len(code_filter_result) == 0:
return False
else:
last_code = code_filter_result[-1]
problem_to_check = copy.deepcopy(row)
problem_to_check["input_output"] = json.loads(row["input_output"])
try:
problem_to_check["solutions"] = json.loads(row["solutions"])
except Exception:
problem_to_check["solutions"] = ""

@ray.remote
def _temp_run(problem, generation, debug):
try:
result = apps_run_test(problem=problem, test=generation, debug=debug)
return result
except Exception:
pass

result = ray.get(
_temp_run.remote(problem_to_check, last_code, False), timeout=TIMEOUT + 1
)

return bool(result and np.all(result[0]))


class TACOScorer(Scorer):
def score(self, row: Dict[str, Any]):
return True


def convert_to_sharegpt_format(row: Dict[str, Any]):
prompt = row["user_input"]
# accept
Expand Down
4 changes: 2 additions & 2 deletions recipes/sky-t1-preview/prompts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
convert_prompt_example = ( # noqa: E501
CONVERT_PROMPT_EXAMPLE = ( # noqa: E501
"<|begin_of_thought|>\n\n"
"Okay, so I've got this problem here. Mr. Wang leaves home at 6 AM, riding his bike at 12 km/h, "
"and he stops to rest for 6 minutes after every 30 minutes of riding. Then, when he arrives at a park "
Expand Down Expand Up @@ -65,5 +65,5 @@
"{example}\n"
"Important: You should almost copy all the contents word-by-word of the original solution. Just convert them into two sections. "
"Make sure you include: <|begin_of_slow_thought|>, <|end_of_slow_thought|>, <|begin_of_solution|>,<|end_of_solution|> These four headers explicitly. "
"Content to be converted: {{content}}".format(example=convert_prompt_example)
"Content to be converted: {content}"
)
145 changes: 75 additions & 70 deletions recipes/sky-t1-preview/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
vLLMEngineProcessorConfig,
)

from skythought.evals.scoring.apps import APPSScorer
from skythought.evals.scoring.math import MathEqualScorer
from skythought.evals.scoring.taco import TACOScorer

from .postprocess import APPSScorer, TACOScorer, convert_to_sharegpt_format
from .postprocess import convert_to_sharegpt_format
from .preprocess import APPSPreprocessor, NUMINAPreprocessor, TACOPreprocessor
from .prompts import CONVERT_PROMPT
from .prompts import CONVERT_PROMPT, CONVERT_PROMPT_EXAMPLE

parser = argparse.ArgumentParser()
parser.add_argument("--as-test", action="store_true")
Expand All @@ -26,11 +28,12 @@
SYSTEM_PROMPT = "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step." # noqa: E501

# 1. Load datasets
apps_ds = datasets.load_dataset("codeparrot/apps", split="test", streaming=True)
taco_ds_medium = datasets.load_dataset(
"BAAI/TACO", split="test", name="MEDIUM", streaming=True
apps_ds = datasets.load_dataset(
"codeparrot/apps",
split="test",
)
numina_ds = datasets.load_dataset("AI-MO/NuminaMath-CoT", split="train", streaming=True)
taco_ds_medium = datasets.load_dataset("BAAI/TACO", split="test", name="MEDIUM")
numina_ds = datasets.load_dataset("AI-MO/NuminaMath-CoT", split="train")

# convert all to ray dataset
apps_ds = ray.data.from_huggingface(apps_ds)
Expand All @@ -45,11 +48,11 @@


if args.as_test:
apps_ds = apps_ds.limit(100)
taco_ds_medium = taco_ds_medium.limit(100)
numina_ds_amc_aime = numina_ds_amc_aime.limit(100)
numina_ds_olympiads = numina_ds_olympiads.limit(100)
numina_ds_math = numina_ds_math.limit(100)
apps_ds = apps_ds.limit(5)
taco_ds_medium = taco_ds_medium.limit(5)
numina_ds_amc_aime = numina_ds_amc_aime.limit(5)
numina_ds_olympiads = numina_ds_olympiads.limit(5)
numina_ds_math = numina_ds_math.limit(5)

# 2. Get model responses for each of the datasets
datasets = [
Expand All @@ -69,10 +72,20 @@
NUMINAPreprocessor(),
]

numina_scorer = MathEqualScorer(
response_column="formatted_response", answer_column="solution"
)
scorers = [
APPSScorer(response_column="formatted_response"),
TACOScorer(response_column="formatted_response"),
numina_scorer,
numina_scorer,
numina_scorer,
]

for i, ds in enumerate(datasets):
datasets[i] = ds.map(preprocess_fns[i])

# our API
config = vLLMEngineProcessorConfig(
# model="Qwen/QwQ-32B-Preview",
model="Qwen/Qwen2-0.5B-Instruct",
Expand All @@ -85,12 +98,11 @@
batch_size=64,
)

# our API
processor = build_llm_processor(
config,
preprocess=lambda row: dict(
messages=[
SYSTEM_PROMPT,
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": row["user_input"]},
],
sampling_params=dict(
Expand All @@ -104,63 +116,56 @@
**row, # This will return all the original columns in the dataset.
),
)
# our API
datasets[i] = processor(ds)

# 3. Reformat the examples into a structured format
# define a configuration for the reformatter
config = HttpRequestProcessorConfig(
url="https://api.openai.com/v1/chat/completions",
headers={"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"},
# number of processors to run in parallel
# Each handles a batch of requests
concurrency=1,
)
# define the reformatter
reformatter = build_llm_processor(
config=config,
preprocess=lambda row: dict(
# define the payload / the exact arguments to the OpenAI chat completions API
payload=dict(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a solution format convertor."},
{
"role": "user",
"content": CONVERT_PROMPT.format(
content=f"{row['question']}\n{row['assistant_response']}"
),
},
],
temperature=0.7,
max_tokens=16384,
datasets[i] = processor(datasets[i])

# 3. Reformat the examples into a structured format
# define a configuration for the reformatter
config = HttpRequestProcessorConfig(
url="https://api.openai.com/v1/chat/completions",
headers={"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"},
# number of processors to run in parallel
# Each handles a batch of requests
concurrency=1,
)
# define the reformatter
reformatter = build_llm_processor(
config,
preprocess=lambda row: dict(
# define the payload / the exact arguments to the OpenAI chat completions API
payload=dict(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": "You are a solution format convertor.",
},
{
"role": "user",
"content": CONVERT_PROMPT.format(
example=CONVERT_PROMPT_EXAMPLE,
content=f"{row['question']}\n{row['assistant_response']}",
),
},
],
temperature=0.7,
max_tokens=2048,
),
),
),
postprocess=lambda row: dict(
formatted_response=row["http_response"]["choices"][0]["message"]["content"],
),
batch_size=64,
)

for i, dataset in enumerate(datasets):
datasets[i] = reformatter(dataset)


# 4. Rejection Sampling based on scoring
# apps, taco, numina-amc-aime, numina-olympiads, numina-math
numina_scorer = MathEqualScorer(
response_key="formatted_response", answer_key="solution"
)
scorers = [APPSScorer(), TACOScorer(), numina_scorer, numina_scorer, numina_scorer]
postprocess=lambda row: dict(
formatted_response=row["http_response"]["choices"][0]["message"]["content"],
**row,
),
)
datasets[i] = reformatter(datasets[i])

for i, dataset in enumerate(datasets):
fn = scorers[i]
datasets[i] = dataset.map(fn)
# 4. Rejection Sampling based on scoring
# apps, taco, numina-amc-aime, numina-olympiads, numina-math
datasets[i] = datasets[i].map(scorers[i])
# datasets[i] = datasets[i].filter(lambda x: x[scorers[i].SCORE_COLUMN])

# 5. Convert to ShareGPT format
for i, dataset in enumerate(datasets):
datasets[i] = dataset.map(convert_to_sharegpt_format)
# 5. Convert to ShareGPT format
datasets[i] = datasets[i].map(convert_to_sharegpt_format)

# 6. Union + Save datasets
datasets = datasets[0].union(*datasets[1:])
datasets.write_parquet("sky-t1-preview.parquet")
# 6. Save datasets
dir_name = f"sky-t1-preview-{i}_parquet"
datasets[i].write_parquet(os.path.abspath(dir_name))
6 changes: 2 additions & 4 deletions skythought/evals/scoring/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .base import Scorer
from .gsm8k import GSM8KScorer
from .ifeval import IfEvalScorer
from .livecodebench import LiveCodeBenchScorer
from .math import MathScorer
from .math import MathEqualScorer, MathVerifyScorer

__all__ = ["Scorer", "MathScorer", "GSM8KScorer", "LiveCodeBenchScorer", "IfEvalScorer"]
__all__ = ["Scorer", "MathEqualScorer", "MathVerifyScorer", "GSM8KScorer"]
3 changes: 3 additions & 0 deletions skythought/evals/scoring/apps/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .apps import APPSScorer

__all__ = ["APPSScorer"]
61 changes: 61 additions & 0 deletions skythought/evals/scoring/apps/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import copy
import json
from typing import Any, Dict

import numpy as np
import ray
from ray.exceptions import GetTimeoutError

from skythought.evals.scoring.base import Scorer
from skythought.evals.tasks.apps.apps_util import run_test as apps_run_test
from skythought.evals.util.common import has_code


class APPSScorer(Scorer):
SCORE_COLUMN = "apps_score"

def __init__(
self,
response_column="response",
answer_column="solutions",
input_column="input_output",
) -> None:
super().__init__()
self.response_column = response_column
self.answer_column = answer_column
self.input_column = input_column

def score(self, row: Dict[str, Any]):
TIMEOUT = 10
code_filter_result = has_code(row[self.response_column])
if len(code_filter_result) == 0:
return {self.SCORE_COLUMN: False}
else:
last_code = code_filter_result[-1]
problem_to_check = copy.deepcopy(row)
problem_to_check[self.input_column] = json.loads(row[self.input_column])
try:
problem_to_check[self.answer_column] = json.loads(
row[self.answer_column]
)
except Exception:
problem_to_check[self.answer_column] = ""

@ray.remote
def _temp_run(problem, generation, debug):
try:
result = apps_run_test(problem=problem, test=generation, debug=debug)
return result
except Exception:
pass

try:
result = ray.get(
_temp_run.remote(problem_to_check, last_code, False),
timeout=TIMEOUT + 1,
)
except GetTimeoutError:
result = []

score = bool(result and np.all(result[0]))
return {self.SCORE_COLUMN: score}
2 changes: 1 addition & 1 deletion skythought/evals/scoring/ifeval/instructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
"""Library of instructions."""
import collections
import json
import logging
import random
import re
import string
from typing import Dict, Optional, Sequence, Union

import langdetect
from absl import logging

from . import instructions_util

Expand Down
Loading

0 comments on commit 6fdca5c

Please sign in to comment.