Skip to content

Commit

Permalink
Additional fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 30, 2025
1 parent bbf4161 commit 9a8da13
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 30 deletions.
2 changes: 1 addition & 1 deletion benchmarks/overall/overall.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_method_scores(ds, model_dict, max_rows=None, score_func=marker_scoring_f
doc_type = sample["classification"]

try:
gt_html = [block["html"] for block in gt_blocks]
gt_html = [block["html"] for block in gt_blocks if len(block["html"]) > 0]
scores = score_func(model_dict, sample, gt_html, **kwargs)
except ValueError as e:
print(f"Error with sample {idx}: {e}")
Expand Down
33 changes: 30 additions & 3 deletions benchmarks/overall/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ def standardize_markdown(markdown):
markdown = re.sub(pattern, standardize_math, markdown)

# Replace image urls
pattern = r'!\[(.*?)\]\((.*?)(?:\?.*?width=(\d+).*?height=(\d+).*?)\)'
markdown = re.sub(pattern, r'![/api/placeholder]', markdown)
pattern = r'!\[(.*?)\]\((https?://[^\s\)]+)\)'
markdown = re.sub(pattern, r'![link]', markdown)
markdown = strip_latex_symbols(markdown)
markdown = replace_centered_lines(markdown)

# Clean up html tags
markdown = markdown.replace("<br>", "\n")
Expand All @@ -84,10 +86,35 @@ def standardize_markdown(markdown):
markdown = re.sub("\\.+", ".", markdown) # Replace repeated periods with a single period, like in table of contents
markdown = re.sub("#+", "#", markdown) # Replace repeated headers with a single header
markdown = re.sub(r"\$", "", markdown) # Remove equation delimiters
markdown = markdown.encode().decode('unicode-escape') # Decode unicode characters properly
markdown = markdown.encode().decode('unicode-escape', errors="ignore") # Decode unicode characters properly
return markdown.strip().lower()


def replace_centered_lines(text):
def replace_match(m):
content = m.group(0)
dash_count = content.count('-')
return '-' * dash_count

pattern = r':-+:'
return re.sub(pattern, replace_match, text)


def strip_latex_symbols(text):
# Handle short math mode sequences first - only match $ $ with brief content
text = re.sub(r'\$\s*\\?[a-zA-Z]+\d?\s*\$', '', text)

# Handle common patterns inside remaining math mode
patterns = [
r'\$\s*\\?[a-zA-Z]+\d?\s*\$', # \alpha or \alpha2 in math mode
r'\$\s*\d+\\[a-zA-Z]+\s*\$', # 45\circ in math mode
r'\$\s*[a-zA-Z0-9]\\[a-zA-Z]+\s*\$' # x\dagger in math mode
]

pattern = '|'.join(patterns)
return re.sub(pattern, '', text)


def standardize_math(match):
try:
delim = "$$" if match.group(0).startswith('$$') else "$"
Expand Down
34 changes: 26 additions & 8 deletions benchmarks/table/inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datasets
from typing import List

import numpy as np
from bs4 import BeautifulSoup
import pypdfium2 as pdfium
Expand All @@ -10,18 +11,27 @@
from marker.config.parser import ConfigParser
from marker.converters.table import TableConverter
from marker.models import create_model_dict
from marker.renderers.json import JSONBlockOutput
from marker.schema.polygon import PolygonBox
from marker.util import matrix_intersection_area


def extract_tables(children: List[JSONBlockOutput]):
tables = []
for child in children:
if child.block_type == 'Table':
tables.append(child)
elif child.children:
tables.extend(extract_tables(child.children))
return tables


def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool):
models = create_model_dict()
config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size, "disable_tqdm": True})
total_unaligned = 0
results = []

dataset = datasets.load_dataset(dataset, split='train')
dataset = dataset.shuffle(seed=0)

iterations = len(dataset)
if max_rows is not None:
iterations = min(max_rows, len(dataset))
Expand All @@ -45,7 +55,8 @@ def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, m
marker_json = converter(temp_pdf_file.name).children

doc = pdfium.PdfDocument(temp_pdf_file.name)
page_image = doc[0].render(scale=92 / 72).to_pil()
page_image = doc[0].render(scale=96/72).to_pil()
doc.close()

if len(marker_json) == 0 or len(gt_tables) == 0:
print(f'No tables detected, skipping...')
Expand All @@ -55,10 +66,17 @@ def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, m
marker_tables = extract_tables(marker_json)
marker_table_boxes = [table.bbox for table in marker_tables]
page_bbox = marker_json[0].bbox
w_scaler, h_scaler = page_image.width / page_bbox[2], page_image.height / page_bbox[3]

table_images = [
page_image.crop([bbox[0] * w_scaler, bbox[1] * h_scaler, bbox[2] * w_scaler, bbox[3] * h_scaler]) for bbox
in marker_table_boxes]
page_image.crop(
PolygonBox.from_bbox(bbox)
.rescale(
(page_bbox[2], page_bbox[3]), (page_image.width, page_image.height)
).bbox
)
for bbox
in marker_table_boxes
]

# Normalize the bboxes
for bbox in marker_table_boxes:
Expand Down
19 changes: 2 additions & 17 deletions benchmarks/table/table.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import os

from benchmarks.table.inference import inference_tables

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS

from pathlib import Path
Expand All @@ -15,11 +12,9 @@
from tabulate import tabulate
import json
from concurrent.futures import ProcessPoolExecutor
from marker.renderers.json import JSONBlockOutput
from marker.settings import settings

from marker.config.parser import ConfigParser
from marker.models import create_model_dict
from marker.settings import settings
from benchmarks.table.inference import inference_tables

from scoring import wrap_table_html, similarity_eval_html

Expand All @@ -31,16 +26,6 @@ def update_teds_score(result, prefix: str = "marker"):
return result


def extract_tables(children: List[JSONBlockOutput]):
tables = []
for child in children:
if child.block_type == 'Table':
tables.append(child)
elif child.children:
tables.extend(extract_tables(child.children))
return tables


@click.command(help="Benchmark Table to HTML Conversion")
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "table"), help="Output path for results.")
@click.option("--dataset", type=str, default="datalab-to/fintabnet_bench_marker", help="Dataset to use")
Expand Down
2 changes: 1 addition & 1 deletion marker/renderers/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def convert_table(self, el, text, convert_as_inline):
grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan
except IndexError:
# Sometimes the colspan/rowspan predictions can overflow
print(f"Overflow in columns: {col_idx + c} >= {total_cols}")
print(f"Overflow in columns: {col_idx + c} >= {total_cols} or rows: {row_idx + r} >= {total_rows}")
continue

col_idx += colspan
Expand Down

0 comments on commit 9a8da13

Please sign in to comment.