Skip to content

Commit

Permalink
Merge pull request #513 from VikParuchuri/texify
Browse files Browse the repository at this point in the history
Texify
  • Loading branch information
VikParuchuri authored Jan 29, 2025
2 parents 98bdbbb + b90ad5d commit 099a493
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 133 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ on: [push]

env:
TORCH_DEVICE: "cpu"
OCR_ENGINE: "surya"

jobs:
benchmark:
runs-on: ubuntu-latest
runs-on: [ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
Expand Down
33 changes: 2 additions & 31 deletions marker/models.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,12 @@
import os

from marker.settings import settings

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS

from typing import List
from PIL import Image
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS

from surya.detection import DetectionPredictor
from surya.layout import LayoutPredictor
from surya.ocr_error import OCRErrorPredictor
from surya.recognition import RecognitionPredictor
from surya.table_rec import TableRecPredictor

from texify.model.model import load_model as load_texify_model
from texify.model.processor import load_processor as load_texify_processor
from texify.inference import batch_inference

class TexifyPredictor:
def __init__(self, device=None, dtype=None):
if not device:
device = settings.TORCH_DEVICE_MODEL
if not dtype:
dtype = settings.TEXIFY_DTYPE

self.model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype)
self.processor = load_texify_processor()
self.device = device
self.dtype = dtype

def __call__(self, batch_images: List[Image.Image], max_tokens: int):
return batch_inference(
batch_images,
self.model,
self.processor,
max_tokens=max_tokens
)
from surya.texify import TexifyPredictor


def create_model_dict(device=None, dtype=None) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion marker/processors/blockquote.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ def __call__(self, document: Document):
next_block.blockquote_level += 1
elif len(next_block.structure) >= 2 and (x_indent and y_indent):
next_block.blockquote = True
next_block.blockquote_level = 1
next_block.blockquote_level = 1
85 changes: 8 additions & 77 deletions marker/processors/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class EquationProcessor(BaseProcessor):
model_max_length: Annotated[
int,
"The maximum number of tokens to allow for the Texify model.",
] = 384
] = 768
texify_batch_size: Annotated[
Optional[int],
"The batch size to use for the Texify model.",
Expand Down Expand Up @@ -65,27 +65,7 @@ def __call__(self, document: Document):
continue

block = document.get_block(equation_d["block_id"])
block.html = self.parse_latex_to_html(prediction)

def parse_latex_to_html(self, latex: str):
html_out = ""
try:
latex = self.parse_latex(latex)
except ValueError as e:
# If we have mismatched delimiters, we'll treat it as a single block
# Strip the $'s from the latex
latex = [
{"class": "block", "content": latex.replace("$", "")}
]

for el in latex:
if el["class"] == "block":
html_out += f'<math display="block">{el["content"]}</math>'
elif el["class"] == "inline":
html_out += f'<math display="inline">{el["content"]}</math>'
else:
html_out += f" {el['content']} "
return html_out.strip()
block.html = prediction

def get_batch_size(self):
if self.texify_batch_size is not None:
Expand All @@ -106,71 +86,22 @@ def get_latex_batched(self, equation_data: List[dict]):
max_idx = min(min_idx + batch_size, len(equation_data))

batch_equations = equation_data[min_idx:max_idx]
max_length = max([eq["token_count"] for eq in batch_equations])
max_length = min(max_length, self.model_max_length)
max_length += self.token_buffer

batch_images = [eq["image"] for eq in batch_equations]

model_output = self.texify_model(
batch_images,
max_tokens=max_length
batch_images
)

for j, output in enumerate(model_output):
token_count = self.get_total_texify_tokens(output)
if token_count >= max_length - 1:
output = ""
token_count = self.get_total_texify_tokens(output.text)
if token_count >= self.model_max_length - 1:
output.text = ""

image_idx = i + j
predictions[image_idx] = output
predictions[image_idx] = output.text
return predictions

def get_total_texify_tokens(self, text):
tokenizer = self.texify_model.processor.tokenizer
tokens = tokenizer(text)
return len(tokens["input_ids"])


@staticmethod
def parse_latex(text: str):
if text.count("$") % 2 != 0:
raise ValueError("Mismatched delimiters in LaTeX")

DELIMITERS = [
("$$", "block"),
("$", "inline")
]

text = text.replace("\n", "<br>") # we can't handle \n's inside <p> properly if we don't do this

i = 0
stack = []
result = []
buffer = ""

while i < len(text):
for delim, class_name in DELIMITERS:
if text[i:].startswith(delim):
if stack and stack[-1] == delim: # Closing
stack.pop()
result.append({"class": class_name, "content": buffer})
buffer = ""
i += len(delim)
break
elif not stack: # Opening
if buffer:
result.append({"class": "text", "content": buffer})
stack.append(delim)
buffer = ""
i += len(delim)
break
else:
raise ValueError(f"Nested {class_name} delimiters not supported")
else: # No delimiter match
buffer += text[i]
i += 1

if buffer:
result.append({"class": "text", "content": buffer})
return result
return len(tokens["input_ids"])
1 change: 0 additions & 1 deletion marker/processors/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from marker.schema.groups.list import ListGroup
from marker.schema.groups.table import TableGroup
from marker.schema.registry import get_block_class
from marker.schema.groups.picture import PictureGroup
from marker.schema.groups.figure import FigureGroup


Expand Down
37 changes: 27 additions & 10 deletions marker/processors/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from copy import deepcopy
from typing import Annotated, List
from collections import Counter
from PIL import ImageDraw

from ftfy import fix_text
from surya.detection import DetectionPredictor
Expand Down Expand Up @@ -53,6 +52,10 @@ class TableProcessor(BaseProcessor):
int,
"The number of workers to use for pdftext.",
] = 4
row_split_threshold: Annotated[
float,
"The percentage of rows that need to be split across the table before row splitting is active.",
] = 0.5

def __init__(
self,
Expand Down Expand Up @@ -171,10 +174,7 @@ def split_combined_rows(self, tables: List[TableResult]):
# Skip empty tables
continue
unique_rows = sorted(list(set([c.row_id for c in table.cells])))
new_cells = []
shift_up = 0
max_cell_id = max([c.cell_id for c in table.cells])
new_cell_count = 0
row_info = []
for row in unique_rows:
# Cells in this row
# Deepcopy is because we do an in-place mutation later, and that can cause rows to shift to match rows in unique_rows
Expand All @@ -201,9 +201,25 @@ def split_combined_rows(self, tables: List[TableResult]):
len(line_lens_counter) == 2 and counter_keys[0] <= 1 and counter_keys[1] > 1 and line_lens_counter[counter_keys[0]] == 1, # Allow a single column with a single line - keys are the line lens, values are the counts
])
should_split = should_split_entire_row or should_split_partial_row
if should_split:
for i in range(0, max(line_lens)):
for cell in row_cells:
row_info.append({
"should_split": should_split,
"row_cells": row_cells,
"line_lens": line_lens
})

# Don't split if we're not splitting most of the rows in the table. This avoids splitting stray multiline rows.
if sum([r["should_split"] for r in row_info]) / len(row_info) < self.row_split_threshold:
continue

new_cells = []
shift_up = 0
max_cell_id = max([c.cell_id for c in table.cells])
new_cell_count = 0
for row, item_info in zip(unique_rows, row_info):
max_lines = max(item_info["line_lens"])
if item_info["should_split"]:
for i in range(0, max_lines):
for cell in item_info["row_cells"]:
# Calculate height based on number of splits
split_height = cell.bbox[3] - cell.bbox[1]
current_bbox = [cell.bbox[0], cell.bbox[1] + i * split_height, cell.bbox[2], cell.bbox[1] + (i + 1) * split_height]
Expand All @@ -226,9 +242,10 @@ def split_combined_rows(self, tables: List[TableResult]):
new_cell_count += 1

# For each new row we add, shift up subsequent rows
shift_up += line_lens[0] - 1
# The max is to account for partial rows
shift_up += max_lines - 1
else:
for cell in row_cells:
for cell in item_info["row_cells"]:
cell.row_id += shift_up
new_cells.append(cell)

Expand Down
18 changes: 15 additions & 3 deletions marker/renderers/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
from marker.schema.document import Document


def escape_dollars(text):
return text.replace("$", r"\$")

def cleanup_text(full_text):
full_text = re.sub(r'\n{3,}', '\n\n', full_text)
full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
return full_text.strip()

def get_formatted_table_text(element):

text = []
for content in element.contents:
if content is None:
Expand All @@ -26,13 +30,14 @@ def get_formatted_table_text(element):
if isinstance(content, NavigableString):
stripped = content.strip()
if stripped:
text.append(stripped)
text.append(escape_dollars(stripped))
elif content.name == 'br':
text.append('<br>')
elif content.name == "math":
text.append("$" + content.text + "$")
else:
text.append(str(content))
content_str = escape_dollars(str(content))
text.append(content_str)

full_text = ""
for i, t in enumerate(text):
Expand Down Expand Up @@ -120,7 +125,7 @@ def convert_table(self, el, text, convert_as_inline):
if r == 0 and c == 0:
grid[row_idx][col_idx] = value
else:
grid[row_idx + r][col_idx + c] = ''
grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan
except IndexError:
# Sometimes the colspan/rowspan predictions can overflow
print(f"Overflow in columns: {col_idx + c} >= {total_cols}")
Expand Down Expand Up @@ -176,6 +181,12 @@ def convert_span(self, el, text, convert_as_inline):
else:
return text

def escape(self, text):
text = super().escape(text)
if self.options['escape_dollars']:
text = text.replace('$', r'\$')
return text

class MarkdownOutput(BaseModel):
markdown: str
images: dict
Expand All @@ -198,6 +209,7 @@ def __call__(self, document: Document) -> MarkdownOutput:
escape_misc=False,
escape_underscores=False,
escape_asterisks=False,
escape_dollars=True,
sub_symbol="<sub>",
sup_symbol="<sup>",
inline_math_delimiters=self.inline_math_delimiters,
Expand Down
14 changes: 7 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ tqdm = "^4.66.1"
ftfy = "^6.1.1"
texify = "^0.2.1"
rapidfuzz = "^3.8.1"
surya-ocr = "~0.9.3"
surya-ocr = "~0.10.0"
regex = "^2024.4.28"
pdftext = "~0.5.1"
markdownify = "^0.13.1"
Expand Down
14 changes: 14 additions & 0 deletions tests/processors/test_table_processor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import pytest
from marker.renderers.json import JSONRenderer

Expand Down Expand Up @@ -63,3 +65,15 @@ def test_ocr_table(pdf_document, detection_model, recognition_model, table_rec_m
table_output = renderer(pdf_document)
assert "1.2E-38" in table_output.markdown


@pytest.mark.config({"page_range": [11]})
def test_split_rows(pdf_document, detection_model, recognition_model, table_rec_model):
processor = TableProcessor(detection_model, recognition_model, table_rec_model)
processor(pdf_document)

table = pdf_document.contained_blocks((BlockTypes.Table,))[-1]
cells: List[TableCell] = table.contained_blocks(pdf_document, (BlockTypes.TableCell,))
unique_rows = len(set([cell.row_id for cell in cells]))
assert unique_rows == 6


0 comments on commit 099a493

Please sign in to comment.