diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index 0fcb2380..5d49aa1c 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -4,11 +4,10 @@ on: [push]
env:
TORCH_DEVICE: "cpu"
- OCR_ENGINE: "surya"
jobs:
benchmark:
- runs-on: ubuntu-latest
+ runs-on: [ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
diff --git a/marker/models.py b/marker/models.py
index 908fb863..80dc254c 100644
--- a/marker/models.py
+++ b/marker/models.py
@@ -1,41 +1,12 @@
import os
-
-from marker.settings import settings
-
-os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
-
-from typing import List
-from PIL import Image
+os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS
from surya.detection import DetectionPredictor
from surya.layout import LayoutPredictor
from surya.ocr_error import OCRErrorPredictor
from surya.recognition import RecognitionPredictor
from surya.table_rec import TableRecPredictor
-
-from texify.model.model import load_model as load_texify_model
-from texify.model.processor import load_processor as load_texify_processor
-from texify.inference import batch_inference
-
-class TexifyPredictor:
- def __init__(self, device=None, dtype=None):
- if not device:
- device = settings.TORCH_DEVICE_MODEL
- if not dtype:
- dtype = settings.TEXIFY_DTYPE
-
- self.model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype)
- self.processor = load_texify_processor()
- self.device = device
- self.dtype = dtype
-
- def __call__(self, batch_images: List[Image.Image], max_tokens: int):
- return batch_inference(
- batch_images,
- self.model,
- self.processor,
- max_tokens=max_tokens
- )
+from surya.texify import TexifyPredictor
def create_model_dict(device=None, dtype=None) -> dict:
diff --git a/marker/processors/blockquote.py b/marker/processors/blockquote.py
index bc0b5bec..018a7ecf 100644
--- a/marker/processors/blockquote.py
+++ b/marker/processors/blockquote.py
@@ -63,4 +63,4 @@ def __call__(self, document: Document):
next_block.blockquote_level += 1
elif len(next_block.structure) >= 2 and (x_indent and y_indent):
next_block.blockquote = True
- next_block.blockquote_level = 1
+ next_block.blockquote_level = 1
\ No newline at end of file
diff --git a/marker/processors/equation.py b/marker/processors/equation.py
index 868f98a2..5f5be17c 100644
--- a/marker/processors/equation.py
+++ b/marker/processors/equation.py
@@ -22,7 +22,7 @@ class EquationProcessor(BaseProcessor):
model_max_length: Annotated[
int,
"The maximum number of tokens to allow for the Texify model.",
- ] = 384
+ ] = 768
texify_batch_size: Annotated[
Optional[int],
"The batch size to use for the Texify model.",
@@ -65,27 +65,7 @@ def __call__(self, document: Document):
continue
block = document.get_block(equation_d["block_id"])
- block.html = self.parse_latex_to_html(prediction)
-
- def parse_latex_to_html(self, latex: str):
- html_out = ""
- try:
- latex = self.parse_latex(latex)
- except ValueError as e:
- # If we have mismatched delimiters, we'll treat it as a single block
- # Strip the $'s from the latex
- latex = [
- {"class": "block", "content": latex.replace("$", "")}
- ]
-
- for el in latex:
- if el["class"] == "block":
- html_out += f''
- elif el["class"] == "inline":
- html_out += f''
- else:
- html_out += f" {el['content']} "
- return html_out.strip()
+ block.html = prediction
def get_batch_size(self):
if self.texify_batch_size is not None:
@@ -106,71 +86,22 @@ def get_latex_batched(self, equation_data: List[dict]):
max_idx = min(min_idx + batch_size, len(equation_data))
batch_equations = equation_data[min_idx:max_idx]
- max_length = max([eq["token_count"] for eq in batch_equations])
- max_length = min(max_length, self.model_max_length)
- max_length += self.token_buffer
-
batch_images = [eq["image"] for eq in batch_equations]
model_output = self.texify_model(
- batch_images,
- max_tokens=max_length
+ batch_images
)
for j, output in enumerate(model_output):
- token_count = self.get_total_texify_tokens(output)
- if token_count >= max_length - 1:
- output = ""
+ token_count = self.get_total_texify_tokens(output.text)
+ if token_count >= self.model_max_length - 1:
+ output.text = ""
image_idx = i + j
- predictions[image_idx] = output
+ predictions[image_idx] = output.text
return predictions
def get_total_texify_tokens(self, text):
tokenizer = self.texify_model.processor.tokenizer
tokens = tokenizer(text)
- return len(tokens["input_ids"])
-
-
- @staticmethod
- def parse_latex(text: str):
- if text.count("$") % 2 != 0:
- raise ValueError("Mismatched delimiters in LaTeX")
-
- DELIMITERS = [
- ("$$", "block"),
- ("$", "inline")
- ]
-
- text = text.replace("\n", "
") # we can't handle \n's inside
properly if we don't do this
-
- i = 0
- stack = []
- result = []
- buffer = ""
-
- while i < len(text):
- for delim, class_name in DELIMITERS:
- if text[i:].startswith(delim):
- if stack and stack[-1] == delim: # Closing
- stack.pop()
- result.append({"class": class_name, "content": buffer})
- buffer = ""
- i += len(delim)
- break
- elif not stack: # Opening
- if buffer:
- result.append({"class": "text", "content": buffer})
- stack.append(delim)
- buffer = ""
- i += len(delim)
- break
- else:
- raise ValueError(f"Nested {class_name} delimiters not supported")
- else: # No delimiter match
- buffer += text[i]
- i += 1
-
- if buffer:
- result.append({"class": "text", "content": buffer})
- return result
+ return len(tokens["input_ids"])
\ No newline at end of file
diff --git a/marker/processors/reference.py b/marker/processors/reference.py
index f3eb2d06..e38b55a9 100644
--- a/marker/processors/reference.py
+++ b/marker/processors/reference.py
@@ -7,7 +7,6 @@
from marker.schema.groups.list import ListGroup
from marker.schema.groups.table import TableGroup
from marker.schema.registry import get_block_class
-from marker.schema.groups.picture import PictureGroup
from marker.schema.groups.figure import FigureGroup
diff --git a/marker/processors/table.py b/marker/processors/table.py
index 1f455848..75b723c0 100644
--- a/marker/processors/table.py
+++ b/marker/processors/table.py
@@ -3,7 +3,6 @@
from copy import deepcopy
from typing import Annotated, List
from collections import Counter
-from PIL import ImageDraw
from ftfy import fix_text
from surya.detection import DetectionPredictor
@@ -53,6 +52,10 @@ class TableProcessor(BaseProcessor):
int,
"The number of workers to use for pdftext.",
] = 4
+ row_split_threshold: Annotated[
+ float,
+ "The percentage of rows that need to be split across the table before row splitting is active.",
+ ] = 0.5
def __init__(
self,
@@ -171,10 +174,7 @@ def split_combined_rows(self, tables: List[TableResult]):
# Skip empty tables
continue
unique_rows = sorted(list(set([c.row_id for c in table.cells])))
- new_cells = []
- shift_up = 0
- max_cell_id = max([c.cell_id for c in table.cells])
- new_cell_count = 0
+ row_info = []
for row in unique_rows:
# Cells in this row
# Deepcopy is because we do an in-place mutation later, and that can cause rows to shift to match rows in unique_rows
@@ -201,9 +201,25 @@ def split_combined_rows(self, tables: List[TableResult]):
len(line_lens_counter) == 2 and counter_keys[0] <= 1 and counter_keys[1] > 1 and line_lens_counter[counter_keys[0]] == 1, # Allow a single column with a single line - keys are the line lens, values are the counts
])
should_split = should_split_entire_row or should_split_partial_row
- if should_split:
- for i in range(0, max(line_lens)):
- for cell in row_cells:
+ row_info.append({
+ "should_split": should_split,
+ "row_cells": row_cells,
+ "line_lens": line_lens
+ })
+
+ # Don't split if we're not splitting most of the rows in the table. This avoids splitting stray multiline rows.
+ if sum([r["should_split"] for r in row_info]) / len(row_info) < self.row_split_threshold:
+ continue
+
+ new_cells = []
+ shift_up = 0
+ max_cell_id = max([c.cell_id for c in table.cells])
+ new_cell_count = 0
+ for row, item_info in zip(unique_rows, row_info):
+ max_lines = max(item_info["line_lens"])
+ if item_info["should_split"]:
+ for i in range(0, max_lines):
+ for cell in item_info["row_cells"]:
# Calculate height based on number of splits
split_height = cell.bbox[3] - cell.bbox[1]
current_bbox = [cell.bbox[0], cell.bbox[1] + i * split_height, cell.bbox[2], cell.bbox[1] + (i + 1) * split_height]
@@ -226,9 +242,10 @@ def split_combined_rows(self, tables: List[TableResult]):
new_cell_count += 1
# For each new row we add, shift up subsequent rows
- shift_up += line_lens[0] - 1
+ # The max is to account for partial rows
+ shift_up += max_lines - 1
else:
- for cell in row_cells:
+ for cell in item_info["row_cells"]:
cell.row_id += shift_up
new_cells.append(cell)
diff --git a/marker/renderers/markdown.py b/marker/renderers/markdown.py
index 0762ab3c..9a48fa40 100644
--- a/marker/renderers/markdown.py
+++ b/marker/renderers/markdown.py
@@ -12,12 +12,16 @@
from marker.schema.document import Document
+def escape_dollars(text):
+ return text.replace("$", r"\$")
+
def cleanup_text(full_text):
full_text = re.sub(r'\n{3,}', '\n\n', full_text)
full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
return full_text.strip()
def get_formatted_table_text(element):
+
text = []
for content in element.contents:
if content is None:
@@ -26,13 +30,14 @@ def get_formatted_table_text(element):
if isinstance(content, NavigableString):
stripped = content.strip()
if stripped:
- text.append(stripped)
+ text.append(escape_dollars(stripped))
elif content.name == 'br':
text.append('
')
elif content.name == "math":
text.append("$" + content.text + "$")
else:
- text.append(str(content))
+ content_str = escape_dollars(str(content))
+ text.append(content_str)
full_text = ""
for i, t in enumerate(text):
@@ -120,7 +125,7 @@ def convert_table(self, el, text, convert_as_inline):
if r == 0 and c == 0:
grid[row_idx][col_idx] = value
else:
- grid[row_idx + r][col_idx + c] = ''
+ grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan
except IndexError:
# Sometimes the colspan/rowspan predictions can overflow
print(f"Overflow in columns: {col_idx + c} >= {total_cols}")
@@ -176,6 +181,12 @@ def convert_span(self, el, text, convert_as_inline):
else:
return text
+ def escape(self, text):
+ text = super().escape(text)
+ if self.options['escape_dollars']:
+ text = text.replace('$', r'\$')
+ return text
+
class MarkdownOutput(BaseModel):
markdown: str
images: dict
@@ -198,6 +209,7 @@ def __call__(self, document: Document) -> MarkdownOutput:
escape_misc=False,
escape_underscores=False,
escape_asterisks=False,
+ escape_dollars=True,
sub_symbol="",
sup_symbol="",
inline_math_delimiters=self.inline_math_delimiters,
diff --git a/poetry.lock b/poetry.lock
index 1ddc9a0f..652f9c68 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -2371,13 +2371,13 @@ dill = ">=0.3.8"
[[package]]
name = "narwhals"
-version = "1.24.0"
+version = "1.24.1"
description = "Extremely lightweight compatibility layer between dataframe libraries"
optional = false
python-versions = ">=3.8"
files = [
- {file = "narwhals-1.24.0-py3-none-any.whl", hash = "sha256:73ff60578641059221de2e4f337bfdf0260378fb1553f787d27411602cfc5e72"},
- {file = "narwhals-1.24.0.tar.gz", hash = "sha256:23f0a05efbe29864d184842dd6bf11c044210bca1d443d6dbffe7e65a70bf063"},
+ {file = "narwhals-1.24.1-py3-none-any.whl", hash = "sha256:d8983fe14851c95d60576ddca37c094bd4ed24ab9ea98396844fb20ad9aaf184"},
+ {file = "narwhals-1.24.1.tar.gz", hash = "sha256:b09b8253d945f23cdb683a84685abf3afb9f96114d89e9f35dc876e143f65007"},
]
[package.extras]
@@ -4641,13 +4641,13 @@ snowflake = ["snowflake-connector-python (>=2.8.0)", "snowflake-snowpark-python[
[[package]]
name = "surya-ocr"
-version = "0.9.3"
+version = "0.10.0"
description = "OCR, layout, reading order, and table recognition in 90+ languages"
optional = false
python-versions = "<4.0,>=3.10"
files = [
- {file = "surya_ocr-0.9.3-py3-none-any.whl", hash = "sha256:6013131f3af004f93ab5422dfa8c49a83aa72beb2f8120fd59dca04803d98009"},
- {file = "surya_ocr-0.9.3.tar.gz", hash = "sha256:a69347a3c85c04d48e3df62d11f045dc13e22ab8b3efebfdae1dd94f05a25b99"},
+ {file = "surya_ocr-0.10.0-py3-none-any.whl", hash = "sha256:ccad25a308eefd61a21b2c97fc3f5b8364887e09f197a3aaa5fee30c03f81ae1"},
+ {file = "surya_ocr-0.10.0.tar.gz", hash = "sha256:966bc0c1aef346df42e458d2c1cbc95665004ea61020577e1656789107d09119"},
]
[package.dependencies]
@@ -5489,4 +5489,4 @@ propcache = ">=0.2.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "33297ed1b238e67f880534882876a29012559d3263ceba2ba0cdd738598af00c"
+content-hash = "d43373ff00de4feb00b0aed4fe98d2a84ecb5742d1a916cabbace5104f888d54"
diff --git a/pyproject.toml b/pyproject.toml
index d838b9e0..6878e04e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -27,7 +27,7 @@ tqdm = "^4.66.1"
ftfy = "^6.1.1"
texify = "^0.2.1"
rapidfuzz = "^3.8.1"
-surya-ocr = "~0.9.3"
+surya-ocr = "~0.10.0"
regex = "^2024.4.28"
pdftext = "~0.5.1"
markdownify = "^0.13.1"
diff --git a/tests/processors/test_table_processor.py b/tests/processors/test_table_processor.py
index 79224a58..72e2a04b 100644
--- a/tests/processors/test_table_processor.py
+++ b/tests/processors/test_table_processor.py
@@ -1,3 +1,5 @@
+from typing import List
+
import pytest
from marker.renderers.json import JSONRenderer
@@ -63,3 +65,15 @@ def test_ocr_table(pdf_document, detection_model, recognition_model, table_rec_m
table_output = renderer(pdf_document)
assert "1.2E-38" in table_output.markdown
+
+@pytest.mark.config({"page_range": [11]})
+def test_split_rows(pdf_document, detection_model, recognition_model, table_rec_model):
+ processor = TableProcessor(detection_model, recognition_model, table_rec_model)
+ processor(pdf_document)
+
+ table = pdf_document.contained_blocks((BlockTypes.Table,))[-1]
+ cells: List[TableCell] = table.contained_blocks(pdf_document, (BlockTypes.TableCell,))
+ unique_rows = len(set([cell.row_id for cell in cells]))
+ assert unique_rows == 6
+
+