diff --git a/marker/processors/table.py b/marker/processors/table.py index 783a73ce..391ab4b3 100644 --- a/marker/processors/table.py +++ b/marker/processors/table.py @@ -53,6 +53,10 @@ class TableProcessor(BaseProcessor): int, "The number of workers to use for pdftext.", ] = 4 + row_split_threshold: Annotated[ + float, + "The percentage of rows that need to be split across the table before row splitting is active.", + ] = 0.5 def __init__( self, @@ -171,10 +175,7 @@ def split_combined_rows(self, tables: List[TableResult]): # Skip empty tables continue unique_rows = sorted(list(set([c.row_id for c in table.cells]))) - new_cells = [] - shift_up = 0 - max_cell_id = max([c.cell_id for c in table.cells]) - new_cell_count = 0 + row_info = [] for row in unique_rows: # Cells in this row # Deepcopy is because we do an in-place mutation later, and that can cause rows to shift to match rows in unique_rows @@ -201,9 +202,25 @@ def split_combined_rows(self, tables: List[TableResult]): len(line_lens_counter) == 2 and counter_keys[0] <= 1 and counter_keys[1] > 1 and line_lens_counter[counter_keys[0]] == 1, # Allow a single column with a single line - keys are the line lens, values are the counts ]) should_split = should_split_entire_row or should_split_partial_row - if should_split: - for i in range(0, max(line_lens)): - for cell in row_cells: + row_info.append({ + "should_split": should_split, + "row_cells": row_cells, + "line_lens": line_lens + }) + + # Don't split if we're not splitting most of the rows in the table. This avoids splitting stray multiline rows. + if sum([r["should_split"] for r in row_info]) / len(row_info) < self.row_split_threshold: + continue + + new_cells = [] + shift_up = 0 + max_cell_id = max([c.cell_id for c in table.cells]) + new_cell_count = 0 + for row, item_info in zip(unique_rows, row_info): + max_lines = max(item_info["line_lens"]) + if item_info["should_split"]: + for i in range(0, max_lines): + for cell in item_info["row_cells"]: # Calculate height based on number of splits split_height = cell.bbox[3] - cell.bbox[1] current_bbox = [cell.bbox[0], cell.bbox[1] + i * split_height, cell.bbox[2], cell.bbox[1] + (i + 1) * split_height] @@ -227,9 +244,9 @@ def split_combined_rows(self, tables: List[TableResult]): # For each new row we add, shift up subsequent rows # The max is to account for partial rows - shift_up += max(line_lens) - 1 + shift_up += max_lines - 1 else: - for cell in row_cells: + for cell in item_info["row_cells"]: cell.row_id += shift_up new_cells.append(cell)