Skip to content

Commit

Permalink
Add support for Gemini Gemini PRO image parser
Browse files Browse the repository at this point in the history
  • Loading branch information
DL committed Aug 19, 2024
1 parent 914a415 commit a944f79
Show file tree
Hide file tree
Showing 8 changed files with 469 additions and 208 deletions.
18 changes: 10 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
llama-cpp-python==0.2.78
chromadb~=0.5
langchain~=0.2.4
langchain-community~=0.2.4
langchain-openai~=0.1.8
llama-cpp-python==0.2.88
chromadb~=0.5.5
langchain~=0.2.14
langchain-community~=0.2.12
langchain-openai~=0.1.22
langchain-huggingface~=0.0.3
pydantic~=2.7
transformers~=4.41
Expand All @@ -16,14 +16,14 @@ python-dotenv
accelerate~=0.33
protobuf==3.20.2
termcolor
openai~=1.34.0
openai~=1.41
einops # required for Mosaic models
click
bitsandbytes==0.43.1
# auto-gptq==0.2.0
InstructorEmbedding==1.0.1
unstructured~=0.14.5
pymupdf==1.22.5
pymupdf==1.24.9
streamlit~=1.28
python-docx~=1.1
six==1.16.0 ; python_version >= "3.10" and python_version < "4.0"
Expand All @@ -36,4 +36,6 @@ threadpoolctl==3.1.0 ; python_version >= "3.10" and python_version < "4.0"
tiktoken==0.7.0 ; python_version >= "3.10" and python_version < "4.0"
tokenizers==0.19.1; python_version >= "3.10" and python_version < "4.0"
tqdm==4.65.0 ; python_version >= "3.10" and python_version < "4.0"
# transformers==4.29.2 ; python_version >= "3.10" and python_version < "4.0"
# transformers==4.29.2 ; python_version >= "3.10" and python_version < "4.0"
gmft==0.2.1
google-generativeai~=0.7
11 changes: 10 additions & 1 deletion src/llmsearch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,16 @@ class PDFTableParser(str, Enum):

class PDFImageParser(str, Enum):
GEMINI_15_FLASH = "gemini-1.5-flash"
GEMINI_15_PRO= "gemini-1.5-pro"

class PDFImageParseSettings(BaseModel):
image_parser: PDFImageParser
system_instruction: str = """You are an research assistant. You analyze the image to extract detailed information. Response must be a Markdown string in the follwing format:
- First line is a heading with image caption, starting with '# '
- Second line is empty
- From the third line on - detailed data points and related metadata, extracted from the image, in Markdown format. Don't use Markdown tables.
"""
user_instruction: str = """From the image, extract detailed quantitative and qualitative data points."""

class EmbeddingModelType(str, Enum):
huggingface = "huggingface"
Expand Down Expand Up @@ -92,7 +101,7 @@ class DocumentPathSettings(BaseModel):
pdf_table_parser: Optional[PDFTableParser] = None
"""If enabled, will parse tables in pdf files using a specific of a parser."""

pdf_image_parser: Optional[PDFImageParser] = None
pdf_image_parser: Optional[PDFImageParseSettings] = None
"""If enabled, will parse images in pdf files using a specific of a parser."""

additional_parser_settings: Dict[str, Any] = Field(default_factory=dict)
Expand Down
13 changes: 5 additions & 8 deletions src/llmsearch/parsers/images/gemini_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,15 @@ class GeminiImageAnalyzer:
def __init__(
self,
model_name: str,
instruction: str = """From the image, extract detailed quantitative and qualitative data points.""",
system_instruction: str,
user_instruction: str
):
self.model_name = model_name
self.instruction = instruction
self.instruction = user_instruction
print(system_instruction, user_instruction)
self.model = genai.GenerativeModel(
model_name,
system_instruction="""You are an research assistant. You analyze the image to extract detailed information. Response must be a Markdown string in the follwing format:
- First line is a heading with image caption, starting with '# '
- Second line is empty
- From the third line on - detailed data points and related metadata, extracted from the image, in Markdown format. Don't use Markdown tables.
""",
system_instruction = system_instruction,
generation_config=genai.types.GenerationConfig(
# Only one candidate for now.
candidate_count=1,
Expand Down
139 changes: 86 additions & 53 deletions src/llmsearch/parsers/images/generic.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,50 @@
from collections import defaultdict
import importlib
import io
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Callable

import PIL.Image
import pymupdf
from loguru import logger
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random_exponential

from llmsearch.config import PDFImageParser
from llmsearch.config import PDFImageParseSettings, PDFImageParser
from llmsearch.parsers.markdown import markdown_splitter

# Define a mapping of PDFImageParser to corresponding analyzer classes and config
ANALYZER_MAPPING: Dict[PDFImageParser, Any] = {
PDFImageParser.GEMINI_15_FLASH: {
"import_path": "llmsearch.parsers.images.gemini_parser", # Import path for lazy loading
"class_name": "GeminiImageAnalyzer",
"params": {"model_name": "gemini-1.5-flash"},
},

PDFImageParser.GEMINI_15_PRO: {
"import_path": "llmsearch.parsers.images.gemini_parser", # Import path for lazy loading
"class_name": "GeminiImageAnalyzer",
"params": {"model_name": "gemini-1.5-pro"},
},
}


def create_analyzer(image_analyzer: PDFImageParser, **additional_params):
analyzer_info = ANALYZER_MAPPING.get(image_analyzer)

if analyzer_info is None:
raise ValueError(f"Unsupported image analyzer type: {image_analyzer}")

# Lazy load the module
module = importlib.import_module(analyzer_info["import_path"])
analyzer_class = getattr(module, analyzer_info["class_name"])
analyzer_params = analyzer_info["params"]

params = {**analyzer_params, **additional_params}

return analyzer_class(**params)


class PDFImage(BaseModel):
image_fn: Path
Expand All @@ -26,57 +58,48 @@ def __init__(
self,
pdf_fn: Path,
temp_folder: Path,
image_analyzer,
save_output=True,
image_analyzer: Callable,
save_output: bool = True,
max_base_width: int = 1280,
min_width: int = 640,
min_height: int = 200,
):
self.pdf_fn = pdf_fn
self.max_base_width = max_base_width
self.temp_folder = temp_folder
self.min_width = min_width
self.min_height = min_height
self.image_analyzer = image_analyzer
self.save_output = save_output
self.max_base_width = max_base_width
self.min_width = min_width
self.min_height = min_height

def prepare_and_clean_folder(self):
# Check if the folder exists
if not self.temp_folder.exists():
# Create the folder if it doesn't exist
self.temp_folder.mkdir(parents=True, exist_ok=True)
logger.info(f"Created folder: {self.temp_folder}")
else:
for file in self.temp_folder.iterdir():
if file.is_file():
file.unlink() # Delete the file
logger.info(f"Deleted file: {file}")
file.unlink()
logger.debug(f"Deleted file: {file}")

def extract_images(self) -> List[PDFImage]:
self.prepare_and_clean_folder()

doc = pymupdf.open(self.pdf_fn)
out_images = []

for page in doc:
page_images = page.get_images()
for img in page_images:
for img in page.get_images():
xref = img[0]
data = doc.extract_image(xref=xref)
out_fn = self._resize_and_save_image(
data=data,
page_num=page.number,
xref_num=xref,
)
if out_fn is not None:
data = doc.extract_image(xref)
out_fn = self._resize_and_save_image(data, page.number, xref)
if out_fn:
out_images.append(
PDFImage(
image_fn=out_fn,
page_num=page.number,
bbox=(img[1], img[2], img[3], img[4]),
)
)

return out_images

def _resize_and_save_image(
Expand All @@ -85,30 +108,32 @@ def _resize_and_save_image(
page_num: int,
xref_num: int,
) -> Optional[Path]:

image = data.get("image", None)
if image is None:
image_data = data.get("image")
if not image_data:
return

with PIL.Image.open(io.BytesIO(image)) as img:
with PIL.Image.open(io.BytesIO(image_data)) as img:
if img.size[1] < self.min_height or img.size[0] < self.min_width:
logger.info(
logger.debug(
f"Image on page {page_num}, xref {xref_num} is too small. Skipping extraction..."
)
return None
wpercent = self.max_base_width / float(img.size[0])

# Resize the image, if needed
wpercent = self.max_base_width / float(img.size[0])
if wpercent < 1:
hsize = int((float(img.size[1]) * float(wpercent)))
hsize = int(float(img.size[1]) * wpercent)
img = img.resize(
(self.max_base_width, hsize), PIL.Image.Resampling.LANCZOS
)

out_fn = self.temp_folder / (str(self.pdf_fn.stem) + f"_page_{page_num}_xref_{xref_num}.png")
logger.info(f"Saving file: {out_fn}")
img.save(out_fn, mode="wb")
return Path(out_fn)
out_fn = (
self.temp_folder
/ f"{self.pdf_fn.stem}_page_{page_num}_xref_{xref_num}.png"
)
logger.debug(f"Saving file: {out_fn}")
img.save(out_fn)

return out_fn

def analyze_images_threaded(
self, extracted_images: List[PDFImage], max_threads: int = 10
Expand All @@ -117,58 +142,66 @@ def analyze_images_threaded(
results = pool.starmap(
analyze_single_image,
[
(pdf_image, self.image_analyzer, i)
for i, pdf_image in enumerate(extracted_images)
(img, self.image_analyzer, i)
for i, img in enumerate(extracted_images)
],
)

if self.save_output:
for r in results:
with open(str(r.image_fn)[:-3] + ".md", "w") as file:
file.write(r.markdown)
for result in results:
with open(str(result.image_fn).replace(".png", ".md"), "w") as file:
file.write(result.markdown)

return results


def log_attempt_number(retry_state):
"""return the result of the last call attempt"""
logger.error(f"API call attempt failed. Retrying: {retry_state.attempt_number}...")
error_message = str(retry_state.outcome.exception())
logger.error(
f"API call attempt {retry_state.attempt_number} failed with error: {error_message}. Retrying..."
)
# logger.error(f"API call attempt failed. Retrying: {retry_state.attempt_number}...")


@retry(
wait=wait_random_exponential(min=5, max=60),
stop=stop_after_attempt(6),
after=log_attempt_number,
)
def analyze_single_image(pdf_image: PDFImage, image_analyzer, i: int) -> PDFImage:
fn = pdf_image.image_fn
pdf_image.markdown = image_analyzer.analyze(fn)
def analyze_single_image(
pdf_image: PDFImage, image_analyzer: Callable, i: int
) -> PDFImage:
pdf_image.markdown = image_analyzer.analyze(pdf_image.image_fn)
return pdf_image


def get_image_chunks(
path: Path, max_size: int, image_analyzer: PDFImageParser, cache_folder: Path
path: Path,
max_size: int,
image_parse_setting: PDFImageParseSettings,
cache_folder: Path,
) -> Tuple[List[dict], Dict[int, List[Tuple[float]]]]:
if image_analyzer is PDFImageParser.GEMINI_15_FLASH:
from llmsearch.parsers.images.gemini_parser import GeminiImageAnalyzer
analyzer = GeminiImageAnalyzer(model_name="gemini-1.5-flash")

analyzer = create_analyzer(
image_parse_setting.image_parser,
system_instruction=image_parse_setting.system_instruction,
user_instruction=image_parse_setting.user_instruction,
)
image_parser = GenericPDFImageParser(
pdf_fn=path,
temp_folder=cache_folder / "pdf_images_temp",
image_analyzer=analyzer,
# image_analyzer=GeminiImageAnalyzer(model_name="gemini-1.5-pro-exp-0801")
)

extracted_images = image_parser.extract_images()
parsed_images = image_parser.analyze_images_threaded(extracted_images)

out_blocks = []
img_bboxes = defaultdict(list)

for img in parsed_images:
print(str(img.image_fn) + ".md")
out_blocks += markdown_splitter(path=str(img.image_fn)[:-3] + ".md", max_chunk_size=max_size)
out_blocks += markdown_splitter(
path=str(img.image_fn).replace(".png", ".md"), max_chunk_size=max_size
)
img_bboxes[img.page_num].append(img.bbox)

return out_blocks, img_bboxes
Expand All @@ -179,7 +212,7 @@ def get_image_chunks(
res = get_image_chunks(
path=Path("/home/snexus/Downloads/Graph_Example2.pdf"),
max_size=1024,
image_analyzer=PDFImageParser.GEMINI_15_FLASH,
image_parse_setting=PDFImageParseSettings(image_parser= PDFImageParser.GEMINI_15_PRO),
cache_folder=Path("./output_images"),
)

Expand Down
4 changes: 2 additions & 2 deletions src/llmsearch/parsers/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from loguru import logger
from langchain_text_splitters import CharacterTextSplitter

from llmsearch.parsers.tables.generic import boxes_intersect
from llmsearch.parsers.tables.generic import do_boxes_intersect


class PDFSplitter:
Expand Down Expand Up @@ -149,7 +149,7 @@ def filter_blocks(blocks: List[Tuple[float, float, float, float, str]],
skip_block = False

for filter_bbox in page_table_bboxes:
if boxes_intersect(filter_bbox, block_bbox):
if do_boxes_intersect(filter_bbox, block_bbox):
# We found an intersection, set the flag and break the inner loop
skip_block = True
# print(f"SKipping block: {block}")
Expand Down
Loading

0 comments on commit a944f79

Please sign in to comment.