Skip to content

Commit

Permalink
Minor cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 15, 2025
1 parent 50b5573 commit 04bb7ad
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 29 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ python benchmarks/overall.py data/pdfs data/references report.json
The processed FinTabNet dataset is hosted [here](https://huggingface.co/datasets/datalab-to/fintabnet-test) and is automatically downloaded. Run the benchmark with:

```shell
python benchmarks/table/table.py table_report.json --max 1000
python benchmarks/table/table.py table_report.json --max_rows 1000
```

# Thanks
Expand Down
33 changes: 14 additions & 19 deletions benchmarks/table/scoring.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
'''
""""
TEDS Code Adapted from https://github.com/ibm-aur-nlp/EDD
'''
"""

from typing import List

from tqdm import tqdm
import distance
from apted import APTED, Config
from apted.helpers import Tree
from lxml import html
from collections import deque
import numpy as np

def wrap_table_html(table_html:str)->str:
return f'<html><body>{table_html}</body></html>'
Expand All @@ -21,7 +17,9 @@ def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
self.colspan = colspan
self.rowspan = rowspan
self.content = content
self.children = list(children)

# Sets self.name and self.children
super().__init__(tag, *children)

def bracket(self):
"""Show tree using brackets notation"""
Expand All @@ -37,17 +35,12 @@ def bracket(self):
class CustomConfig(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
"""
return max(map(len, sequences))

def normalized_distance(self, *sequences):
"""Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)

def rename(self, node1, node2):
"""Compares attributes of trees"""
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
if node1.tag == 'td':
Expand All @@ -56,8 +49,9 @@ def rename(self, node1, node2):
return 0.

def tokenize(node):
''' Tokenizes table cells
'''
"""
Tokenizes table cells
"""
global __tokens__
__tokens__.append('<%s>' % node.tag)
if node.text is not None:
Expand All @@ -70,8 +64,9 @@ def tokenize(node):
__tokens__ += list(node.tail)

def tree_convert_html(node, convert_cell=False, parent=None):
''' Converts HTML tree to the format required by apted
'''
"""
Converts HTML tree to the format required by apted
"""
global __tokens__
if node.tag == 'td':
if convert_cell:
Expand All @@ -95,9 +90,9 @@ def tree_convert_html(node, convert_cell=False, parent=None):
return new_node

def similarity_eval_html(pred, true, structure_only=False):
''' Computes TEDS score between the prediction and the ground truth of a
given samples
'''
"""
Computes TEDS score between the prediction and the ground truth of a given samples
"""
pred, true = html.fromstring(pred), html.fromstring(true)
if pred.xpath('body/table') and true.xpath('body/table'):
pred = pred.xpath('body/table')[0]
Expand Down
18 changes: 9 additions & 9 deletions benchmarks/table/table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS

import base64
import time
import datasets
from tqdm import tqdm
Expand All @@ -11,8 +13,6 @@
from concurrent.futures import ThreadPoolExecutor
from pypdfium2._helpers.misc import PdfiumError

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS

from marker.config.parser import ConfigParser
from marker.converters.table import TableConverter
from marker.models import create_model_dict
Expand All @@ -30,19 +30,19 @@ def update_teds_score(result):
@click.command(help="Benchmark Table to HTML Conversion")
@click.argument("out_file", type=str)
@click.option("--dataset", type=str, default="datalab-to/fintabnet-test", help="Dataset to use")
@click.option("--max", type=int, default=None, help="Maximum number of PDFs to process")
def main(out_file, dataset, max):
@click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
def main(out_file: str, dataset: str, max_rows: int):
models = create_model_dict()
config_parser = ConfigParser({})
config_parser = ConfigParser({'output_format': 'html'})
start = time.time()


dataset = datasets.load_dataset(dataset, split='train')
dataset = dataset.shuffle(seed=0)

iterations = len(dataset)
if max is not None:
iterations = min(max, len(dataset))
if max_rows is not None:
iterations = min(max_rows, len(dataset))

results = []
for i in tqdm(range(iterations), desc='Converting Tables'):
Expand All @@ -55,7 +55,7 @@ def main(out_file, dataset, max):
config=config_parser.generate_config_dict(),
artifact_dict=models,
processor_list=config_parser.get_processors(),
renderer='marker.renderers.html.HTMLRenderer'
renderer=config_parser.get_renderer()
)

with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as temp_pdf_file:
Expand Down

0 comments on commit 04bb7ad

Please sign in to comment.