-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Neel/spatial map answer extraction #57
base: main
Are you sure you want to change the base?
Changes from 5 commits
51cc1db
bb2dd08
c4b6bb0
46e3573
09ca286
b890e1a
74107ad
c7b848e
ae456f1
c4784b3
82a4eb2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -157,25 +157,34 @@ def extract_answer_from_text_grid(text, question_type): | |
return None # Return None if no numbers are found | ||
|
||
|
||
def extract_answer_from_text_map(text, question_type, model_name): | ||
def extract_answer_from_text_map_and_maze(model_output_raw, options): | ||
""" | ||
Extracts the answer from the text based on specific patterns, | ||
and as a fallback, extracts the first number if no patterns match. | ||
The code is from: https://github.com/alvinmingwisc/spatial_reason_vlm/tree/main/eval, | ||
and included with minimal modifications. | ||
Extracts the answer from the text based on known model output patterns. | ||
Searches for both a letter and whole word answer and returns both as they are not | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh is this why you have added the OR metrics? If so, I don't think this justifies adding new metric classes, instead the answer extractor should return some combination of these two, maybe simply "x or y". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to do this another way. Especially if that means less new code. However, if let's say I return "x or y", I still need to a metric that knows how to check this. the current ones just look for case-sensitive or insensitive exact match. One though was to make one metric that can take a single or a list (instead of having two as now) and an optional parameter for how to combine, i.e. any ("or") or all ("and"). what do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now returning the "x or y" string |
||
always consistent. | ||
|
||
Args: | ||
- text (str): The text containing the model's answer. | ||
- question_type (str): The text containing the question type. | ||
- model_name (str): The model name. | ||
- model_output_raw (str): The text containing the model's answer. | ||
- options (str): The list of options. | ||
|
||
Returns: | ||
- str or None: The extracted answer, or None if no answer could be extracted. | ||
- str or None: The extracted answers, or empty strings if no answer could be extracted. | ||
""" | ||
# Mapping of textual numbers to their numeric equivalents | ||
|
||
# replace common subsitutions in model outputs | ||
|
||
model_output_parsed_letter = "" | ||
model_output_parsed = "" | ||
|
||
if not model_output_raw: | ||
return [model_output_parsed, model_output_parsed_letter] | ||
|
||
model_output_raw = re.sub(r"\bno objects\b", "0 objects", model_output_raw, re.IGNORECASE) | ||
model_output_raw = re.sub(r"\bnot\b", "no", model_output_raw, re.IGNORECASE) | ||
model_output_raw = re.sub(r"\bshould be\b", "is", model_output_raw, re.IGNORECASE) | ||
|
||
number_mapping = { | ||
"zero": 0, | ||
"no": 0, | ||
"zero": 0, | ||
"one": 1, | ||
"two": 2, | ||
"three": 3, | ||
|
@@ -187,127 +196,71 @@ def extract_answer_from_text_map(text, question_type, model_name): | |
"nine": 9, | ||
} | ||
|
||
dirs = ["southeast", "northeast", "northwest", "southwest"] | ||
dir_pattern = rf"\b({'|'.join(dirs)})\b" | ||
|
||
if text is None: | ||
return None | ||
|
||
question_id = int(re.search("[0-9]", re.search("Q[0-9]", question_type).group()).group()) | ||
|
||
if question_id == 0: | ||
direction_match = re.search(r"\b[A-D]\.\s*(" + "|".join(dirs) + r")\b", text, re.IGNORECASE) | ||
if direction_match: | ||
return direction_match.group(1).lower() | ||
|
||
match = re.search(dir_pattern, text, re.IGNORECASE) | ||
if match: | ||
return match.group(1) | ||
return None | ||
|
||
elif question_id == 1: | ||
match = re.search( | ||
rf"^([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|located\s+to\s+the\s+)({dir_pattern})", | ||
text, | ||
re.IGNORECASE, | ||
) | ||
|
||
if match: | ||
string = match.group(1) | ||
return string | ||
|
||
match = re.search(r"\b[A-D]\.\s*(.*)", text) # problem with extracting . | ||
|
||
if match: | ||
string = match.group(1) | ||
string = remove_redundancy(string) | ||
string = extract_before_is(string) | ||
return string | ||
|
||
match = re.search(r"\b([ABCD][.,]|[(][abcdABCD][)])\s*(.*?)(?=\sis\b|\.|,|<|$)", text) | ||
if match: | ||
answer = match.group(1).strip() | ||
# Remove trailing punctuation if any | ||
answer = re.sub(r"[\.,\?!<]+$", "", answer) | ||
return answer | ||
|
||
match = re.search( | ||
rf"Therefore, the object in the {dir_pattern} of [\w\s\'\']+ is ([\w\s\'\']+)", text, re.IGNORECASE | ||
) | ||
if match: | ||
string = match.group(2) | ||
return string | ||
|
||
if "claude" in model_name.lower(): | ||
match = re.search(rf"^([\w\s\'\']+?)\s+is\s+(to\s+the\s+)({dir_pattern})", text, re.IGNORECASE) | ||
if match: | ||
string = match.group(1) | ||
return string | ||
|
||
if "gemini" in model_name.lower(): | ||
patterns = [ | ||
rf"\*\*Concise Answer:\*\*\n([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|in\s+|located\s+to\s+the\s+)({dir_pattern})", | ||
rf"\*\*Answer:\*\*\s+([\w\s\'\']+?)\s+is\s+in\s+the\s+({dir_pattern})\s+of\s+([\w\s\'\']+)", | ||
r"\*\*Answer:\*\*\n([\w\s\'\']+)", | ||
r"\*\*Answer\*\*:\s+([\w\s\'\']+)", | ||
r"\*\*Answer:\*\*\s+([\w\s\'\']+)", | ||
] | ||
|
||
for pattern in patterns: | ||
match = re.search(pattern, text, re.IGNORECASE) | ||
if match: | ||
return match.group(1) | ||
|
||
if "gpt-4o" in model_name.lower() or "gpt4o" in model_name.lower(): | ||
match = re.search( | ||
rf"Concise Answer:\s+([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|in\s+|located\s+to\s+the\s+)({dir_pattern})", | ||
text, | ||
re.IGNORECASE, | ||
) | ||
if match: | ||
string = match.group(1) | ||
return string | ||
|
||
# If no match, check for an answer following "is", with specific end markers defined | ||
match = re.search(r"\bis\b\s+(.*?)(?=\.|,|<|$)", text) | ||
if match: | ||
answer = match.group(1).strip() | ||
# Remove trailing punctuation if any | ||
answer = re.sub(r"[\.,\?!<]+$", "", answer) | ||
return answer | ||
|
||
return None # Return None if no match is found | ||
|
||
elif question_id == 2: | ||
match = re.search(r"\b[A-D]\.\s*(\d+)", text) # match number only | ||
if match: | ||
return match.group(1) | ||
# Create a list to store all found numbers along with their positions | ||
found_numbers = [] | ||
|
||
# Check for textual numbers and their positions | ||
for text_num, num in number_mapping.items(): | ||
for match in re.finditer(rf"\b{text_num}\b", text, re.IGNORECASE): | ||
found_numbers.append((match.start(), num)) | ||
|
||
# Check for digit sequences and their positions, specifically ignoring list markers at the start | ||
# Exclude numbers following "\n\n" and directly followed by ". " | ||
text = re.sub(r"^\n\n\d+\.\s", "", text) # Remove the leading list marker if it exists | ||
|
||
for match in re.finditer(r"\d+", text): | ||
found_numbers.append((match.start(), int(match.group(0)))) | ||
|
||
# Sort found numbers by their positions (smallest position first) | ||
if found_numbers: | ||
found_numbers.sort(key=lambda x: x[0]) | ||
# Return the number associated with the earliest position | ||
return str(found_numbers[0][1]) | ||
return None | ||
|
||
else: | ||
raise ValueError(f"Question ID {question_id} is not supported.") | ||
|
||
return None # Return None if no numbers are found | ||
for k, v in number_mapping.items(): | ||
model_output_raw = re.sub(rf"\b{k}\b", str(v), model_output_raw, re.IGNORECASE) | ||
|
||
# get dict of options from options string | ||
options_dict = {x.split(".")[0].strip().lower():x.split(".")[1].strip().lower() for x in options} | ||
|
||
|
||
model_output_parsed_letter = "" | ||
model_output_parsed = "" | ||
|
||
answers = [v for k, v in options_dict.items()] | ||
answers_pattern = rf"\b({'|'.join(answers)})\b" | ||
|
||
if "Answer:".lower() in model_output_raw.lower(): | ||
pattern_letter = r"^\**Answer:\**\s+(\w)\. (\w+)" | ||
matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE) | ||
if matches: | ||
match_option = matches.group(1).lower() | ||
neelsj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if match_option in options_dict: | ||
model_output_parsed_letter = options_dict[match_option] | ||
else: | ||
model_output_parsed_letter = match_option | ||
|
||
pattern_phrase = r"Answer:\**\s+([^\n]+)" | ||
matches = re.search(pattern_phrase, model_output_raw, re.IGNORECASE) | ||
if matches: | ||
model_output_answer_line = matches.group(1) | ||
|
||
answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE) | ||
|
||
if answers_match: | ||
model_output_parsed = answers_match.group(1) | ||
else: | ||
letters = [k for k, v in options_dict.items()] | ||
letters_pattern = rf"\b({'|'.join(letters)})\b" | ||
letters_pattern_match = re.search(letters_pattern, model_output_answer_line, re.IGNORECASE) | ||
|
||
if letters_pattern_match: | ||
match_option = letters_pattern_match.group(1).lower() | ||
model_output_parsed_letter = options_dict[match_option] | ||
|
||
elif "answer is".lower() in model_output_raw.lower(): | ||
pattern_letter = r'answer is:*\s*\**([\w\d]+)[\s:.]*\**' | ||
|
||
# first look for a single letter answer | ||
matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE) | ||
if matches: | ||
match_option = matches.group(1).lower() | ||
if match_option in options_dict: | ||
model_output_parsed_letter = options_dict[match_option] | ||
else: | ||
model_output_parsed_letter = match_option | ||
|
||
# next look if any of the options names are present in the first line | ||
|
||
model_output_answer_line = model_output_raw.splitlines()[0] | ||
|
||
answers = [v for k, v in options_dict.items()] | ||
answers_pattern = rf"\b({'|'.join(answers)})\b" | ||
answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE) | ||
|
||
if answers_match: | ||
model_output_parsed = answers_match.group(1) | ||
|
||
return [model_output_parsed, model_output_parsed_letter] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you mind adding a test for the extractor to showcase/check all these cases that you have covered? regex is hard to read and review, a test would give me peace of mind. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added in a new file under tests/data_utils_tests |
||
|
||
|
||
def extract_answer_from_text_maze(text, question_type): | ||
|
@@ -443,43 +396,59 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: | |
) | ||
return df | ||
|
||
|
||
@dataclass | ||
class ExtractAnswerGrid(ExtractAnswer): | ||
"""This class is an answer extractor for the GRID benchmark.""" | ||
class ExtractQuestionOptions(DFTransformBase): | ||
"""This class is for extracting the option list from a prompt.""" | ||
|
||
answer_column_name: str | ||
extracted_answer_column_name: str | ||
question_type_column_name: str | ||
mode: str | ||
prompt_column_name: str | ||
extracted_options_column_name: str | ||
|
||
@abstractmethod | ||
def _parse_answer_function(self, answer_text, question_type): | ||
return extract_answer_from_text_grid(answer_text, question_type) | ||
def _extract_options_from_text_map(self, prompt): | ||
""" | ||
Extracts the options list from the text. | ||
neelsj marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
- text (str): The text containing the prompt. | ||
|
||
Returns: | ||
- str or None: The extracted list of options. | ||
""" | ||
|
||
# get list of options from prompt | ||
prompt_lines = prompt.splitlines() | ||
matches = [i for i, x in enumerate(prompt_lines) if "Available options:" in x] | ||
options = prompt_lines[matches[0]+1:matches[0]+5] | ||
|
||
return options | ||
|
||
def transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
df[self.extracted_options_column_name] = df[self.prompt_column_name].apply(self._extract_options_from_text_map) | ||
return df | ||
|
||
@dataclass | ||
class ExtractAnswerSpatialMap(ExtractAnswer): | ||
"""This class is an answer extractor for the SPATIAL_MAP benchmark.""" | ||
class ExtractAnswerGrid(ExtractAnswer): | ||
"""This class is an answer extractor for the GRID benchmark.""" | ||
|
||
answer_column_name: str | ||
extracted_answer_column_name: str | ||
question_type_column_name: str | ||
model_name: str | ||
mode: str | ||
|
||
@abstractmethod | ||
def _parse_answer_function(self, answer_text, question_type): | ||
return extract_answer_from_text_map(answer_text, question_type, self.model_name) | ||
return extract_answer_from_text_grid(answer_text, question_type) | ||
|
||
|
||
@dataclass | ||
class ExtractAnswerMaze(ExtractAnswer): | ||
"""This class is an answer extractor for the MAZE benchmark.""" | ||
class ExtractAnswerSpatialMapAndMaze(DFTransformBase): | ||
"""This class is an answer extractor for the SPATIAL_MAP and MAZE benchmark.""" | ||
|
||
answer_column_name: str | ||
extracted_answer_column_name: str | ||
question_type_column_name: str | ||
extracted_options_column_name: str | ||
|
||
@abstractmethod | ||
def _parse_answer_function(self, answer_text, question_type): | ||
return extract_answer_from_text_maze(answer_text, question_type) | ||
def transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
df[self.extracted_answer_column_name] = df.apply( | ||
lambda x: extract_answer_from_text_map_and_maze(x[self.answer_column_name], x[self.extracted_options_column_name]), axis=1 | ||
) | ||
return df |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add "ExtractQuestionOptions" so that when we run the formatters this import does not get removed (considered unused).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh, yes will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done and checked in