diff --git a/requirements.txt b/requirements.txt
index 3911a3d..87a1598 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,8 @@
-llama-cpp-python==0.2.78
-chromadb~=0.5
-langchain~=0.2.4
-langchain-community~=0.2.4
-langchain-openai~=0.1.8
+llama-cpp-python==0.2.88
+chromadb~=0.5.5
+langchain~=0.2.14
+langchain-community~=0.2.12
+langchain-openai~=0.1.22
langchain-huggingface~=0.0.3
pydantic~=2.7
transformers~=4.41
@@ -16,14 +16,14 @@ python-dotenv
accelerate~=0.33
protobuf==3.20.2
termcolor
-openai~=1.34.0
+openai~=1.41
einops # required for Mosaic models
click
bitsandbytes==0.43.1
# auto-gptq==0.2.0
InstructorEmbedding==1.0.1
unstructured~=0.14.5
-pymupdf==1.22.5
+pymupdf==1.24.9
streamlit~=1.28
python-docx~=1.1
six==1.16.0 ; python_version >= "3.10" and python_version < "4.0"
@@ -36,4 +36,6 @@ threadpoolctl==3.1.0 ; python_version >= "3.10" and python_version < "4.0"
tiktoken==0.7.0 ; python_version >= "3.10" and python_version < "4.0"
tokenizers==0.19.1; python_version >= "3.10" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.10" and python_version < "4.0"
-# transformers==4.29.2 ; python_version >= "3.10" and python_version < "4.0"
\ No newline at end of file
+# transformers==4.29.2 ; python_version >= "3.10" and python_version < "4.0"
+gmft==0.2.1
+google-generativeai~=0.7
\ No newline at end of file
diff --git a/src/llmsearch/config.py b/src/llmsearch/config.py
index 33b75ec..ac40bbe 100644
--- a/src/llmsearch/config.py
+++ b/src/llmsearch/config.py
@@ -62,7 +62,16 @@ class PDFTableParser(str, Enum):
class PDFImageParser(str, Enum):
GEMINI_15_FLASH = "gemini-1.5-flash"
+ GEMINI_15_PRO= "gemini-1.5-pro"
+class PDFImageParseSettings(BaseModel):
+ image_parser: PDFImageParser
+ system_instruction: str = """You are an research assistant. You analyze the image to extract detailed information. Response must be a Markdown string in the follwing format:
+- First line is a heading with image caption, starting with '# '
+- Second line is empty
+- From the third line on - detailed data points and related metadata, extracted from the image, in Markdown format. Don't use Markdown tables.
+"""
+ user_instruction: str = """From the image, extract detailed quantitative and qualitative data points."""
class EmbeddingModelType(str, Enum):
huggingface = "huggingface"
@@ -92,7 +101,7 @@ class DocumentPathSettings(BaseModel):
pdf_table_parser: Optional[PDFTableParser] = None
"""If enabled, will parse tables in pdf files using a specific of a parser."""
- pdf_image_parser: Optional[PDFImageParser] = None
+ pdf_image_parser: Optional[PDFImageParseSettings] = None
"""If enabled, will parse images in pdf files using a specific of a parser."""
additional_parser_settings: Dict[str, Any] = Field(default_factory=dict)
diff --git a/src/llmsearch/parsers/images/gemini_parser.py b/src/llmsearch/parsers/images/gemini_parser.py
index dce8de8..5033057 100644
--- a/src/llmsearch/parsers/images/gemini_parser.py
+++ b/src/llmsearch/parsers/images/gemini_parser.py
@@ -19,18 +19,15 @@ class GeminiImageAnalyzer:
def __init__(
self,
model_name: str,
- instruction: str = """From the image, extract detailed quantitative and qualitative data points.""",
+ system_instruction: str,
+ user_instruction: str
):
self.model_name = model_name
- self.instruction = instruction
+ self.instruction = user_instruction
+ print(system_instruction, user_instruction)
self.model = genai.GenerativeModel(
model_name,
- system_instruction="""You are an research assistant. You analyze the image to extract detailed information. Response must be a Markdown string in the follwing format:
-
-- First line is a heading with image caption, starting with '# '
-- Second line is empty
-- From the third line on - detailed data points and related metadata, extracted from the image, in Markdown format. Don't use Markdown tables.
-""",
+ system_instruction = system_instruction,
generation_config=genai.types.GenerationConfig(
# Only one candidate for now.
candidate_count=1,
diff --git a/src/llmsearch/parsers/images/generic.py b/src/llmsearch/parsers/images/generic.py
index f9c2943..3b59764 100644
--- a/src/llmsearch/parsers/images/generic.py
+++ b/src/llmsearch/parsers/images/generic.py
@@ -1,8 +1,9 @@
from collections import defaultdict
+import importlib
import io
from multiprocessing.pool import ThreadPool
from pathlib import Path
-from typing import Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Callable
import PIL.Image
import pymupdf
@@ -10,9 +11,40 @@
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random_exponential
-from llmsearch.config import PDFImageParser
+from llmsearch.config import PDFImageParseSettings, PDFImageParser
from llmsearch.parsers.markdown import markdown_splitter
+# Define a mapping of PDFImageParser to corresponding analyzer classes and config
+ANALYZER_MAPPING: Dict[PDFImageParser, Any] = {
+ PDFImageParser.GEMINI_15_FLASH: {
+ "import_path": "llmsearch.parsers.images.gemini_parser", # Import path for lazy loading
+ "class_name": "GeminiImageAnalyzer",
+ "params": {"model_name": "gemini-1.5-flash"},
+ },
+
+ PDFImageParser.GEMINI_15_PRO: {
+ "import_path": "llmsearch.parsers.images.gemini_parser", # Import path for lazy loading
+ "class_name": "GeminiImageAnalyzer",
+ "params": {"model_name": "gemini-1.5-pro"},
+ },
+}
+
+
+def create_analyzer(image_analyzer: PDFImageParser, **additional_params):
+ analyzer_info = ANALYZER_MAPPING.get(image_analyzer)
+
+ if analyzer_info is None:
+ raise ValueError(f"Unsupported image analyzer type: {image_analyzer}")
+
+ # Lazy load the module
+ module = importlib.import_module(analyzer_info["import_path"])
+ analyzer_class = getattr(module, analyzer_info["class_name"])
+ analyzer_params = analyzer_info["params"]
+
+ params = {**analyzer_params, **additional_params}
+
+ return analyzer_class(**params)
+
class PDFImage(BaseModel):
image_fn: Path
@@ -26,49 +58,41 @@ def __init__(
self,
pdf_fn: Path,
temp_folder: Path,
- image_analyzer,
- save_output=True,
+ image_analyzer: Callable,
+ save_output: bool = True,
max_base_width: int = 1280,
min_width: int = 640,
min_height: int = 200,
):
self.pdf_fn = pdf_fn
- self.max_base_width = max_base_width
self.temp_folder = temp_folder
- self.min_width = min_width
- self.min_height = min_height
self.image_analyzer = image_analyzer
self.save_output = save_output
+ self.max_base_width = max_base_width
+ self.min_width = min_width
+ self.min_height = min_height
def prepare_and_clean_folder(self):
- # Check if the folder exists
if not self.temp_folder.exists():
- # Create the folder if it doesn't exist
self.temp_folder.mkdir(parents=True, exist_ok=True)
logger.info(f"Created folder: {self.temp_folder}")
else:
for file in self.temp_folder.iterdir():
if file.is_file():
- file.unlink() # Delete the file
- logger.info(f"Deleted file: {file}")
+ file.unlink()
+ logger.debug(f"Deleted file: {file}")
def extract_images(self) -> List[PDFImage]:
self.prepare_and_clean_folder()
-
doc = pymupdf.open(self.pdf_fn)
out_images = []
for page in doc:
- page_images = page.get_images()
- for img in page_images:
+ for img in page.get_images():
xref = img[0]
- data = doc.extract_image(xref=xref)
- out_fn = self._resize_and_save_image(
- data=data,
- page_num=page.number,
- xref_num=xref,
- )
- if out_fn is not None:
+ data = doc.extract_image(xref)
+ out_fn = self._resize_and_save_image(data, page.number, xref)
+ if out_fn:
out_images.append(
PDFImage(
image_fn=out_fn,
@@ -76,7 +100,6 @@ def extract_images(self) -> List[PDFImage]:
bbox=(img[1], img[2], img[3], img[4]),
)
)
-
return out_images
def _resize_and_save_image(
@@ -85,30 +108,32 @@ def _resize_and_save_image(
page_num: int,
xref_num: int,
) -> Optional[Path]:
-
- image = data.get("image", None)
- if image is None:
+ image_data = data.get("image")
+ if not image_data:
return
- with PIL.Image.open(io.BytesIO(image)) as img:
+ with PIL.Image.open(io.BytesIO(image_data)) as img:
if img.size[1] < self.min_height or img.size[0] < self.min_width:
- logger.info(
+ logger.debug(
f"Image on page {page_num}, xref {xref_num} is too small. Skipping extraction..."
)
return None
- wpercent = self.max_base_width / float(img.size[0])
- # Resize the image, if needed
+ wpercent = self.max_base_width / float(img.size[0])
if wpercent < 1:
- hsize = int((float(img.size[1]) * float(wpercent)))
+ hsize = int(float(img.size[1]) * wpercent)
img = img.resize(
(self.max_base_width, hsize), PIL.Image.Resampling.LANCZOS
)
- out_fn = self.temp_folder / (str(self.pdf_fn.stem) + f"_page_{page_num}_xref_{xref_num}.png")
- logger.info(f"Saving file: {out_fn}")
- img.save(out_fn, mode="wb")
- return Path(out_fn)
+ out_fn = (
+ self.temp_folder
+ / f"{self.pdf_fn.stem}_page_{page_num}_xref_{xref_num}.png"
+ )
+ logger.debug(f"Saving file: {out_fn}")
+ img.save(out_fn)
+
+ return out_fn
def analyze_images_threaded(
self, extracted_images: List[PDFImage], max_threads: int = 10
@@ -117,22 +142,25 @@ def analyze_images_threaded(
results = pool.starmap(
analyze_single_image,
[
- (pdf_image, self.image_analyzer, i)
- for i, pdf_image in enumerate(extracted_images)
+ (img, self.image_analyzer, i)
+ for i, img in enumerate(extracted_images)
],
)
if self.save_output:
- for r in results:
- with open(str(r.image_fn)[:-3] + ".md", "w") as file:
- file.write(r.markdown)
+ for result in results:
+ with open(str(result.image_fn).replace(".png", ".md"), "w") as file:
+ file.write(result.markdown)
return results
def log_attempt_number(retry_state):
- """return the result of the last call attempt"""
- logger.error(f"API call attempt failed. Retrying: {retry_state.attempt_number}...")
+ error_message = str(retry_state.outcome.exception())
+ logger.error(
+ f"API call attempt {retry_state.attempt_number} failed with error: {error_message}. Retrying..."
+ )
+ # logger.error(f"API call attempt failed. Retrying: {retry_state.attempt_number}...")
@retry(
@@ -140,26 +168,30 @@ def log_attempt_number(retry_state):
stop=stop_after_attempt(6),
after=log_attempt_number,
)
-def analyze_single_image(pdf_image: PDFImage, image_analyzer, i: int) -> PDFImage:
- fn = pdf_image.image_fn
- pdf_image.markdown = image_analyzer.analyze(fn)
+def analyze_single_image(
+ pdf_image: PDFImage, image_analyzer: Callable, i: int
+) -> PDFImage:
+ pdf_image.markdown = image_analyzer.analyze(pdf_image.image_fn)
return pdf_image
def get_image_chunks(
- path: Path, max_size: int, image_analyzer: PDFImageParser, cache_folder: Path
+ path: Path,
+ max_size: int,
+ image_parse_setting: PDFImageParseSettings,
+ cache_folder: Path,
) -> Tuple[List[dict], Dict[int, List[Tuple[float]]]]:
- if image_analyzer is PDFImageParser.GEMINI_15_FLASH:
- from llmsearch.parsers.images.gemini_parser import GeminiImageAnalyzer
- analyzer = GeminiImageAnalyzer(model_name="gemini-1.5-flash")
+ analyzer = create_analyzer(
+ image_parse_setting.image_parser,
+ system_instruction=image_parse_setting.system_instruction,
+ user_instruction=image_parse_setting.user_instruction,
+ )
image_parser = GenericPDFImageParser(
pdf_fn=path,
temp_folder=cache_folder / "pdf_images_temp",
image_analyzer=analyzer,
- # image_analyzer=GeminiImageAnalyzer(model_name="gemini-1.5-pro-exp-0801")
)
-
extracted_images = image_parser.extract_images()
parsed_images = image_parser.analyze_images_threaded(extracted_images)
@@ -167,8 +199,9 @@ def get_image_chunks(
img_bboxes = defaultdict(list)
for img in parsed_images:
- print(str(img.image_fn) + ".md")
- out_blocks += markdown_splitter(path=str(img.image_fn)[:-3] + ".md", max_chunk_size=max_size)
+ out_blocks += markdown_splitter(
+ path=str(img.image_fn).replace(".png", ".md"), max_chunk_size=max_size
+ )
img_bboxes[img.page_num].append(img.bbox)
return out_blocks, img_bboxes
@@ -179,7 +212,7 @@ def get_image_chunks(
res = get_image_chunks(
path=Path("/home/snexus/Downloads/Graph_Example2.pdf"),
max_size=1024,
- image_analyzer=PDFImageParser.GEMINI_15_FLASH,
+ image_parse_setting=PDFImageParseSettings(image_parser= PDFImageParser.GEMINI_15_PRO),
cache_folder=Path("./output_images"),
)
diff --git a/src/llmsearch/parsers/pdf.py b/src/llmsearch/parsers/pdf.py
index 419baf8..911405f 100644
--- a/src/llmsearch/parsers/pdf.py
+++ b/src/llmsearch/parsers/pdf.py
@@ -6,7 +6,7 @@
from loguru import logger
from langchain_text_splitters import CharacterTextSplitter
-from llmsearch.parsers.tables.generic import boxes_intersect
+from llmsearch.parsers.tables.generic import do_boxes_intersect
class PDFSplitter:
@@ -149,7 +149,7 @@ def filter_blocks(blocks: List[Tuple[float, float, float, float, str]],
skip_block = False
for filter_bbox in page_table_bboxes:
- if boxes_intersect(filter_bbox, block_bbox):
+ if do_boxes_intersect(filter_bbox, block_bbox):
# We found an intersection, set the flag and break the inner loop
skip_block = True
# print(f"SKipping block: {block}")
diff --git a/src/llmsearch/parsers/tables/generic.py b/src/llmsearch/parsers/tables/generic.py
index fba65c5..fd1f62c 100644
--- a/src/llmsearch/parsers/tables/generic.py
+++ b/src/llmsearch/parsers/tables/generic.py
@@ -1,13 +1,37 @@
from collections import defaultdict
+import importlib
from pathlib import Path
-from typing import Dict, List, Tuple
+from typing import Dict, List, Tuple, Any
import pandas as pd
from loguru import logger
-
from abc import ABC, abstractmethod
-
from llmsearch.config import PDFTableParser
+# Define a mapping of PDFImageParser to corresponding analyzer classes and config
+PARSER_MAPPING: Dict[PDFTableParser, Any] = {
+ PDFTableParser.GMFT: {
+ "import_path": "llmsearch.parsers.tables.gmft_parser", # Import path for lazy loading
+ "class_name": "GMFTParser",
+ "params": {},
+ },
+ # Add more analyzers here as needed
+ # PDFImageParser.ANOTHER_TYPE: {'import_path': 'another.module.path', 'class_name': 'AnotherAnalyzer', 'params': {'param1': value1, 'param2': value2}},
+}
+
+
+def create_table_parser(table_parser: PDFTableParser, filename: Path):
+ parser_info = PARSER_MAPPING.get(table_parser)
+
+ if parser_info is None:
+ raise ValueError(f"Unsupported table parser type: {table_parser}")
+
+ # Lazy load the module
+ module = importlib.import_module(parser_info["import_path"])
+ parser_class = getattr(module, parser_info["class_name"])
+ additional_parser_params = parser_info["params"]
+
+ return parser_class(fn = filename, **additional_parser_params)
+
class GenericParsedTable(ABC):
def __init__(self, page_number: int, bbox: Tuple[float, float, float, float]):
@@ -17,42 +41,41 @@ def __init__(self, page_number: int, bbox: Tuple[float, float, float, float]):
@property
@abstractmethod
def df(self) -> pd.DataFrame:
- """Returns Pandas DF corresponding to a table"""
+ """Returns a Pandas DataFrame corresponding to a table."""
pass
@property
@abstractmethod
def caption(self) -> str:
- """Returns caption of the table"""
+ """Returns the caption of the table."""
pass
@property
@abstractmethod
def xml(self) -> List[str]:
- """Returns xml representation of the table"""
+ """Returns XML representation of the table."""
pass
def pandas_df_to_xml(df: pd.DataFrame) -> List[str]:
- """Converts Pandas df to a simplified xml representation digestible by LLMs
+ """Converts a Pandas DataFrame to a simplified XML representation.
Args:
- df (pd.DataFrame): Pandas df
+ df (pd.DataFrame): The DataFrame to convert.
Returns:
- str: List of xml row strings representing the dataframe
+ List[str]: List of XML row strings representing the DataFrame.
"""
def func(row):
xml = [""]
for field in row.index:
- xml.append(' {1}'.format(field, row[field]))
+ xml.append(f' {row[field]}')
xml.append("
")
return "\n".join(xml)
items = df.apply(func, axis=1).tolist()
return items
- # return "\n".join(items)
def pdf_table_splitter(
@@ -60,15 +83,25 @@ def pdf_table_splitter(
max_size: int,
include_caption: bool = True,
max_caption_size_ratio: int = 4,
-):
+) -> List[Dict[str, Any]]:
+ """Splits a parsed table into manageable chunks.
+
+ Args:
+ parsed_table (GenericParsedTable): The parsed table instance.
+ max_size (int): Maximum size for each chunk.
+ include_caption (bool): Whether to include the table caption.
+ max_caption_size_ratio (int): Ratio to determine allowable caption size.
+
+ Returns:
+ List[Dict[str, Any]]: List of text chunks with metadata.
+ """
xml_elements = parsed_table.xml
caption = parsed_table.caption
metadata = {"page": parsed_table.page_num, "source_chunk_type": "table"}
-
all_chunks = []
- # If caption is too long, trim it down, so there is some space for actual data
+ # Trim caption if it's too long
if len(caption) > max_size / max_caption_size_ratio:
logger.warning(
"Caption is too large compared to max char size, trimming down..."
@@ -79,25 +112,19 @@ def pdf_table_splitter(
if include_caption and caption:
header = f"Table below contains information about: {caption}\n" + header
- footer = f"```"
-
+ footer = "```"
current_text = header
for el in xml_elements:
-
- # If new element is too big, trim it (shouldn't happen)
if len(el) > max_size:
logger.warning(
- "xml element is larger than allowed max char size. Flushing.."
- )
- # el = el[:max_size-len(header)-3]
- all_chunks.append(
- {"text": current_text + el + footer, "metadata": metadata}
+ "XML element is larger than allowed max char size. Flushing.."
)
+ all_chunks.append({"text": current_text + footer, "metadata": metadata})
+ all_chunks.append({"text": header + el + footer, "metadata": metadata})
current_text = header
-
- # if current text is already large and doesn't fit the new element, flush it
elif len(current_text + el) >= max_size:
- all_chunks.append({"text": current_text + footer, "metadata": metadata})
+ if current_text != header:
+ all_chunks.append({"text": current_text + footer, "metadata": metadata})
current_text = header + el + "\n"
else:
current_text += el + "\n"
@@ -106,66 +133,63 @@ def pdf_table_splitter(
all_chunks.append({"text": current_text + footer, "metadata": metadata})
return all_chunks
-def boxes_intersect(box1: Tuple[float, float, float, float], box2: Tuple[float, float, float, float]) -> bool:
- """
- Check if two bounding boxes intersect.
- Parameters:
- box1: Tuple (x1_min, y1_min, x1_max, y1_max)
- box2: Tuple (x2_min, y2_min, x2_max, y2_max)
+def do_boxes_intersect(
+ box1: Tuple[float, float, float, float], box2: Tuple[float, float, float, float]
+) -> bool:
+ """Check if two bounding boxes intersect.
+
+ Args:
+ box1 (Tuple[float, float, float, float]): First bounding box.
+ box2 (Tuple[float, float, float, float]): Second bounding box.
Returns:
- True if the boxes intersect, False otherwise.
+ bool: True if the boxes intersect, False otherwise.
"""
-
- # Unpack the box coordinates
x1_min, y1_min, x1_max, y1_max = box1
x2_min, y2_min, x2_max, y2_max = box2
- # Check for non-intersection
- if x1_max < x2_min or x2_max < x1_min:
- return False
- if y1_max < y2_min or y2_max < y1_min:
- return False
+ return not (
+ x1_max < x2_min or x2_max < x1_min or y1_max < y2_min or y2_max < y1_min
+ )
- # If none of the non-intersection conditions are met, they must intersect
- return True
def get_table_chunks(
- path: Path, max_size: int, table_parser: PDFTableParser, format_extensions = ("pdf",)
-) -> Tuple[List[dict], Dict[int, List[Tuple[float]]]]:
- """Parses tables from the document using specified table_splitter
+ path: Path,
+ max_size: int,
+ table_parser: PDFTableParser,
+ format_extensions: Tuple[str, ...] = (".pdf",),
+) -> Tuple[List[Dict[str, Any]], Dict[int, List[Tuple[float, float, float, float]]]]:
+ """Parses tables from a document and splits them into chunks.
Args:
- path (Path): document path
- max_size (int): Maximum chunk size to split by
- table_splitter (PDFTableParser): name of the table splitter
- """
+ path (Path): Document path.
+ max_size (int): Maximum chunk size to split by.
+ table_parser (PDFTableParser): Table parser to use.
+ format_extensions (Tuple[str, ...]): Supported file formats for parsing.
+ Returns:
+ Tuple[List[Dict[str, Any]], Dict[int, List[Tuple[float, float, float, float]]]]:
+ A tuple with the list of table chunks and a dictionary of bounding boxes.
+ """
table_chunks = []
- extension = str(path).strip("/")[-3:]
- if extension not in format_extensions:
+ extension = path.suffix.lower()
+ if extension not in format_extensions:
logger.info(f"Format {extension} doesn't support table parsing..Skipping..")
- return list(), dict()
+ return [], {}
- if table_parser is PDFTableParser.GMFT:
- from llmsearch.parsers.tables.gmft_parser import GMFTParser
- parser = GMFTParser(fn=path)
- splitter = pdf_table_splitter
- else:
- raise TypeError(f"Unknown table parser: {table_parser}")
+ parser = create_table_parser(table_parser, filename=path)
logger.info("Parsing tables..")
-
parsed_tables = parser.parsed_tables
logger.info(f"Parsed {len(parsed_tables)} tables. Chunking...")
for parsed_table in parsed_tables:
- table_chunks += splitter(parsed_table, max_size=max_size)
+ table_chunks.extend(pdf_table_splitter(parsed_table, max_size=max_size))
- # Extract tables bounding boxes and store in a convenient data structure.
+ # Extract bounding boxes
table_bboxes = defaultdict(list)
for table in parsed_tables:
table_bboxes[table.page_num].append(table.bbox)
- return table_chunks, table_bboxes
\ No newline at end of file
+ return table_chunks, table_bboxes
diff --git a/src/llmsearch/parsers/tables/gmft_parser.py b/src/llmsearch/parsers/tables/gmft_parser.py
index e1d29bb..2c11226 100644
--- a/src/llmsearch/parsers/tables/gmft_parser.py
+++ b/src/llmsearch/parsers/tables/gmft_parser.py
@@ -5,7 +5,6 @@
from gmft import (
CroppedTable,
TableDetector,
- AutoFormatConfig,
AutoTableFormatter,
)
from pathlib import Path
@@ -15,132 +14,224 @@
from llmsearch.parsers.tables.generic import (
pandas_df_to_xml,
GenericParsedTable,
- pdf_table_splitter,
)
-class TableFormatterSingleton:
- """Singleton for table formatter"""
+class XMLConverter:
+ """Converts Pandas DataFrames to XML format."""
+
+ @staticmethod
+ def convert(df: pd.DataFrame) -> List[str]:
+ """Converts a DataFrame to a list of XML strings.
+
+ Args:
+ df (pd.DataFrame): The DataFrame to convert.
+
+ Returns:
+ List[str]: A list of XML strings representing the DataFrame.
+ """
+ return pandas_df_to_xml(df)
+
- _instance = None
+class ExtractionError(Exception):
+ """Custom exception for extraction failures."""
+ pass
+
+
+@dataclass
+class PageTables:
+ """Holds cropped tables extracted from a specific page of a document."""
+ page_num: int
+ cropped_tables: List[CroppedTable]
+
+ @property
+ def n_tables(self) -> int:
+ """Returns the number of cropped tables extracted from the page."""
+ return len(self.cropped_tables)
+
+
+class TableFormatterSingleton:
+ """Singleton class for managing a single instance of AutoTableFormatter."""
+
+ _instance: Optional['TableFormatterSingleton'] = None
formatter = None
def __new__(cls, *args, **kwargs):
- if not cls._instance:
+ """Creates a new instance if one does not already exist."""
+ if cls._instance is None:
logger.info("Initializing AutoTableFormatter...")
- cls._instance = super(TableFormatterSingleton, cls).__new__(cls)
+ cls._instance = super().__new__(cls)
cls._instance.formatter = AutoTableFormatter()
return cls._instance
class GMFTParsedTable(GenericParsedTable):
- def __init__(self, table: CroppedTable, page_num: int) -> None:
- super().__init__(
- page_number=page_num, bbox=table.bbox
- ) # Initialize the field from the abstract class
- self._table = table
- self.failed = False
- self.formatter: AutoTableFormatter = TableFormatterSingleton().formatter
+ """Represents a parsed table with its metadata and data extraction logic."""
- # Formatter is passed externally
- # self.formatter = formatter
+ def __init__(self, table: CroppedTable, page_num: int, formatter: AutoTableFormatter) -> None:
+ """Initializes the parsed table with a cropped table, page number, and formatter.
+
+ Args:
+ table (CroppedTable): The cropped table to parse.
+ page_num (int): The page number where the table is found.
+ formatter (AutoTableFormatter): The formatter to be used for extraction.
+ """
+ super().__init__(page_number=page_num, bbox=table.bbox)
+ self._table = table # Store the cropped table
+ self.failed = False # Track extraction failures
+ self.formatter = formatter # Formatter for extracting data
@cached_property
def _captions(self) -> List[str]:
- # return ""
+ """Caches and returns a list of non-empty captions from the table."""
return [c for c in self._table.captions() if c.strip()]
@cached_property
def caption(self) -> str:
+ """Returns a unique string of all captions, combined into one."""
return "\n".join(set(self._captions))
@property
def df(self) -> Optional[pd.DataFrame]:
- ft = self.formatter.extract(self._table)
+ """Attempts to extract a DataFrame from the cropped table.
+
+ Returns:
+ Optional[pd.DataFrame]: The extracted DataFrame or None if extraction fails.
+
+ Raises:
+ ExtractionError: If extraction fails, this error will be raised.
+ """
+ ft = self.formatter.extract(self._table) # Use the formatter to extract the table
try:
- df = ft.df()
+ return ft.df() # Return the DataFrame
except ValueError as ex:
logger.error(f"Couldn't extract df on page {self.page_num}: {str(ex)}")
self.failed = True
return None
-
- # config = AutoFormatConfig()
- # config.total_overlap_reject_threshold = 0.8
- # config.large_table_threshold = 0
-
- # try:
- # logger.info("\tTrying to reover")
- # df = ft.df(config_overrides = config)
- # except ValueError:
- # logger.error(f"\tCouldn't recover, page {self.page_num}: {str(ex)}")
- # return None
-
- return df
+ # raise ExtractionError(f"Extraction failed on page {self.page_num}")
@property
def xml(self) -> List[str]:
+ """Converts the extracted DataFrame to XML format.
+
+ Returns:
+ List[str]: A list of XML strings. Returns an empty list if df extraction failed.
+ """
if self.df is None:
- return list()
- return pandas_df_to_xml(self.df)
+ return []
+ return XMLConverter.convert(self.df)
-@dataclass
-class PageTables:
- page_num: int
- cropped_tables: List[CroppedTable]
+class DocumentHandler:
+ """Handles loading a PDF document and providing access to its pages."""
- @property
- def n_tables(self):
- return len(self.cropped_tables)
+ def __init__(self, path: Path):
+ """Initializes the DocumentHandler with a path to a PDF.
+ Args:
+ path (Path): The file path to the PDF document.
+ """
+ self.doc = PyPDFium2Document(path) # Load the document using PyPDFium2
-class GMFTParser:
- def __init__(self, fn: Path) -> None:
- self.fn = fn
- self._doc = None
- self._parsed_tables = None
+ def get_pages(self) -> Any:
+ """Returns an iterable of pages from the loaded document."""
+ return self.doc
- # logger.info("Initializing Table Formatter.")
- # self.formatter = AutoTableFormatter()
- def detect_page_tables(self) -> Tuple[List[PageTables], Any]:
- """Detects tables in a document and returns list of page tables"""
+class TableDetectorHelper:
+ """Facilitates detection of tables within document pages."""
- logger.info("Detecting tables...")
- doc = PyPDFium2Document(self.fn)
- detector = TableDetector()
- pt = []
+ def __init__(self):
+ """Initializes the TableDetector to find tables."""
+ self.detector = TableDetector()
- for page in doc:
- pt.append(
- PageTables(
- page_num=page.page_number, cropped_tables=detector.extract(page)
- )
- )
+ def detect_tables(self, page: Any) -> List[CroppedTable]:
+ """Detects and returns cropped tables from a given page.
- return pt, doc
+ Args:
+ page (Any): The page from which to detect tables.
- @property
- def parsed_tables(self) -> List[GenericParsedTable]:
- if self._parsed_tables is None:
- page_tables, self._doc = self.detect_page_tables()
- logger.info("Parsing tables ...")
+ Returns:
+ List[CroppedTable]: A list of detected cropped tables.
+ """
+ return self.detector.extract(page)
+
+
+class TableParser:
+ """Parses cropped tables into GMFTParsedTable objects."""
+
+ def __init__(self, formatter: AutoTableFormatter):
+ """Initializes the TableParser with a formatter.
+
+ Args:
+ formatter (AutoTableFormatter): Formatter used for parsing tables.
+ """
+ self.formatter = formatter
- out_tables = []
+ def parse(self, cropped_table: CroppedTable, page_num: int) -> GMFTParsedTable:
+ """Parses a cropped table into a GMFTParsedTable instance.
- for page_table in page_tables:
- for cropped_table in page_table.cropped_tables:
- out_tables.append(
- GMFTParsedTable(cropped_table, page_table.page_num)
- )
- self._parsed_tables = out_tables
+ Args:
+ cropped_table (CroppedTable): The cropped table to parse.
+ page_num (int): The page number where the table is found.
+
+ Returns:
+ GMFTParsedTable: An instance of GMFTParsedTable containing the parsed data.
+ """
+ return GMFTParsedTable(cropped_table, page_num, self.formatter)
+
+
+class GMFTParser:
+ """Main class for handling the parsing of tables from a PDF document."""
+
+ def __init__(self, fn: Path) -> None:
+ """Initializes the parser with a PDF file path and prepares components.
+
+ Args:
+ fn (Path): The file path to the PDF document.
+ """
+ self.fn = fn
+ self.document_handler = DocumentHandler(fn) # Load the document
+ self.formatter = TableFormatterSingleton().formatter # Get the formatter
+ self.table_detector = TableDetectorHelper() # Initialize table detector
+ self.table_parser = TableParser(self.formatter) # Initialize table parser
+ self._parsed_tables: Optional[List[GMFTParsedTable]] = None # Cache for parsed tables
+
+ def detect_and_parse_tables(self) -> List[GMFTParsedTable]:
+ """Detects and parses tables from the PDF document.
+
+ Returns:
+ List[GMFTParsedTable]: A list of parsed tables.
+ """
+ logger.info("Detecting and parsing tables...")
+ detected_tables = []
+
+ # Iterate through the pages in the document
+ for page in self.document_handler.get_pages():
+ cropped_tables = self.table_detector.detect_tables(page) # Detect tables on the page
+ # Parse each cropped table found on the page
+ for cropped_table in cropped_tables:
+ parsed_table = self.table_parser.parse(cropped_table, page.page_number)
+ detected_tables.append(parsed_table) # Store the parsed table
+
+ return detected_tables
+
+ @property
+ def parsed_tables(self) -> List[GMFTParsedTable]:
+ """Lazy-loads the parsed tables when requested.
+
+ Returns:
+ List[GMFTParsedTable]: A list of parsed tables from the document.
+ """
+ if self._parsed_tables is None:
+ self._parsed_tables = self.detect_and_parse_tables() # Detect and parse tables if not done already
return self._parsed_tables
if __name__ == "__main__":
# fn = Path("/home/snexus/Downloads/ws90.pdf")
# fn = Path("/home/snexus/Downloads/SSRN-id2741701.pdf")
- fn = Path("/home/snexus/Downloads/Table_Example1.pdf")
+ fn = Path("/home/snexus/Downloads/ws90.pdf")
parser = GMFTParser(fn=fn)
for p in parser.parsed_tables:
diff --git a/tests/test_table_splitting.py b/tests/test_table_splitting.py
new file mode 100644
index 0000000..b01e837
--- /dev/null
+++ b/tests/test_table_splitting.py
@@ -0,0 +1,105 @@
+import pytest
+from unittest.mock import MagicMock
+from llmsearch.parsers.tables.generic import pdf_table_splitter # Replace with the actual module name
+
+@pytest.fixture
+def setup_parsed_table():
+ """Fixture to create a mock parsed table for testing."""
+ dummy_bbox = (0.0, 0.0, 100.0, 100.0) # Dummy bounding box
+ parsed_table = MagicMock()
+ parsed_table.page_num = 1
+ parsed_table.bbox = dummy_bbox
+ parsed_table.caption = ""
+ return parsed_table
+
+def test_basic_functionality(setup_parsed_table):
+ parsed_table = setup_parsed_table
+ parsed_table.xml = [
+ "1
",
+ "2
"
+ ]
+ expected_output = [
+ {
+ "text": "```xml table:\n1
\n2
\n```",
+ "metadata": {"page": 1, "source_chunk_type": "table"}
+ }
+ ]
+
+ result = pdf_table_splitter(parsed_table, max_size=100) # Adjust max size as needed
+ print(result)
+ assert result == expected_output
+
+def test_caption_inclusion(setup_parsed_table):
+ parsed_table = setup_parsed_table
+ parsed_table.xml = ["1
"]
+ parsed_table.caption = "This is a test caption."
+
+ expected_output = [
+ {
+ "text": "Table below contains information about: This is a test caption.\n```xml table:\n1
\n```",
+ "metadata": {"page": 1, "source_chunk_type": "table"}
+ }
+ ]
+
+ result = pdf_table_splitter(parsed_table, max_size=100)
+ assert result == expected_output
+
+def test_caption_trimming(setup_parsed_table):
+ parsed_table = setup_parsed_table
+ parsed_table.xml = ["1
"]
+ parsed_table.caption = "A very long caption that exceeds the size limit."
+
+ expected_output = [
+ {
+ "text": "Table below contains information about: A very long capt\n```xml table:\n1
\n```",
+ "metadata": {"page": 1, "source_chunk_type": "table"}
+ }
+ ]
+
+ result = pdf_table_splitter(parsed_table, max_size=50, max_caption_size_ratio=3)
+ print(result)
+ assert result == expected_output
+
+def test_element_larger_than_max_size(setup_parsed_table):
+ parsed_table = setup_parsed_table
+ parsed_table.xml = [
+ "1
",
+ "2
"
+ ]
+ long_element = "" + "" + "X" * 200 + "
" # Very long element
+ parsed_table.xml.append(long_element)
+
+ result = pdf_table_splitter(parsed_table, max_size=100)
+ print(result)
+ # There should be one chunk for the first two elements and a separate chunk for the long element
+ assert len(result) == 3
+
+def test_empty_input(setup_parsed_table):
+ parsed_table = setup_parsed_table
+ parsed_table.xml = []
+ parsed_table.caption = ""
+
+ result = pdf_table_splitter(parsed_table, max_size=100)
+ print(result)
+ assert result == [
+ {
+ "text": "```xml table:\n```",
+ "metadata": {"page": 1, "source_chunk_type": "table"}
+ }
+ ]
+
+def test_single_element(setup_parsed_table):
+ parsed_table = setup_parsed_table
+ parsed_table.xml = ["1
"]
+
+ result = pdf_table_splitter(parsed_table, max_size=150)
+ assert len(result) == 1
+
+def test_multiple_elements_within_limit(setup_parsed_table):
+ parsed_table = setup_parsed_table
+ parsed_table.xml = [
+ "1
",
+ "2
"
+ ]
+ result = pdf_table_splitter(parsed_table, max_size=250)
+ assert len(result) == 1
\ No newline at end of file