Skip to content

Commit

Permalink
Integrate table extraction, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 8, 2025
1 parent 816c3d9 commit eccb7d2
Show file tree
Hide file tree
Showing 13 changed files with 408 additions and 32 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Unit tests

on: [push]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install python dependencies
run: |
pip install poetry
poetry install
- name: Run tests
run: poetry run pytest
File renamed without changes.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pdftext PDF_PATH --out_path output.txt
- `--out_path` path to the output txt file. If not specified, will write to stdout.
- `--sort` will attempt to sort in reading order if specified.
- `--keep_hyphens` will keep hyphens in the output (they will be stripped and words joined otherwise)
- `--pages` will specify pages (comma separated) to extract
- `--page_range` will specify pages (comma separated) to extract. Like `0,5-10,12`.
- `--workers` specifies the number of parallel workers to use
- `--flatten_pdf` merges form fields into the PDF

Expand Down
64 changes: 40 additions & 24 deletions extract_text.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,57 @@
import argparse
import json
from pathlib import Path
from typing import List

import click
import pypdfium2 as pdfium

from pdftext.extraction import plain_text_output, dictionary_output


def main():
parser = argparse.ArgumentParser(description="Extract plain text from PDF. Not guaranteed to be in order.")
parser.add_argument("pdf_path", type=str, help="Path to the PDF file")
parser.add_argument("--out_path", type=str, help="Path to the output text file, defaults to stdout", default=None)
parser.add_argument("--json", action="store_true", help="Output json instead of plain text", default=False)
parser.add_argument("--sort", action="store_true", help="Attempt to sort the text by reading order", default=False)
parser.add_argument("--keep_hyphens", action="store_true", help="Keep hyphens in words", default=False)
parser.add_argument("--pages", type=str, help="Comma separated pages to extract, like 1,2,3", default=None)
parser.add_argument("--flatten_pdf", action="store_true", help="Flatten form fields and annotations into page contents", default=False)
parser.add_argument("--keep_chars", action="store_true", help="Keep character level information", default=False)
parser.add_argument("--workers", type=int, help="Number of workers to use for parallel processing", default=None)
args = parser.parse_args()

def parse_range_str(range_str: str) -> List[int]:
range_lst = range_str.split(",")
page_lst = []
for i in range_lst:
if "-" in i:
start, end = i.split("-")
page_lst += list(range(int(start), int(end) + 1))
else:
page_lst.append(int(i))
page_lst = sorted(list(set(page_lst))) # Deduplicate page numbers and sort in order
return page_lst

@click.command(help="Extract plain text or JSON from PDF.")
@click.argument("pdf_path", type=click.Path(exists=True))
@click.option("--out_path", type=click.Path(exists=False), help="Path to the output text file, defaults to stdout")
@click.option("--json", is_flag=True, help="Output json instead of plain text", default=False)
@click.option("--sort", is_flag=True, help="Attempt to sort the text by reading order", default=False)
@click.option("--keep_hyphens", is_flag=True, help="Keep hyphens in words", default=False)
@click.option("--page_range", type=str, help="Page numbers or ranges to extract, comma separated like 1,2-4,10", default=None)
@click.option("--flatten_pdf", is_flag=True, help="Flatten form fields and annotations into page contents", default=False)
@click.option("--keep_chars", is_flag=True, help="Keep character level information", default=False)
@click.option("--workers", type=int, help="Number of workers to use for parallel processing", default=None)
def main(
pdf_path: Path,
out_path: Path | None,
**kwargs
):
pages = None
if args.pages is not None:
pdf_doc = pdfium.PdfDocument(args.pdf_path)
pages = [int(p) for p in args.pages.split(",")]
if kwargs["page_range"] is not None:
pdf_doc = pdfium.PdfDocument(pdf_path)
pages = parse_range_str(kwargs["page_range"])
doc_len = len(pdf_doc)
pdf_doc.close()
assert all(p <= doc_len for p in pages), "Invalid page number(s) provided"
assert all(0 <= p <= doc_len for p in pages), "Invalid page number(s) provided"

if args.json:
text = dictionary_output(args.pdf_path, sort=args.sort, page_range=pages, flatten_pdf=args.flatten_pdf, keep_chars=args.keep_chars, workers=args.workers)
if kwargs["json"]:
text = dictionary_output(pdf_path, sort=kwargs["sort"], page_range=pages, flatten_pdf=kwargs["flatten_pdf"], keep_chars=kwargs["keep_chars"], workers=kwargs["workers"])
text = json.dumps(text)
else:
text = plain_text_output(args.pdf_path, sort=args.sort, hyphens=args.keep_hyphens, page_range=pages, flatten_pdf=args.flatten_pdf, workers=args.workers)
text = plain_text_output(pdf_path, sort=kwargs["sort"], hyphens=kwargs["keep_hyphens"], page_range=pages, flatten_pdf=kwargs["flatten_pdf"], workers=kwargs["workers"])

if args.out_path is None:
if out_path is None:
print(text)
else:
with open(args.out_path, "w+") as f:
with open(out_path, "w+") as f:
f.write(text)


Expand Down
44 changes: 39 additions & 5 deletions pdftext/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from pdftext.pdf.pages import get_pages
from pdftext.postprocessing import handle_hyphens, merge_text, postprocess_text, sort_blocks
from pdftext.schema import Pages, TableInputs, Tables
from pdftext.settings import settings
from pdftext.tables import table_cell_text


def _load_pdf(pdf, flatten_pdf):
Expand All @@ -22,7 +24,7 @@ def _load_pdf(pdf, flatten_pdf):
return pdf


def _get_page_range(page_range, flatten_pdf=False, quote_loosebox=True):
def _get_page_range(page_range, flatten_pdf=False, quote_loosebox=True) -> Pages:
return get_pages(pdf_doc, page_range, flatten_pdf, quote_loosebox)


Expand All @@ -38,7 +40,7 @@ def worker_init(pdf_path, flatten_pdf):
atexit.register(partial(worker_shutdown, pdf_doc))


def _get_pages(pdf_path, page_range=None, flatten_pdf=False, quote_loosebox=True, workers=None):
def _get_pages(pdf_path, page_range=None, flatten_pdf=False, quote_loosebox=True, workers=None) -> Pages:
pdf_doc = _load_pdf(pdf_path, flatten_pdf)
if page_range is None:
page_range = range(len(pdf_doc))
Expand Down Expand Up @@ -70,7 +72,7 @@ def plain_text_output(pdf_path, sort=False, hyphens=False, page_range=None, flat


def paginated_plain_text_output(pdf_path, sort=False, hyphens=False, page_range=None, flatten_pdf=False, workers=None) -> List[str]:
pages = _get_pages(pdf_path, page_range, workers=workers, flatten_pdf=flatten_pdf)
pages: Pages = _get_pages(pdf_path, page_range, workers=workers, flatten_pdf=flatten_pdf)
text = []
for page in pages:
text.append(merge_text(page, sort=sort, hyphens=hyphens).strip())
Expand All @@ -87,8 +89,16 @@ def _process_span(span, page_width, page_height, keep_chars):
char["bbox"] = char["bbox"].bbox


def dictionary_output(pdf_path, sort=False, page_range=None, keep_chars=False, flatten_pdf=False, quote_loosebox=True, workers=None):
pages = _get_pages(pdf_path, page_range, workers=workers, flatten_pdf=flatten_pdf, quote_loosebox=quote_loosebox)
def dictionary_output(
pdf_path,
sort=False,
page_range=None,
keep_chars=False,
flatten_pdf=False,
quote_loosebox=True,
workers=None
) -> Pages:
pages: Pages = _get_pages(pdf_path, page_range, workers=workers, flatten_pdf=flatten_pdf, quote_loosebox=quote_loosebox)
for page in pages:
page_width, page_height = page["width"], page["height"]
for block in page["blocks"]:
Expand All @@ -111,3 +121,27 @@ def dictionary_output(pdf_path, sort=False, page_range=None, keep_chars=False, f
page["width"], page["height"] = page["height"], page["width"]
page["bbox"] = [page["bbox"][2], page["bbox"][3], page["bbox"][0], page["bbox"][1]]
return pages

def table_output(
pdf_path: str,
table_inputs: TableInputs,
page_range=None,
flatten_pdf=False,
quote_loosebox=True,
workers=None,
pages: Pages | None = None
) -> List[Tables]:
# Extract pages if they don't exist
if not pages:
pages: Pages = dictionary_output(pdf_path, page_range=page_range, flatten_pdf=flatten_pdf, quote_loosebox=quote_loosebox, workers=workers, keep_chars=True)

assert len(pages) == len(table_inputs), "Number of pages and table inputs must match"

# Extract table cells per page
out_tables = []
for page, table_input in zip(pages, table_inputs):
tables = table_cell_text(table_input["tables"], page, table_input["img_size"])
assert len(tables) == len(table_input["tables"]), "Number of tables and table inputs must match"
out_tables.append(tables)
return out_tables

28 changes: 27 additions & 1 deletion pdftext/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ class Bbox:
def __init__(self, bbox: List[float]):
self.bbox = bbox

def __getitem__(self, item):
return self.bbox[item]

@property
def height(self):
return self.bbox[3] - self.bbox[1]
Expand Down Expand Up @@ -101,6 +104,18 @@ def rotate(self, page_width: float, page_height: float, rotation: int) -> Bbox:

return Bbox(rotated_bbox)

def rescale(self, img_size: List[int], page: Page) -> Bbox:
w_scale = img_size[0] / page["width"]
h_scale = img_size[1] / page["height"]
new_bbox = [
self.bbox[0] * w_scale,
self.bbox[1] * h_scale,
self.bbox[2] * w_scale,
self.bbox[3] * h_scale
]

return Bbox(new_bbox)


class Char(TypedDict):
bbox: Bbox
Expand All @@ -116,7 +131,7 @@ class Span(TypedDict):
font: Dict[str, Union[Any, str]]
font_weight: float
font_size: float
chars: List[Char]
chars: List[Char] | None
char_start_idx: int
char_end_idx: int

Expand All @@ -137,10 +152,21 @@ class Page(TypedDict):
width: int
height: int
blocks: List[Block]
rotation: int

class TableCell(TypedDict):
text: str
bbox: Bbox

class TableInput(TypedDict):
tables: List[List[int]]
img_size: List[int]


Chars = List[Char]
Spans = List[Span]
Lines = List[Line]
Blocks = List[Block]
Pages = List[Page]
Tables = List[List[TableCell]]
TableInputs = List[TableInput]
129 changes: 129 additions & 0 deletions pdftext/tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import List
import numpy as np

from pdftext.schema import Pages, Page, Bbox, Tables


def sort_text_lines(lines: List[dict], tolerance=1.25):
# Sorts in reading order. Not 100% accurate, this should only
# be used as a starting point for more advanced sorting.
vertical_groups = {}
for line in lines:
group_key = (line["bbox"][1] / tolerance) * tolerance
if group_key not in vertical_groups:
vertical_groups[group_key] = []
vertical_groups[group_key].append(line)

# Sort each group horizontally and flatten the groups into a single list
sorted_lines = []
for _, group in sorted(vertical_groups.items()):
sorted_group = sorted(group, key=lambda x: x["bbox"][0])
sorted_lines.extend(sorted_group)

return sorted_lines


def get_dynamic_gap_thresh(page: Page, img_size: list, default_thresh=.01, min_chars=100):
space_dists = []
for block in page["blocks"]:
for line in block["lines"]:
for span in line["spans"]:
for i in range(1, len(span["chars"])):
char1 = span["chars"][i - 1]
char2 = span["chars"][i]
if page["rotation"] == 90:
space_dists.append((char2["bbox"][0] - char1["bbox"][2]) / img_size[0])
elif page["rotation"] == 180:
space_dists.append((char2["bbox"][1] - char1["bbox"][3]) / img_size[1])
elif page["rotation"] == 270:
space_dists.append((char1["bbox"][0] - char2["bbox"][2]) / img_size[0])
else:
space_dists.append((char1["bbox"][1] - char2["bbox"][3]) / img_size[1])
cell_gap_thresh = np.percentile(space_dists, 80) if len(space_dists) > min_chars else default_thresh
return cell_gap_thresh


def is_same_span(char, curr_box, img_size, space_thresh, rotation):
def normalized_diff(a, b, dimension, mult=1, use_abs=True):
func = abs if use_abs else lambda x: x
return func(a - b) / img_size[dimension] < space_thresh * mult

bbox = char["bbox"]
if rotation == 90:
return all([
normalized_diff(bbox[0], curr_box[0], 0, use_abs=False),
normalized_diff(bbox[1], curr_box[3], 1),
normalized_diff(bbox[0], curr_box[0], 0, mult=5)
])
elif rotation == 180:
return all([
normalized_diff(bbox[2], curr_box[0], 0, use_abs=False),
normalized_diff(bbox[1], curr_box[1], 1),
normalized_diff(bbox[2], curr_box[0], 1, mult=5)
])
elif rotation == 270:
return all([
normalized_diff(bbox[0], curr_box[0], 0, use_abs=False),
normalized_diff(bbox[3], curr_box[1], 1),
normalized_diff(bbox[0], curr_box[0], 1, mult=5)
])
else: # 0 or default case
return all([
normalized_diff(bbox[0], curr_box[2], 0, use_abs=False),
normalized_diff(bbox[1], curr_box[1], 1),
normalized_diff(bbox[0], curr_box[2], 1, mult=5)
])


def table_cell_text(tables: List[List[int]], page: Page, img_size: list, table_thresh=.8, space_thresh=.01) -> Tables:
# Note: table is a list of 4 ints representing the bounding box of the table. This is against the image dims - this can be different from the page dims.
# We rescale the characters below to account for this.
assert all(len(table) == 4 for table in tables), "Tables must be a list of 4 ints representing the bounding box of the table"
assert len(img_size) == 2, "img_size must be a list of 2 ints representing the image dimensions width, height"

table_texts = []
space_thresh = max(space_thresh, get_dynamic_gap_thresh(page, img_size, default_thresh=space_thresh))
for table in tables:
table_poly = Bbox(bbox=table)
table_text = []
rotation = page["rotation"]

for block in page["blocks"]:
for line in block["lines"]:
if line["bbox"].intersection_pct(table_poly) < table_thresh:
continue
curr_span = None
curr_box = None
for span in line["spans"]:
for char in span["chars"]:
char["bbox"] = char["bbox"].rescale(img_size, page) # Rescale to match image dimensions
same_span = False
if curr_span:
same_span = is_same_span(char, curr_box, img_size, space_thresh, rotation)

if curr_span is None:
curr_span = char["char"]
curr_box = char["bbox"]
elif same_span:
curr_span += char["char"]
curr_box = [min(curr_box[0], char["bbox"][0]), min(curr_box[1], char["bbox"][1]),
max(curr_box[2], char["bbox"][2]), max(curr_box[3], char["bbox"][3])]
else:
if curr_span.strip():
table_text.append({"text": curr_span, "bbox": curr_box})
curr_span = char["char"]
curr_box = char["bbox"]
if curr_span is not None and curr_span.strip():
table_text.append({"text": curr_span, "bbox": curr_box})
# Adjust to be relative to input table
for item in table_text:
item["bbox"] = [
item["bbox"][0] - table[0],
item["bbox"][1] - table[1],
item["bbox"][2] - table[0],
item["bbox"][3] - table[1]
]
item["bbox"] = Bbox(bbox=item["bbox"])
table_text = sort_text_lines(table_text)
table_texts.append(table_text)
return table_texts
Loading

0 comments on commit eccb7d2

Please sign in to comment.