diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 0fcb2380..5d49aa1c 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -4,11 +4,10 @@ on: [push] env: TORCH_DEVICE: "cpu" - OCR_ENGINE: "surya" jobs: benchmark: - runs-on: ubuntu-latest + runs-on: [ubuntu-latest, windows-latest] steps: - uses: actions/checkout@v3 - name: Set up Python 3.11 diff --git a/marker/models.py b/marker/models.py index 908fb863..80dc254c 100644 --- a/marker/models.py +++ b/marker/models.py @@ -1,41 +1,12 @@ import os - -from marker.settings import settings - -os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS - -from typing import List -from PIL import Image +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS from surya.detection import DetectionPredictor from surya.layout import LayoutPredictor from surya.ocr_error import OCRErrorPredictor from surya.recognition import RecognitionPredictor from surya.table_rec import TableRecPredictor - -from texify.model.model import load_model as load_texify_model -from texify.model.processor import load_processor as load_texify_processor -from texify.inference import batch_inference - -class TexifyPredictor: - def __init__(self, device=None, dtype=None): - if not device: - device = settings.TORCH_DEVICE_MODEL - if not dtype: - dtype = settings.TEXIFY_DTYPE - - self.model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=device, dtype=dtype) - self.processor = load_texify_processor() - self.device = device - self.dtype = dtype - - def __call__(self, batch_images: List[Image.Image], max_tokens: int): - return batch_inference( - batch_images, - self.model, - self.processor, - max_tokens=max_tokens - ) +from surya.texify import TexifyPredictor def create_model_dict(device=None, dtype=None) -> dict: diff --git a/marker/processors/blockquote.py b/marker/processors/blockquote.py index bc0b5bec..018a7ecf 100644 --- a/marker/processors/blockquote.py +++ b/marker/processors/blockquote.py @@ -63,4 +63,4 @@ def __call__(self, document: Document): next_block.blockquote_level += 1 elif len(next_block.structure) >= 2 and (x_indent and y_indent): next_block.blockquote = True - next_block.blockquote_level = 1 + next_block.blockquote_level = 1 \ No newline at end of file diff --git a/marker/processors/equation.py b/marker/processors/equation.py index 868f98a2..5f5be17c 100644 --- a/marker/processors/equation.py +++ b/marker/processors/equation.py @@ -22,7 +22,7 @@ class EquationProcessor(BaseProcessor): model_max_length: Annotated[ int, "The maximum number of tokens to allow for the Texify model.", - ] = 384 + ] = 768 texify_batch_size: Annotated[ Optional[int], "The batch size to use for the Texify model.", @@ -65,27 +65,7 @@ def __call__(self, document: Document): continue block = document.get_block(equation_d["block_id"]) - block.html = self.parse_latex_to_html(prediction) - - def parse_latex_to_html(self, latex: str): - html_out = "" - try: - latex = self.parse_latex(latex) - except ValueError as e: - # If we have mismatched delimiters, we'll treat it as a single block - # Strip the $'s from the latex - latex = [ - {"class": "block", "content": latex.replace("$", "")} - ] - - for el in latex: - if el["class"] == "block": - html_out += f'{el["content"]}' - elif el["class"] == "inline": - html_out += f'{el["content"]}' - else: - html_out += f" {el['content']} " - return html_out.strip() + block.html = prediction def get_batch_size(self): if self.texify_batch_size is not None: @@ -106,71 +86,22 @@ def get_latex_batched(self, equation_data: List[dict]): max_idx = min(min_idx + batch_size, len(equation_data)) batch_equations = equation_data[min_idx:max_idx] - max_length = max([eq["token_count"] for eq in batch_equations]) - max_length = min(max_length, self.model_max_length) - max_length += self.token_buffer - batch_images = [eq["image"] for eq in batch_equations] model_output = self.texify_model( - batch_images, - max_tokens=max_length + batch_images ) for j, output in enumerate(model_output): - token_count = self.get_total_texify_tokens(output) - if token_count >= max_length - 1: - output = "" + token_count = self.get_total_texify_tokens(output.text) + if token_count >= self.model_max_length - 1: + output.text = "" image_idx = i + j - predictions[image_idx] = output + predictions[image_idx] = output.text return predictions def get_total_texify_tokens(self, text): tokenizer = self.texify_model.processor.tokenizer tokens = tokenizer(text) - return len(tokens["input_ids"]) - - - @staticmethod - def parse_latex(text: str): - if text.count("$") % 2 != 0: - raise ValueError("Mismatched delimiters in LaTeX") - - DELIMITERS = [ - ("$$", "block"), - ("$", "inline") - ] - - text = text.replace("\n", "
") # we can't handle \n's inside

properly if we don't do this - - i = 0 - stack = [] - result = [] - buffer = "" - - while i < len(text): - for delim, class_name in DELIMITERS: - if text[i:].startswith(delim): - if stack and stack[-1] == delim: # Closing - stack.pop() - result.append({"class": class_name, "content": buffer}) - buffer = "" - i += len(delim) - break - elif not stack: # Opening - if buffer: - result.append({"class": "text", "content": buffer}) - stack.append(delim) - buffer = "" - i += len(delim) - break - else: - raise ValueError(f"Nested {class_name} delimiters not supported") - else: # No delimiter match - buffer += text[i] - i += 1 - - if buffer: - result.append({"class": "text", "content": buffer}) - return result + return len(tokens["input_ids"]) \ No newline at end of file diff --git a/marker/processors/reference.py b/marker/processors/reference.py index f3eb2d06..e38b55a9 100644 --- a/marker/processors/reference.py +++ b/marker/processors/reference.py @@ -7,7 +7,6 @@ from marker.schema.groups.list import ListGroup from marker.schema.groups.table import TableGroup from marker.schema.registry import get_block_class -from marker.schema.groups.picture import PictureGroup from marker.schema.groups.figure import FigureGroup diff --git a/marker/processors/table.py b/marker/processors/table.py index 1f455848..75b723c0 100644 --- a/marker/processors/table.py +++ b/marker/processors/table.py @@ -3,7 +3,6 @@ from copy import deepcopy from typing import Annotated, List from collections import Counter -from PIL import ImageDraw from ftfy import fix_text from surya.detection import DetectionPredictor @@ -53,6 +52,10 @@ class TableProcessor(BaseProcessor): int, "The number of workers to use for pdftext.", ] = 4 + row_split_threshold: Annotated[ + float, + "The percentage of rows that need to be split across the table before row splitting is active.", + ] = 0.5 def __init__( self, @@ -171,10 +174,7 @@ def split_combined_rows(self, tables: List[TableResult]): # Skip empty tables continue unique_rows = sorted(list(set([c.row_id for c in table.cells]))) - new_cells = [] - shift_up = 0 - max_cell_id = max([c.cell_id for c in table.cells]) - new_cell_count = 0 + row_info = [] for row in unique_rows: # Cells in this row # Deepcopy is because we do an in-place mutation later, and that can cause rows to shift to match rows in unique_rows @@ -201,9 +201,25 @@ def split_combined_rows(self, tables: List[TableResult]): len(line_lens_counter) == 2 and counter_keys[0] <= 1 and counter_keys[1] > 1 and line_lens_counter[counter_keys[0]] == 1, # Allow a single column with a single line - keys are the line lens, values are the counts ]) should_split = should_split_entire_row or should_split_partial_row - if should_split: - for i in range(0, max(line_lens)): - for cell in row_cells: + row_info.append({ + "should_split": should_split, + "row_cells": row_cells, + "line_lens": line_lens + }) + + # Don't split if we're not splitting most of the rows in the table. This avoids splitting stray multiline rows. + if sum([r["should_split"] for r in row_info]) / len(row_info) < self.row_split_threshold: + continue + + new_cells = [] + shift_up = 0 + max_cell_id = max([c.cell_id for c in table.cells]) + new_cell_count = 0 + for row, item_info in zip(unique_rows, row_info): + max_lines = max(item_info["line_lens"]) + if item_info["should_split"]: + for i in range(0, max_lines): + for cell in item_info["row_cells"]: # Calculate height based on number of splits split_height = cell.bbox[3] - cell.bbox[1] current_bbox = [cell.bbox[0], cell.bbox[1] + i * split_height, cell.bbox[2], cell.bbox[1] + (i + 1) * split_height] @@ -226,9 +242,10 @@ def split_combined_rows(self, tables: List[TableResult]): new_cell_count += 1 # For each new row we add, shift up subsequent rows - shift_up += line_lens[0] - 1 + # The max is to account for partial rows + shift_up += max_lines - 1 else: - for cell in row_cells: + for cell in item_info["row_cells"]: cell.row_id += shift_up new_cells.append(cell) diff --git a/marker/renderers/markdown.py b/marker/renderers/markdown.py index 0762ab3c..9a48fa40 100644 --- a/marker/renderers/markdown.py +++ b/marker/renderers/markdown.py @@ -12,12 +12,16 @@ from marker.schema.document import Document +def escape_dollars(text): + return text.replace("$", r"\$") + def cleanup_text(full_text): full_text = re.sub(r'\n{3,}', '\n\n', full_text) full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text) return full_text.strip() def get_formatted_table_text(element): + text = [] for content in element.contents: if content is None: @@ -26,13 +30,14 @@ def get_formatted_table_text(element): if isinstance(content, NavigableString): stripped = content.strip() if stripped: - text.append(stripped) + text.append(escape_dollars(stripped)) elif content.name == 'br': text.append('
') elif content.name == "math": text.append("$" + content.text + "$") else: - text.append(str(content)) + content_str = escape_dollars(str(content)) + text.append(content_str) full_text = "" for i, t in enumerate(text): @@ -120,7 +125,7 @@ def convert_table(self, el, text, convert_as_inline): if r == 0 and c == 0: grid[row_idx][col_idx] = value else: - grid[row_idx + r][col_idx + c] = '' + grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan except IndexError: # Sometimes the colspan/rowspan predictions can overflow print(f"Overflow in columns: {col_idx + c} >= {total_cols}") @@ -176,6 +181,12 @@ def convert_span(self, el, text, convert_as_inline): else: return text + def escape(self, text): + text = super().escape(text) + if self.options['escape_dollars']: + text = text.replace('$', r'\$') + return text + class MarkdownOutput(BaseModel): markdown: str images: dict @@ -198,6 +209,7 @@ def __call__(self, document: Document) -> MarkdownOutput: escape_misc=False, escape_underscores=False, escape_asterisks=False, + escape_dollars=True, sub_symbol="", sup_symbol="", inline_math_delimiters=self.inline_math_delimiters, diff --git a/poetry.lock b/poetry.lock index 1ddc9a0f..652f9c68 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2371,13 +2371,13 @@ dill = ">=0.3.8" [[package]] name = "narwhals" -version = "1.24.0" +version = "1.24.1" description = "Extremely lightweight compatibility layer between dataframe libraries" optional = false python-versions = ">=3.8" files = [ - {file = "narwhals-1.24.0-py3-none-any.whl", hash = "sha256:73ff60578641059221de2e4f337bfdf0260378fb1553f787d27411602cfc5e72"}, - {file = "narwhals-1.24.0.tar.gz", hash = "sha256:23f0a05efbe29864d184842dd6bf11c044210bca1d443d6dbffe7e65a70bf063"}, + {file = "narwhals-1.24.1-py3-none-any.whl", hash = "sha256:d8983fe14851c95d60576ddca37c094bd4ed24ab9ea98396844fb20ad9aaf184"}, + {file = "narwhals-1.24.1.tar.gz", hash = "sha256:b09b8253d945f23cdb683a84685abf3afb9f96114d89e9f35dc876e143f65007"}, ] [package.extras] @@ -4641,13 +4641,13 @@ snowflake = ["snowflake-connector-python (>=2.8.0)", "snowflake-snowpark-python[ [[package]] name = "surya-ocr" -version = "0.9.3" +version = "0.10.0" description = "OCR, layout, reading order, and table recognition in 90+ languages" optional = false python-versions = "<4.0,>=3.10" files = [ - {file = "surya_ocr-0.9.3-py3-none-any.whl", hash = "sha256:6013131f3af004f93ab5422dfa8c49a83aa72beb2f8120fd59dca04803d98009"}, - {file = "surya_ocr-0.9.3.tar.gz", hash = "sha256:a69347a3c85c04d48e3df62d11f045dc13e22ab8b3efebfdae1dd94f05a25b99"}, + {file = "surya_ocr-0.10.0-py3-none-any.whl", hash = "sha256:ccad25a308eefd61a21b2c97fc3f5b8364887e09f197a3aaa5fee30c03f81ae1"}, + {file = "surya_ocr-0.10.0.tar.gz", hash = "sha256:966bc0c1aef346df42e458d2c1cbc95665004ea61020577e1656789107d09119"}, ] [package.dependencies] @@ -5489,4 +5489,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "33297ed1b238e67f880534882876a29012559d3263ceba2ba0cdd738598af00c" +content-hash = "d43373ff00de4feb00b0aed4fe98d2a84ecb5742d1a916cabbace5104f888d54" diff --git a/pyproject.toml b/pyproject.toml index d838b9e0..6878e04e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ tqdm = "^4.66.1" ftfy = "^6.1.1" texify = "^0.2.1" rapidfuzz = "^3.8.1" -surya-ocr = "~0.9.3" +surya-ocr = "~0.10.0" regex = "^2024.4.28" pdftext = "~0.5.1" markdownify = "^0.13.1" diff --git a/tests/processors/test_table_processor.py b/tests/processors/test_table_processor.py index 79224a58..72e2a04b 100644 --- a/tests/processors/test_table_processor.py +++ b/tests/processors/test_table_processor.py @@ -1,3 +1,5 @@ +from typing import List + import pytest from marker.renderers.json import JSONRenderer @@ -63,3 +65,15 @@ def test_ocr_table(pdf_document, detection_model, recognition_model, table_rec_m table_output = renderer(pdf_document) assert "1.2E-38" in table_output.markdown + +@pytest.mark.config({"page_range": [11]}) +def test_split_rows(pdf_document, detection_model, recognition_model, table_rec_model): + processor = TableProcessor(detection_model, recognition_model, table_rec_model) + processor(pdf_document) + + table = pdf_document.contained_blocks((BlockTypes.Table,))[-1] + cells: List[TableCell] = table.contained_blocks(pdf_document, (BlockTypes.TableCell,)) + unique_rows = len(set([cell.row_id for cell in cells])) + assert unique_rows == 6 + +