From 4eae13e8e05d24d28e17a6e560e57634d006c09f Mon Sep 17 00:00:00 2001 From: Minsuk Kahng Date: Mon, 17 Jun 2024 16:08:13 +0000 Subject: [PATCH] No public description PiperOrigin-RevId: 644028999 --- python/src/llm_comparator/llm_judge_runner.py | 6 ++- .../rationale_bullet_generator.py | 51 ++++++++++--------- .../rationale_cluster_generator.py | 16 +++--- 3 files changed, 38 insertions(+), 35 deletions(-) diff --git a/python/src/llm_comparator/llm_judge_runner.py b/python/src/llm_comparator/llm_judge_runner.py index 6d92931..566278b 100644 --- a/python/src/llm_comparator/llm_judge_runner.py +++ b/python/src/llm_comparator/llm_judge_runner.py @@ -160,7 +160,11 @@ def parse_output(raw_output: str): if (rating_label := rating_label.text) is None: return None - score = self.rating_to_score_map[rating_label] + try: + score = self.rating_to_score_map[rating_label] + except KeyError: + print(f'LLM judge returned an unknown rating label: {rating_label}') + return None return (score, rating_label, rationale.strip(' \n')) max_example_index = max([ex['example_index'] for ex in inputs_for_judge]) diff --git a/python/src/llm_comparator/rationale_bullet_generator.py b/python/src/llm_comparator/rationale_bullet_generator.py index ecb6fdd..3fec704 100644 --- a/python/src/llm_comparator/rationale_bullet_generator.py +++ b/python/src/llm_comparator/rationale_bullet_generator.py @@ -19,7 +19,7 @@ class _BulletGeneratorInput(TypedDict): """Intermediate output for rationale bullet generator.""" selected_rationales: list[str] - ex_win_side: Literal['A', 'B'] + ex_win_side: Optional[Literal['A', 'B']] class RationaleBulletGenerator: @@ -80,29 +80,27 @@ def _prepare_inputs_for_generating_bullets( elif ex_score < b_wins_score: ex_win_side = 'B' else: - raise ValueError( - f'No winner for example with score: {ex_score}. For A to win, score' - f' > {a_wins_score}. For B to win, score < {b_wins_score}.' - ) + ex_win_side = None # Select rationales for ratings whose winners are the same as the winner # for the example. winners_rationales = [] - for rating in ex['individual_rater_scores']: - # Rewrite rationales for flipped cases. - rating['rationale'] = self._rewrite_flipped_ratings( - rating['rationale'], rating['is_flipped'] - ) - - if rating['score'] > a_wins_score: - rating_win_side = 'A' - elif rating['score'] < b_wins_score: - rating_win_side = 'B' - else: - continue - - if ex_win_side == rating_win_side: - winners_rationales.append(rating['rationale']) + if ex_win_side is not None: + for rating in ex['individual_rater_scores']: + # Rewrite rationales for flipped cases. + rating['rationale'] = self._rewrite_flipped_ratings( + rating['rationale'], rating['is_flipped'] + ) + + if rating['score'] > a_wins_score: + rating_win_side = 'A' + elif rating['score'] < b_wins_score: + rating_win_side = 'B' + else: + continue + + if ex_win_side == rating_win_side: + winners_rationales.append(rating['rationale']) inputs_for_generating_bullets.append({ 'selected_rationales': winners_rationales, @@ -169,12 +167,15 @@ def _generate_rationale_bullets_for_examples( rationale_bullets_for_examples = [] for input_for_generation in tqdm.auto.tqdm(inputs_for_generating_bullets): # Run LLMs to summarize several rationales into a set of short phrases. - output = self._generate_rationale_bullets_for_example( - input_for_generation['selected_rationales'], - input_for_generation['ex_win_side'], - ) + if input_for_generation['ex_win_side']: + output = self._generate_rationale_bullets_for_example( + input_for_generation['selected_rationales'], + input_for_generation['ex_win_side'], + ) - bullets = self._parse_xml_formatted_rationale_bullets(output) + bullets = self._parse_xml_formatted_rationale_bullets(output) + else: + bullets = [] rationale_bullets_for_examples.append(bullets) print('Done generating rationale bullets') diff --git a/python/src/llm_comparator/rationale_cluster_generator.py b/python/src/llm_comparator/rationale_cluster_generator.py index 6b79033..218ffcb 100644 --- a/python/src/llm_comparator/rationale_cluster_generator.py +++ b/python/src/llm_comparator/rationale_cluster_generator.py @@ -12,7 +12,6 @@ from llm_comparator import utils -_RationaleBullet = types.RationaleBullet _RationaleBulletWithClusterSimilarity = ( types.RationaleBulletWithClusterSimilarity ) @@ -34,13 +33,13 @@ def __init__( self._embedder = emb_model_helper def _flatten_rationales( - self, rationale_bullets_for_examples: Sequence[Sequence[_RationaleBullet]] + self, rationale_bullets_for_examples: Sequence[Sequence[str]] ) -> Sequence[str]: """Flatten rationale bullets and remove duplicates.""" flattened_rationales = [] for rationale_bullets_for_example in rationale_bullets_for_examples: - for bullet in rationale_bullets_for_example: - flattened_rationales.append(bullet['rationale']) + for rationale in rationale_bullets_for_example: + flattened_rationales.append(rationale) return list(set(flattened_rationales)) def _paraphrase_rationales( @@ -136,7 +135,7 @@ def _generate_cluster_titles( ) output = self._generator.predict( - prompt_for_clustering, temperatue=temperature_for_clustering + prompt_for_clustering, temperature=temperature_for_clustering ) output_parsed = utils.extract_xml_part(output, 'groups') @@ -185,15 +184,14 @@ def _compute_similarities_to_clusters( def _store_similarities_to_rationale_bullets( self, - rationale_bullets_for_examples: Sequence[Sequence[_RationaleBullet]], + rationale_bullets_for_examples: Sequence[Sequence[str]], similarities_for_rationales: Mapping[str, Sequence[float]], ) -> Sequence[Sequence[_RationaleBulletWithClusterSimilarity]]: """Store similarities to bullets by iterating over the nested lists.""" rationale_bullets_with_similarities = [] for rationale_bullets_for_example in rationale_bullets_for_examples: rationale_bullets_with_similarities_for_example = [] - for bullet in rationale_bullets_for_example: - rationale = bullet['rationale'] + for rationale in rationale_bullets_for_example: similarities = similarities_for_rationales[rationale] rationale_bullets_with_similarities_for_example.append( _RationaleBulletWithClusterSimilarity( @@ -210,7 +208,7 @@ def _store_similarities_to_rationale_bullets( def run( self, - rationale_bullets_for_examples: Sequence[Sequence[_RationaleBullet]], + rationale_bullets_for_examples: Sequence[Sequence[str]], num_clusters: int = 10, ) -> tuple[ Sequence[_RationaleCluster],