Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 644028999
  • Loading branch information
minsukkahng authored and RyanMullins committed Jun 27, 2024
1 parent 6a37c54 commit 4eae13e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 35 deletions.
6 changes: 5 additions & 1 deletion python/src/llm_comparator/llm_judge_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ def parse_output(raw_output: str):
if (rating_label := rating_label.text) is None:
return None

score = self.rating_to_score_map[rating_label]
try:
score = self.rating_to_score_map[rating_label]
except KeyError:
print(f'LLM judge returned an unknown rating label: {rating_label}')
return None
return (score, rating_label, rationale.strip(' \n'))

max_example_index = max([ex['example_index'] for ex in inputs_for_judge])
Expand Down
51 changes: 26 additions & 25 deletions python/src/llm_comparator/rationale_bullet_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class _BulletGeneratorInput(TypedDict):
"""Intermediate output for rationale bullet generator."""

selected_rationales: list[str]
ex_win_side: Literal['A', 'B']
ex_win_side: Optional[Literal['A', 'B']]


class RationaleBulletGenerator:
Expand Down Expand Up @@ -80,29 +80,27 @@ def _prepare_inputs_for_generating_bullets(
elif ex_score < b_wins_score:
ex_win_side = 'B'
else:
raise ValueError(
f'No winner for example with score: {ex_score}. For A to win, score'
f' > {a_wins_score}. For B to win, score < {b_wins_score}.'
)
ex_win_side = None

# Select rationales for ratings whose winners are the same as the winner
# for the example.
winners_rationales = []
for rating in ex['individual_rater_scores']:
# Rewrite rationales for flipped cases.
rating['rationale'] = self._rewrite_flipped_ratings(
rating['rationale'], rating['is_flipped']
)

if rating['score'] > a_wins_score:
rating_win_side = 'A'
elif rating['score'] < b_wins_score:
rating_win_side = 'B'
else:
continue

if ex_win_side == rating_win_side:
winners_rationales.append(rating['rationale'])
if ex_win_side is not None:
for rating in ex['individual_rater_scores']:
# Rewrite rationales for flipped cases.
rating['rationale'] = self._rewrite_flipped_ratings(
rating['rationale'], rating['is_flipped']
)

if rating['score'] > a_wins_score:
rating_win_side = 'A'
elif rating['score'] < b_wins_score:
rating_win_side = 'B'
else:
continue

if ex_win_side == rating_win_side:
winners_rationales.append(rating['rationale'])

inputs_for_generating_bullets.append({
'selected_rationales': winners_rationales,
Expand Down Expand Up @@ -169,12 +167,15 @@ def _generate_rationale_bullets_for_examples(
rationale_bullets_for_examples = []
for input_for_generation in tqdm.auto.tqdm(inputs_for_generating_bullets):
# Run LLMs to summarize several rationales into a set of short phrases.
output = self._generate_rationale_bullets_for_example(
input_for_generation['selected_rationales'],
input_for_generation['ex_win_side'],
)
if input_for_generation['ex_win_side']:
output = self._generate_rationale_bullets_for_example(
input_for_generation['selected_rationales'],
input_for_generation['ex_win_side'],
)

bullets = self._parse_xml_formatted_rationale_bullets(output)
bullets = self._parse_xml_formatted_rationale_bullets(output)
else:
bullets = []
rationale_bullets_for_examples.append(bullets)

print('Done generating rationale bullets')
Expand Down
16 changes: 7 additions & 9 deletions python/src/llm_comparator/rationale_cluster_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from llm_comparator import utils


_RationaleBullet = types.RationaleBullet
_RationaleBulletWithClusterSimilarity = (
types.RationaleBulletWithClusterSimilarity
)
Expand All @@ -34,13 +33,13 @@ def __init__(
self._embedder = emb_model_helper

def _flatten_rationales(
self, rationale_bullets_for_examples: Sequence[Sequence[_RationaleBullet]]
self, rationale_bullets_for_examples: Sequence[Sequence[str]]
) -> Sequence[str]:
"""Flatten rationale bullets and remove duplicates."""
flattened_rationales = []
for rationale_bullets_for_example in rationale_bullets_for_examples:
for bullet in rationale_bullets_for_example:
flattened_rationales.append(bullet['rationale'])
for rationale in rationale_bullets_for_example:
flattened_rationales.append(rationale)
return list(set(flattened_rationales))

def _paraphrase_rationales(
Expand Down Expand Up @@ -136,7 +135,7 @@ def _generate_cluster_titles(
)

output = self._generator.predict(
prompt_for_clustering, temperatue=temperature_for_clustering
prompt_for_clustering, temperature=temperature_for_clustering
)

output_parsed = utils.extract_xml_part(output, 'groups')
Expand Down Expand Up @@ -185,15 +184,14 @@ def _compute_similarities_to_clusters(

def _store_similarities_to_rationale_bullets(
self,
rationale_bullets_for_examples: Sequence[Sequence[_RationaleBullet]],
rationale_bullets_for_examples: Sequence[Sequence[str]],
similarities_for_rationales: Mapping[str, Sequence[float]],
) -> Sequence[Sequence[_RationaleBulletWithClusterSimilarity]]:
"""Store similarities to bullets by iterating over the nested lists."""
rationale_bullets_with_similarities = []
for rationale_bullets_for_example in rationale_bullets_for_examples:
rationale_bullets_with_similarities_for_example = []
for bullet in rationale_bullets_for_example:
rationale = bullet['rationale']
for rationale in rationale_bullets_for_example:
similarities = similarities_for_rationales[rationale]
rationale_bullets_with_similarities_for_example.append(
_RationaleBulletWithClusterSimilarity(
Expand All @@ -210,7 +208,7 @@ def _store_similarities_to_rationale_bullets(

def run(
self,
rationale_bullets_for_examples: Sequence[Sequence[_RationaleBullet]],
rationale_bullets_for_examples: Sequence[Sequence[str]],
num_clusters: int = 10,
) -> tuple[
Sequence[_RationaleCluster],
Expand Down

0 comments on commit 4eae13e

Please sign in to comment.