diff --git a/marker/v2/converters/pdf.py b/marker/v2/converters/pdf.py index cf4fbbfa..48741385 100644 --- a/marker/v2/converters/pdf.py +++ b/marker/v2/converters/pdf.py @@ -1,30 +1,39 @@ +from marker.v2.providers.pdf import PdfProvider import os + os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning -from marker.v2.processors.sectionheader import SectionHeaderProcessor -from marker.v2.providers.pdf import PdfProvider import tempfile -from typing import List, Optional +from collections import defaultdict +from typing import Dict, Type import click import datasets -from pydantic import BaseModel from marker.v2.builders.document import DocumentBuilder from marker.v2.builders.layout import LayoutBuilder from marker.v2.builders.ocr import OcrBuilder from marker.v2.builders.structure import StructureBuilder from marker.v2.converters import BaseConverter +from marker.v2.models import setup_detection_model, setup_layout_model, \ + setup_recognition_model, setup_table_rec_model, setup_texify_model from marker.v2.processors.equation import EquationProcessor +from marker.v2.processors.sectionheader import SectionHeaderProcessor from marker.v2.processors.table import TableProcessor -from marker.v2.models import setup_layout_model, setup_texify_model, setup_recognition_model, setup_table_rec_model, \ - setup_detection_model from marker.v2.renderers.markdown import MarkdownRenderer +from marker.v2.schema import BlockTypes +from marker.v2.schema.blocks import Block +from marker.v2.schema.registry import BLOCK_REGISTRY class PdfConverter(BaseConverter): + override_map: Dict[BlockTypes, Type[Block]] = defaultdict() + def __init__(self, config=None): super().__init__(config) + + for block_type, override_block_type in self.override_map.items(): + BLOCK_REGISTRY[block_type] = override_block_type self.layout_model = setup_layout_model() self.texify_model = setup_texify_model() diff --git a/poetry.lock b/poetry.lock index e4e8dfcb..c4c900ca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1716,6 +1716,21 @@ profiling = ["gprof2dot"] rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] +[[package]] +name = "markdownify" +version = "0.13.1" +description = "Convert HTML to markdown." +optional = false +python-versions = "*" +files = [ + {file = "markdownify-0.13.1-py3-none-any.whl", hash = "sha256:1d181d43d20902bcc69d7be85b5316ed174d0dda72ff56e14ae4c95a4a407d22"}, + {file = "markdownify-0.13.1.tar.gz", hash = "sha256:ab257f9e6bd4075118828a28c9d02f8a4bfeb7421f558834aa79b2dfeb32a098"}, +] + +[package.dependencies] +beautifulsoup4 = ">=4.9,<5" +six = ">=1.15,<2" + [[package]] name = "markupsafe" version = "3.0.2" @@ -5208,4 +5223,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "87f2bf6e84db6c7100db24777a4104096cb963d9bea7380e520dbc1bd8f2feb7" +content-hash = "4cb4a1d2f40994498d5657bcb5ab37e7ec9c4bb867015d34131c53000eacfbee" diff --git a/pyproject.toml b/pyproject.toml index d2b49103..233e6ecd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ filetype = "^1.2.0" regex = "^2024.4.28" pdftext = "^0.3.18" tabled-pdf = { git = "https://github.com/VikParuchuri/tabled.git", branch = "dev-mose/compilation-updates" } +markdownify = "^0.13.1" [tool.poetry.group.dev.dependencies] jupyter = "^1.0.0" diff --git a/tests/conftest.py b/tests/conftest.py index 178470a6..2c63a934 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,13 +3,17 @@ import datasets import pytest +from typing import Dict, Type +from marker.v2.schema import BlockTypes +from marker.v2.schema.blocks import Block from marker.v2.models import setup_layout_model, setup_texify_model, setup_recognition_model, setup_table_rec_model, \ setup_detection_model from marker.v2.builders.document import DocumentBuilder from marker.v2.builders.layout import LayoutBuilder from marker.v2.builders.ocr import OcrBuilder from marker.v2.schema.document import Document +from marker.v2.schema.registry import BLOCK_REGISTRY @pytest.fixture(scope="session") @@ -48,13 +52,22 @@ def table_rec_model(): @pytest.fixture(scope="function") -def pdf_provider(request): +def config(request): + config_mark = request.node.get_closest_marker("config") + config = config_mark.args[0] if config_mark else {} + + override_map: Dict[BlockTypes, Type[Block]] = config.get("override_map", {}) + for block_type, override_block_type in override_map.items(): + BLOCK_REGISTRY[block_type] = override_block_type + + return config + + +@pytest.fixture(scope="function") +def pdf_provider(request, config): filename_mark = request.node.get_closest_marker("filename") filename = filename_mark.args[0] if filename_mark else "adversarial.pdf" - config_mark = request.node.get_closest_marker("config") - config = config_mark.args[0] if config_mark else None - dataset = datasets.load_dataset("datalab-to/pdfs", split="train") idx = dataset['filename'].index(filename) @@ -65,10 +78,7 @@ def pdf_provider(request): @pytest.fixture(scope="function") -def pdf_document(request, pdf_provider, layout_model, recognition_model, detection_model) -> Document: - config_mark = request.node.get_closest_marker("config") - config = config_mark.args[0] if config_mark else None - +def pdf_document(request, config, pdf_provider, layout_model, recognition_model, detection_model) -> Document: layout_builder = LayoutBuilder(layout_model, config) ocr_builder = OcrBuilder(detection_model, recognition_model, config) builder = DocumentBuilder(config) diff --git a/tests/test_overriding.py b/tests/test_overriding.py new file mode 100644 index 00000000..b70bb1c7 --- /dev/null +++ b/tests/test_overriding.py @@ -0,0 +1,18 @@ +import pytest + +from marker.v2.schema import BlockTypes +from marker.v2.schema.document import Document +from marker.v2.schema.blocks import SectionHeader + + +class NewSectionHeader(SectionHeader): + pass + + +@pytest.mark.config({ + "page_range": [0], + "override_map": {BlockTypes.SectionHeader: NewSectionHeader} +}) +def test_overriding(pdf_document: Document): + assert pdf_document.pages[0]\ + .get_block(pdf_document.pages[0].structure[0]).__class__ == NewSectionHeader