diff --git a/eureka_ml_insights/metrics/__init__.py b/eureka_ml_insights/metrics/__init__.py index 2b19ec1..e31b08a 100644 --- a/eureka_ml_insights/metrics/__init__.py +++ b/eureka_ml_insights/metrics/__init__.py @@ -28,6 +28,7 @@ SpatialAndLayoutReasoningMetric, ) +from .aime_metrics import NumericMatch __all__ = [ Metric, ClassicMetric, @@ -52,4 +53,5 @@ SumAggregator, MMMUMetric, MaxTokenF1ScoreMetric, + NumericMatch, ] diff --git a/eureka_ml_insights/metrics/aime_metrics.py b/eureka_ml_insights/metrics/aime_metrics.py new file mode 100644 index 0000000..8106bb6 --- /dev/null +++ b/eureka_ml_insights/metrics/aime_metrics.py @@ -0,0 +1,20 @@ +from tqdm.auto import tqdm + +from eureka_ml_insights.metrics.metrics_base import ClassicMetric + +import numpy as np + +class NumericMatch(ClassicMetric): + """This class checks for a numeric match.""" + eps = 1e-6 + def __evaluate__(self, answer_text, target_text, is_valid): + if not is_valid: + return "none" + try: + diff = np.abs(float(target_text)-float(answer_text)) + except: + return "none" + if diff PipelineConfig: + pipeline = super().configure_pipeline(model_config=model_config, resume_from=resume_from) + # data preprocessing + self.data_processing_comp.prompt_template_path=os.path.join( + os.path.dirname(__file__), "../prompt_templates/aime_templates/Template_1direct.jinja" + ) + return pipeline + class AIME_PIPELINE16Run(AIME_PIPELINE): """This class specifies the config for running AIME benchmark 5 repeated times""" @@ -312,3 +339,20 @@ def configure_pipeline( MultiplyTransform(n_repeats=1024) ) return pipeline + + +class AIME_PIPELINETag(AIME_PIPELINE): + """This class specifies the config for running AIME benchmark 5 repeated times""" + + def configure_pipeline( + self, model_config: ModelConfig, resume_from: str = None, **kwargs: dict[str, Any] + ) -> PipelineConfig: + pipeline = super().configure_pipeline(model_config=model_config, resume_from=resume_from) + # data preprocessing + self.data_processing_comp.prompt_template_path = os.path.join( + os.path.dirname(__file__), "../prompt_templates/aime_templates/Template_tag1.jinja" + ) + # Each query is tagged with one or more topics from arithmetic, algebra, counting, geometry, number theory, and probability and other. + # These topics follow the description on the official website: https://artofproblemsolving.com/wiki/index.php/American_Invitational_Mathematics_Examination?srsltid=AfmBOooSIQ8ua5aJX00ZtYCKDuOAB4I4c-YE9zr1xYZ86fq8x5RL2sEg. + # In their own words, "The AIME tests mathematical problem solving with arithmetic, algebra, counting, geometry, number theory, and probability and other secondary school math topics" + return pipeline