-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #480 from tarun-menta/table_benchmarks
Table benchmarks
- Loading branch information
Showing
5 changed files
with
418 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
''' | ||
TEDS Code Adapted from https://github.com/ibm-aur-nlp/EDD | ||
''' | ||
|
||
from typing import List | ||
|
||
from tqdm import tqdm | ||
import distance | ||
from apted import APTED, Config | ||
from apted.helpers import Tree | ||
from lxml import html | ||
from collections import deque | ||
import numpy as np | ||
|
||
def wrap_table_html(table_html:str)->str: | ||
return f'<html><body>{table_html}</body></html>' | ||
|
||
class TableTree(Tree): | ||
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): | ||
self.tag = tag | ||
self.colspan = colspan | ||
self.rowspan = rowspan | ||
self.content = content | ||
self.children = list(children) | ||
|
||
def bracket(self): | ||
"""Show tree using brackets notation""" | ||
if self.tag == 'td': | ||
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \ | ||
(self.tag, self.colspan, self.rowspan, self.content) | ||
else: | ||
result = '"tag": %s' % self.tag | ||
for child in self.children: | ||
result += child.bracket() | ||
return "{{{}}}".format(result) | ||
|
||
class CustomConfig(Config): | ||
@staticmethod | ||
def maximum(*sequences): | ||
"""Get maximum possible value | ||
""" | ||
return max(map(len, sequences)) | ||
|
||
def normalized_distance(self, *sequences): | ||
"""Get distance from 0 to 1 | ||
""" | ||
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) | ||
|
||
def rename(self, node1, node2): | ||
"""Compares attributes of trees""" | ||
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): | ||
return 1. | ||
if node1.tag == 'td': | ||
if node1.content or node2.content: | ||
return self.normalized_distance(node1.content, node2.content) | ||
return 0. | ||
|
||
def tokenize(node): | ||
''' Tokenizes table cells | ||
''' | ||
global __tokens__ | ||
__tokens__.append('<%s>' % node.tag) | ||
if node.text is not None: | ||
__tokens__ += list(node.text) | ||
for n in node.getchildren(): | ||
tokenize(n) | ||
if node.tag != 'unk': | ||
__tokens__.append('</%s>' % node.tag) | ||
if node.tag != 'td' and node.tail is not None: | ||
__tokens__ += list(node.tail) | ||
|
||
def tree_convert_html(node, convert_cell=False, parent=None): | ||
''' Converts HTML tree to the format required by apted | ||
''' | ||
global __tokens__ | ||
if node.tag == 'td': | ||
if convert_cell: | ||
__tokens__ = [] | ||
tokenize(node) | ||
cell = __tokens__[1:-1].copy() | ||
else: | ||
cell = [] | ||
new_node = TableTree(node.tag, | ||
int(node.attrib.get('colspan', '1')), | ||
int(node.attrib.get('rowspan', '1')), | ||
cell, *deque()) | ||
else: | ||
new_node = TableTree(node.tag, None, None, None, *deque()) | ||
if parent is not None: | ||
parent.children.append(new_node) | ||
if node.tag != 'td': | ||
for n in node.getchildren(): | ||
tree_convert_html(n, convert_cell, new_node) | ||
if parent is None: | ||
return new_node | ||
|
||
def similarity_eval_html(pred, true, structure_only=False): | ||
''' Computes TEDS score between the prediction and the ground truth of a | ||
given samples | ||
''' | ||
pred, true = html.fromstring(pred), html.fromstring(true) | ||
if pred.xpath('body/table') and true.xpath('body/table'): | ||
pred = pred.xpath('body/table')[0] | ||
true = true.xpath('body/table')[0] | ||
n_nodes_pred = len(pred.xpath(".//*")) | ||
n_nodes_true = len(true.xpath(".//*")) | ||
tree_pred = tree_convert_html(pred, convert_cell=not structure_only) | ||
tree_true = tree_convert_html(true, convert_cell=not structure_only) | ||
n_nodes = max(n_nodes_pred, n_nodes_true) | ||
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance() | ||
return 1.0 - (float(distance) / n_nodes) | ||
else: | ||
return 0.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import base64 | ||
import os | ||
import time | ||
import datasets | ||
from tqdm import tqdm | ||
import tempfile | ||
import click | ||
from tabulate import tabulate | ||
import json | ||
from bs4 import BeautifulSoup | ||
from concurrent.futures import ThreadPoolExecutor | ||
from pypdfium2._helpers.misc import PdfiumError | ||
|
||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS | ||
|
||
from marker.config.parser import ConfigParser | ||
from marker.converters.table import TableConverter | ||
from marker.models import create_model_dict | ||
|
||
from scoring import wrap_table_html, similarity_eval_html | ||
|
||
def update_teds_score(result): | ||
prediction, ground_truth = result['marker_table'], result['gt_table'] | ||
prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth) | ||
score = similarity_eval_html(prediction, ground_truth) | ||
result.update({'score':score}) | ||
return result | ||
|
||
|
||
@click.command(help="Benchmark Table to HTML Conversion") | ||
@click.argument("out_file", type=str) | ||
@click.option("--dataset", type=str, default="datalab-to/fintabnet-test", help="Dataset to use") | ||
@click.option("--max", type=int, default=None, help="Maximum number of PDFs to process") | ||
def main(out_file, dataset, max): | ||
models = create_model_dict() | ||
config_parser = ConfigParser({}) | ||
start = time.time() | ||
|
||
|
||
dataset = datasets.load_dataset(dataset, split='train') | ||
dataset = dataset.shuffle(seed=0) | ||
|
||
iterations = len(dataset) | ||
if max is not None: | ||
iterations = min(max, len(dataset)) | ||
|
||
results = [] | ||
for i in tqdm(range(iterations), desc='Converting Tables'): | ||
try: | ||
row = dataset[i] | ||
pdf_binary = base64.b64decode(row['pdf']) | ||
gt_tables = row['tables'] #Already sorted by reading order, which is what marker returns | ||
|
||
converter = TableConverter( | ||
config=config_parser.generate_config_dict(), | ||
artifact_dict=models, | ||
processor_list=config_parser.get_processors(), | ||
renderer='marker.renderers.html.HTMLRenderer' | ||
) | ||
|
||
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as temp_pdf_file: | ||
temp_pdf_file.write(pdf_binary) | ||
temp_pdf_file.seek(0) | ||
marker_table_html = converter(temp_pdf_file.name).html | ||
|
||
marker_table_soup = BeautifulSoup(marker_table_html, 'html.parser') | ||
marker_detected_tables = marker_table_soup.find_all('table') | ||
if len(marker_detected_tables)==0: | ||
print(f'No tables detected, skipping...') | ||
|
||
for marker_table_soup, gt_table in zip(marker_detected_tables, gt_tables): | ||
gt_table_html = gt_table['html'] | ||
|
||
#marker wraps the table in <tbody> which fintabnet data doesn't | ||
marker_table_soup.find('tbody').unwrap() | ||
#Fintabnet doesn't use th tags, need to be replaced for fair comparison | ||
for th_tag in marker_table_soup.find_all('th'): | ||
th_tag.name = 'td' | ||
marker_table_html = str(marker_table_soup) | ||
|
||
results.append({ | ||
"marker_table": marker_table_html, | ||
"gt_table": gt_table_html | ||
}) | ||
except PdfiumError: | ||
print('Broken PDF, Skipping...') | ||
continue | ||
|
||
print(f"Total time: {time.time() - start}") | ||
|
||
with ThreadPoolExecutor(max_workers=16) as executor: | ||
results = list(tqdm(executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results))) | ||
avg_score = sum([r["score"] for r in results]) / len(results) | ||
|
||
headers = ["Avg score", "Total tables"] | ||
data = [f"{avg_score:.3f}", len(results)] | ||
table = tabulate([data], headers=headers, tablefmt="github") | ||
print(table) | ||
print("Avg score computed by comparing marker predicted HTML with original HTML") | ||
|
||
with open(out_file, "w+") as f: | ||
json.dump(results, f, indent=2) | ||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.