Skip to content

Commit

Permalink
Update layout prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 15, 2025
1 parent b0edbe2 commit dbe1fc4
Show file tree
Hide file tree
Showing 33 changed files with 186 additions and 40 deletions.
105 changes: 92 additions & 13 deletions benchmarks/table/table.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import os
from typing import List

import numpy as np

from marker.renderers.json import JSONOutput, JSONBlockOutput

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS

import base64
Expand All @@ -10,8 +16,9 @@
from tabulate import tabulate
import json
from bs4 import BeautifulSoup
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from pypdfium2._helpers.misc import PdfiumError
from marker.util import matrix_intersection_area

from marker.config.parser import ConfigParser
from marker.converters.table import TableConverter
Expand All @@ -27,13 +34,24 @@ def update_teds_score(result):
return result


def extract_tables(children: List[JSONBlockOutput]):
tables = []
for child in children:
if child.block_type == 'Table':
tables.append(child)
elif child.children:
tables.extend(extract_tables(child.children))
return tables


@click.command(help="Benchmark Table to HTML Conversion")
@click.argument("out_file", type=str)
@click.option("--dataset", type=str, default="datalab-to/fintabnet-test", help="Dataset to use")
@click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
def main(out_file: str, dataset: str, max_rows: int):
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
def main(out_file: str, dataset: str, max_rows: int, max_workers: int):
models = create_model_dict()
config_parser = ConfigParser({'output_format': 'html'})
config_parser = ConfigParser({'output_format': 'json'})
start = time.time()


Expand All @@ -45,6 +63,7 @@ def main(out_file: str, dataset: str, max_rows: int):
iterations = min(max_rows, len(dataset))

results = []
total_unaligned = 0
for i in tqdm(range(iterations), desc='Converting Tables'):
try:
row = dataset[i]
Expand All @@ -61,19 +80,74 @@ def main(out_file: str, dataset: str, max_rows: int):
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as temp_pdf_file:
temp_pdf_file.write(pdf_binary)
temp_pdf_file.seek(0)
marker_table_html = converter(temp_pdf_file.name).html
tqdm.disable = True
marker_json = converter(temp_pdf_file.name).children
tqdm.disable = False

marker_table_soup = BeautifulSoup(marker_table_html, 'html.parser')
marker_detected_tables = marker_table_soup.find_all('table')
if len(marker_detected_tables)==0:
if len(marker_json) == 0 or len(gt_tables) == 0:
print(f'No tables detected, skipping...')
total_unaligned += len(gt_tables)
continue

marker_tables = extract_tables(marker_json)
marker_table_boxes = [table.bbox for table in marker_tables]
page_bbox = marker_json[0].bbox

# Normalize the bboxes
for bbox in marker_table_boxes:
bbox[0] = bbox[0] / page_bbox[2]
bbox[1] = bbox[1] / page_bbox[3]
bbox[2] = bbox[2] / page_bbox[2]
bbox[3] = bbox[3] / page_bbox[3]

gt_boxes = [table['normalized_bbox'] for table in gt_tables]
gt_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in gt_boxes]
marker_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in marker_table_boxes]
table_alignments = matrix_intersection_area(gt_boxes, marker_table_boxes)

aligned_tables = []
used_tables = set()
unaligned_tables = set()
for table_idx, alignment in enumerate(table_alignments):
try:
max_area = np.max(alignment)
aligned_idx = np.argmax(alignment)
except ValueError:
# No alignment found
unaligned_tables.add(table_idx)
continue

if aligned_idx in used_tables:
# Marker table already aligned with another gt table
unaligned_tables.add(table_idx)
continue

# Gt table doesn't align well with any marker table
gt_table_pct = gt_areas[table_idx] / max_area
if not .75 < gt_table_pct < 1.25:
unaligned_tables.add(table_idx)
continue

# Marker table doesn't align with gt table
marker_table_pct = marker_areas[aligned_idx] / max_area
if not .75 < marker_table_pct < 1.25:
unaligned_tables.add(table_idx)
continue

aligned_tables.append(
(marker_tables[aligned_idx], gt_tables[table_idx])
)
used_tables.add(aligned_idx)

total_unaligned += len(unaligned_tables)

for marker_table_soup, gt_table in zip(marker_detected_tables, gt_tables):
for marker_table, gt_table in aligned_tables:
gt_table_html = gt_table['html']

#marker wraps the table in <tbody> which fintabnet data doesn't
marker_table_soup.find('tbody').unwrap()
#Fintabnet doesn't use th tags, need to be replaced for fair comparison
marker_table_soup = BeautifulSoup(marker_table.html, 'html.parser')
marker_table_soup.find('tbody').unwrap()
for th_tag in marker_table_soup.find_all('th'):
th_tag.name = 'td'
marker_table_html = str(marker_table_soup)
Expand All @@ -86,10 +160,15 @@ def main(out_file: str, dataset: str, max_rows: int):
print('Broken PDF, Skipping...')
continue

print(f"Total time: {time.time() - start}")
print(f"Total time: {time.time() - start}.")
print(f"Could not align {total_unaligned} tables from fintabnet.")

with ThreadPoolExecutor(max_workers=16) as executor:
results = list(tqdm(executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)))
with ProcessPoolExecutor(max_workers=max_workers) as executor:
results = list(
tqdm(
executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)
)
)
avg_score = sum([r["score"] for r in results]) / len(results)

headers = ["Avg score", "Total tables"]
Expand Down
1 change: 0 additions & 1 deletion convert_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from marker.config.parser import ConfigParser
from marker.config.printer import CustomClickPrinter
from marker.converters.pdf import PdfConverter
from marker.logger import configure_logging
from marker.models import create_model_dict
from marker.output import save_output
Expand Down
61 changes: 42 additions & 19 deletions marker/builders/llm_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class LLMLayoutBuilder(LayoutBuilder):
confidence_threshold: Annotated[
float,
"The confidence threshold to use for relabeling.",
] = 0.7
] = 0.75
picture_height_threshold: Annotated[
float,
"The height threshold for pictures that may actually be complex regions.",
Expand All @@ -57,38 +57,42 @@ class LLMLayoutBuilder(LayoutBuilder):
"Default is a string containing the Gemini relabelling prompt."
] = """You are a layout expert specializing in document analysis.
Your task is to relabel layout blocks in images to improve the accuracy of an existing layout model.
You will be provided with an image of a layout block and the top k predictions from the current model, along with their confidence scores.
You will be provided with an image of a layout block and the top k predictions from the current model, along with the per-label confidence scores.
Your job is to analyze the image and choose the single most appropriate label from the provided top k predictions.
Do not invent any new labels.
Carefully examine the image and consider the provided predictions.
Choose the label you believe is the most accurate representation of the layout block.
Carefully examine the image and consider the provided predictions. Take the model confidence scores into account. If the existing label is the most appropriate, you should not change it.
**Instructions**
1. Analyze the image and consider the provided top k predictions.
2. Write a short description of the image, and which of the potential labels you believe is the most accurate representation of the layout block.
3. Choose the single most appropriate label from the provided top k predictions.
Here are the top k predictions from the model followed by the image:
Here are descriptions of the layout blocks you can choose from:
{potential_labels}
Here are the top k predictions from the model:
{top_k}
"""
complex_relabeling_prompt: Annotated[
str,
"The prompt to use for complex relabelling blocks.",
"Default is a string containing the complex relabelling prompt."
] = """You are a layout expert specializing in document analysis.
Your task is to relabel layout blocks in images to improve the accuracy of an existing layout model.
You will be provided with an image of a layout block and some potential labels.
You will be provided with an image of a layout block and some potential labels that might be appropriate.
Your job is to analyze the image and choose the single most appropriate label from the provided labels.
Do not invent any new labels.
Carefully examine the image and consider the provided predictions.
Choose the label you believe is the most accurate representation of the layout block.
**Instructions**
1. Analyze the image and consider the potential labels.
2. Write a short description of the image, and which of the potential labels you believe is the most accurate representation of the layout block.
3. Choose the single most appropriate label from the provided labels.
Potential labels:
- Picture
- Table
- Form
- Figure - A graph or diagram with text.
- ComplexRegion - a complex region containing multiple text and other elements.
{potential_labels}
Respond only with one of `Figure`, `Picture`, `ComplexRegion`, `Table`, or `Form`.
Here is the image of the layout block:
"""

def __init__(self, layout_model: LayoutPredictor, ocr_error_model: OCRErrorPredictor, config=None):
Expand Down Expand Up @@ -126,22 +130,41 @@ def relabel_blocks(self, document: Document):
pbar.close()

def process_block_topk_relabeling(self, document: Document, page: PageGroup, block: Block):
topk = {str(k): round(v, 3) for k, v in block.top_k.items()}
topk_types = list(block.top_k.keys())
potential_labels = ""
for block_type in topk_types:
label_cls = get_block_class(block_type)
potential_labels += f"- `{block_type}` - {label_cls.block_description}\n"

topk = ""
for k,v in block.top_k.items():
topk += f"- `{k}` - Confidence {round(v, 3)}\n"

prompt = self.topk_relabelling_prompt.replace("{potential_labels}", potential_labels).replace("{top_k}", topk)
print(prompt)

prompt = self.topk_relabelling_prompt + '```json' + json.dumps(topk) + '```\n'
return self.process_block_relabeling(document, page, block, prompt)

def process_block_complex_relabeling(self, document: Document, page: PageGroup, block: Block):
complex_prompt = self.complex_relabeling_prompt
potential_labels = ""
for block_type in [BlockTypes.Figure, BlockTypes.Picture, BlockTypes.ComplexRegion, BlockTypes.Table, BlockTypes.Form]:
label_cls = get_block_class(block_type)
potential_labels += f"- `{block_type}` - {label_cls.block_description}\n"

complex_prompt = self.complex_relabeling_prompt.replace("{potential_labels}", potential_labels)
print(complex_prompt)
return self.process_block_relabeling(document, page, block, complex_prompt)

def process_block_relabeling(self, document: Document, page: PageGroup, block: Block, prompt: str):
image = self.extract_image(document, block)
response_schema = content.Schema(
type=content.Type.OBJECT,
enum=[],
required=["label"],
required=["image_description", "label"],
properties={
"image_description": content.Schema(
type=content.Type.STRING,
),
"label": content.Schema(
type=content.Type.STRING,
),
Expand Down
2 changes: 2 additions & 0 deletions marker/converters/table.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import cache
from typing import Tuple, List

from marker.builders.document import DocumentBuilder
Expand All @@ -23,6 +24,7 @@ class TableConverter(PdfConverter):
)
converter_block_types: List[BlockTypes] = (BlockTypes.Table, BlockTypes.Form, BlockTypes.TableOfContents)

@cache
def build_document(self, filepath: str):
provider_cls = provider_from_filepath(filepath)
layout_builder = self.resolve_dependencies(self.layout_builder_class)
Expand Down
23 changes: 17 additions & 6 deletions marker/processors/llm/llm_table_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class LLMTableMergeProcessor(BaseLLMProcessor):
int,
"The maximum distance between table edges for adjacency."
] = 20
column_gap_threshold: Annotated[
int,
"The maximum gap between columns to merge tables"
] = 50
gemini_table_merge_prompt: Annotated[
str,
"The prompt to use for rewriting text.",
Expand Down Expand Up @@ -133,10 +137,8 @@ def rewrite_blocks(self, document: Document):
for page in document.pages:
page_blocks = page.contained_blocks(document, self.block_types)
for block in page_blocks:
if prev_block is None:
subsequent_page_table = False
same_page_vertical_table = False
else:
merge_condition = False
if prev_block is not None:
prev_cells = prev_block.contained_blocks(document, (BlockTypes.TableCell,))
curr_cells = block.contained_blocks(document, (BlockTypes.TableCell,))
row_match = abs(self.get_row_count(prev_cells) - self.get_row_count(curr_cells)) < 5, # Similar number of rows
Expand All @@ -154,11 +156,20 @@ def rewrite_blocks(self, document: Document):
prev_block.page_id == block.page_id, # On the same page
(1 - self.vertical_table_height_threshold) < prev_block.polygon.height / block.polygon.height < (1 + self.vertical_table_height_threshold), # Similar height
abs(block.polygon.x_start - prev_block.polygon.x_end) < self.vertical_table_distance_threshold, # Close together in x
abs(block.polygon.y_start - prev_block.polygon.y_start) < self.vertical_table_distance_threshold, # Close together in y
row_match
])

if prev_block is not None and \
(subsequent_page_table or same_page_vertical_table):
same_page_new_column = all([
prev_block.page_id == block.page_id, # On the same page
abs(block.polygon.x_start - prev_block.polygon.x_end) < self.column_gap_threshold,
block.y_start < prev_block.y_end,
block.polygon.width * (1 - self.vertical_table_height_threshold) < prev_block.polygon.width < block.polygon.width * (1 + self.vertical_table_height_threshold), # Similar width
col_match
])
merge_condition = any([subsequent_page_table, same_page_vertical_table, same_page_new_column])

if prev_block is not None and merge_condition:
if prev_block not in table_run:
table_run.append(prev_block)
table_run.append(block)
Expand Down
5 changes: 4 additions & 1 deletion marker/processors/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __call__(self, document: Document):
colspan=cell.colspan,
row_id=cell.row_id,
col_id=cell.col_id,
is_header=cell.is_header,
is_header=bool(cell.is_header),
page_id=page.page_id,
)
page.add_full_block(cell_block)
Expand Down Expand Up @@ -133,6 +133,9 @@ def normalize_spaces(text):

def split_combined_rows(self, tables: List[TableResult]):
for table in tables:
if len(table.cells) == 0:
# Skip empty tables
continue
unique_rows = sorted(list(set([c.row_id for c in table.cells])))
new_cells = []
shift_up = 0
Expand Down
3 changes: 3 additions & 0 deletions marker/renderers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class JSONBlockOutput(BaseModel):
block_type: str
html: str
polygon: List[List[float]]
bbox: List[float]
children: List['JSONBlockOutput'] | None = None
section_hierarchy: Dict[int, str] | None = None
images: dict | None = None
Expand Down Expand Up @@ -52,6 +53,7 @@ def extract_json(self, document: Document, block_output: BlockOutput):
return JSONBlockOutput(
html=html,
polygon=block_output.polygon.polygon,
bbox=block_output.polygon.bbox,
id=str(block_output.id),
block_type=str(block_output.id.block_type),
images=images,
Expand All @@ -66,6 +68,7 @@ def extract_json(self, document: Document, block_output: BlockOutput):
return JSONBlockOutput(
html=block_output.html,
polygon=block_output.polygon.polygon,
bbox=block_output.polygon.bbox,
id=str(block_output.id),
block_type=str(block_output.id.block_type),
children=children,
Expand Down
1 change: 1 addition & 0 deletions marker/schema/blocks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def to_path(self):

class Block(BaseModel):
polygon: PolygonBox
block_description: str
block_type: Optional[BlockTypes] = None
block_id: Optional[int] = None
page_id: Optional[int] = None
Expand Down
1 change: 1 addition & 0 deletions marker/schema/blocks/caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

class Caption(Block):
block_type: BlockTypes = BlockTypes.Caption
block_description: str = "A text caption that is directly above or below an image or table. Only used for text describing the image or table. "

def assemble_html(self, document, child_blocks, parent_structure):
template = super().assemble_html(document, child_blocks, parent_structure)
Expand Down
Loading

0 comments on commit dbe1fc4

Please sign in to comment.