Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/texify' into inline-math
Browse files Browse the repository at this point in the history
  • Loading branch information
tarun-menta committed Feb 2, 2025
2 parents ce867b8 + 597db72 commit af2c09e
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 113 deletions.
33 changes: 2 additions & 31 deletions marker/models.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,12 @@
import os

from marker.settings import settings

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS

from typing import List
from PIL import Image
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS

from surya.detection import DetectionPredictor, InlineDetectionPredictor
from surya.layout import LayoutPredictor
from surya.ocr_error import OCRErrorPredictor
from surya.recognition import RecognitionPredictor
from surya.table_rec import TableRecPredictor

from texify.model.model import load_model as load_texify_model
from texify.model.processor import load_processor as load_texify_processor
from texify.inference import batch_inference

class TexifyPredictor:
def __init__(self, device=None, dtype=None):
if not device:
device = settings.TORCH_DEVICE_MODEL
if not dtype:
dtype = settings.TEXIFY_DTYPE

self.model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype)
self.processor = load_texify_processor()
self.device = device
self.dtype = dtype

def __call__(self, batch_images: List[Image.Image], max_tokens: int):
return batch_inference(
batch_images,
self.model,
self.processor,
max_tokens=max_tokens
)
from surya.texify import TexifyPredictor


def create_model_dict(device=None, dtype=None) -> dict:
Expand Down
90 changes: 12 additions & 78 deletions marker/processors/equation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Annotated, List, Optional, Tuple

from texify.inference import batch_inference
Expand All @@ -24,7 +25,7 @@ class EquationProcessor(BaseProcessor):
model_max_length: Annotated[
int,
"The maximum number of tokens to allow for the Texify model.",
] = 384
] = 768
texify_batch_size: Annotated[
Optional[int],
"The batch size to use for the Texify model.",
Expand Down Expand Up @@ -69,29 +70,11 @@ def __call__(self, document: Document):

block = document.get_block(equation_d["block_id"])
if isinstance(block, Equation):
block.html = self.parse_latex_to_html(prediction)
elif isinstance(block, Span):
block.text = prediction.replace('$$', '$')

def parse_latex_to_html(self, latex: str):
html_out = ""
try:
latex = self.parse_latex(latex)
except ValueError as e:
# If we have mismatched delimiters, we'll treat it as a single block
# Strip the $'s from the latex
latex = [
{"class": "block", "content": latex.replace("$", "")}
]

for el in latex:
if el["class"] == "block":
html_out += f'<math display="block">{el["content"]}</math>'
elif el["class"] == "inline":
html_out += f'<math display="inline">{el["content"]}</math>'
block.html = prediction
elif isinstance(block, Span) and 'math' in block.formats:
block.text = re.sub(r"<math[^>]*>(.*?)</math>", r"$\1$", prediction)
else:
html_out += f" {el['content']} "
return html_out.strip()
raise ValueError(f"Unexpected block of type {block.block_type}")

def get_batch_size(self):
if self.texify_batch_size is not None:
Expand All @@ -112,71 +95,22 @@ def get_latex_batched(self, equation_data: List[dict]):
max_idx = min(min_idx + batch_size, len(equation_data))

batch_equations = equation_data[min_idx:max_idx]
max_length = max([eq["token_count"] for eq in batch_equations])
max_length = min(max_length, self.model_max_length)
max_length += self.token_buffer

batch_images = [eq["image"] for eq in batch_equations]

model_output = self.texify_model(
batch_images,
max_tokens=max_length
batch_images
)

for j, output in enumerate(model_output):
token_count = self.get_total_texify_tokens(output)
if token_count >= max_length - 1:
output = ""
token_count = self.get_total_texify_tokens(output.text)
if token_count >= self.model_max_length - 1:
output.text = ""

image_idx = i + j
predictions[image_idx] = output
predictions[image_idx] = output.text
return predictions

def get_total_texify_tokens(self, text):
tokenizer = self.texify_model.processor.tokenizer
tokens = tokenizer(text)
return len(tokens["input_ids"])


@staticmethod
def parse_latex(text: str):
if text.count("$") % 2 != 0:
raise ValueError("Mismatched delimiters in LaTeX")

DELIMITERS = [
("$$", "block"),
("$", "inline")
]

text = text.replace("\n", "<br>") # we can't handle \n's inside <p> properly if we don't do this

i = 0
stack = []
result = []
buffer = ""

while i < len(text):
for delim, class_name in DELIMITERS:
if text[i:].startswith(delim):
if stack and stack[-1] == delim: # Closing
stack.pop()
result.append({"class": class_name, "content": buffer})
buffer = ""
i += len(delim)
break
elif not stack: # Opening
if buffer:
result.append({"class": "text", "content": buffer})
stack.append(delim)
buffer = ""
i += len(delim)
break
else:
raise ValueError(f"Nested {class_name} delimiters not supported")
else: # No delimiter match
buffer += text[i]
i += 1

if buffer:
result.append({"class": "text", "content": buffer})
return result
return len(tokens["input_ids"])
11 changes: 8 additions & 3 deletions marker/renderers/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
from marker.schema.document import Document


def escape_dollars(text):
return text.replace("$", r"\$")

def cleanup_text(full_text):
full_text = re.sub(r'\n{3,}', '\n\n', full_text)
full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
return full_text.strip()

def get_formatted_table_text(element):

text = []
for content in element.contents:
if content is None:
Expand All @@ -26,13 +30,14 @@ def get_formatted_table_text(element):
if isinstance(content, NavigableString):
stripped = content.strip()
if stripped:
text.append(stripped)
text.append(escape_dollars(stripped))
elif content.name == 'br':
text.append('<br>')
elif content.name == "math":
text.append("$" + content.text + "$")
else:
text.append(str(content))
content_str = escape_dollars(str(content))
text.append(content_str)

full_text = ""
for i, t in enumerate(text):
Expand Down Expand Up @@ -120,7 +125,7 @@ def convert_table(self, el, text, convert_as_inline):
if r == 0 and c == 0:
grid[row_idx][col_idx] = value
else:
grid[row_idx + r][col_idx + c] = ''
grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan
except IndexError:
# Sometimes the colspan/rowspan predictions can overflow
print(f"Overflow in columns: {col_idx + c} >= {total_cols}")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ tqdm = "^4.66.1"
ftfy = "^6.1.1"
texify = "^0.2.1"
rapidfuzz = "^3.8.1"
surya-ocr = "~0.9.3"
surya-ocr = "~0.10.0"
regex = "^2024.4.28"
pdftext = "~0.5.1"
markdownify = "^0.13.1"
Expand Down

0 comments on commit af2c09e

Please sign in to comment.