Skip to content

Commit

Permalink
add batch processing pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Jan 29, 2025
1 parent 4ca60f4 commit 0164959
Show file tree
Hide file tree
Showing 10 changed files with 428 additions and 0 deletions.
30 changes: 30 additions & 0 deletions marker/batch/cli_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import click

from marker.batch.messages import MESSAGE_QUEUES
from marker.batch.storage import STORAGE_CLIENTS


class DynamicOptionCommand(click.Command):
def parse_args(self, ctx, args):
# First parse with original options to get storage-client and message-queue
parser = self.make_parser(ctx)
opts, _, _ = parser.parse_args(args=list(args))

# Get the selected storage client and message queue from parsed options
storage_client = opts.get('storage_client')
message_queue = opts.get('message_queue')

# If storage client is selected, add its specific options
if storage_client and storage_client in STORAGE_CLIENTS:
client_cls = STORAGE_CLIENTS[storage_client]
for option in client_cls.config_options():
self.params.append(option)

# If message queue is selected, add its specific options
if message_queue and message_queue in MESSAGE_QUEUES:
queue_cls = MESSAGE_QUEUES[message_queue]
for option in queue_cls.config_options():
self.params.append(option)

# Re-parse with all options including the dynamic ones
return super().parse_args(ctx, args)
119 changes: 119 additions & 0 deletions marker/batch/consume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import multiprocessing as mp
import sys
import tempfile
import traceback
from concurrent.futures import ProcessPoolExecutor
from typing import Optional

import click

from marker.batch.cli_dynamic import DynamicOptionCommand
from marker.batch.messages import MESSAGE_QUEUES
from marker.batch.schema import Message
from marker.batch.storage import STORAGE_CLIENTS
from marker.batch.storage.base import StorageClient
from marker.config.parser import ConfigParser
from marker.models import create_model_dict
from marker.output import save_output

_worker_storage_client: Optional[StorageClient] = None
_model_refs: Optional[dict] = None


def initialize_worker(kwargs: dict) -> None:
global _worker_storage_client, _model_refs
try:
_worker_storage_client = STORAGE_CLIENTS[kwargs['storage_client_name']](**kwargs)
except Exception as e:
print(f"Failed to initialize storage client in worker: {e}")
sys.exit(1)

try:
_model_refs = create_model_dict()
except Exception as e:
print(f"Failed to initialize model in worker: {e}")
sys.exit(1)


def process_single_pdf(key, input_path, output_path, cli_options):
config_parser = ConfigParser(cli_options)
converter_cls = config_parser.get_converter_cls()

try:
converter = converter_cls(
config=config_parser.generate_config_dict(),
artifact_dict=_model_refs,
processor_list=config_parser.get_processors(),
renderer=config_parser.get_renderer()
)
rendered = converter(input_path)
save_output(rendered, output_path, key)
except Exception as e:
print(f"Error converting {input_path}: {e}")
print(traceback.format_exc())


def process_message(message: Message) -> None:
if _worker_storage_client is None:
raise RuntimeError("Storage client is not initialized in the worker process.")
if _model_refs is None:
raise RuntimeError("Model is not initialized in the worker process.")

print(f"Processing message: {message}")
with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
temp_file_path = tmp_file.name
_worker_storage_client.download_file(message.input_path, temp_file_path)

with tempfile.TemporaryDirectory(delete=False) as temp_dir:
process_single_pdf(message.key, temp_file_path, temp_dir, {
"output_format": "markdown",
"disable_multiprocessing": True
})

print(f"Uploading output to {message.output_path}")
_worker_storage_client.upload_folder(temp_dir, message.output_path)

print(f"Finished processing message: {message}")


@click.command(cls=DynamicOptionCommand, context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
))
@click.option("--storage-client", type=click.Choice(STORAGE_CLIENTS.keys()), help="Storage client to use", required=True) # type: ignore
@click.option("--message-queue", type=click.Choice(MESSAGE_QUEUES.keys()), help="Message queue to use", required=True) # type: ignore
@click.option("--job-id", type=str, help="Job ID", required=True)
@click.option("--worker-count", type=int, help="Number of workers to use", default=1)
def cli(storage_client, message_queue, worker_count, **kwargs):
message_queue = MESSAGE_QUEUES[message_queue](**kwargs)

initializer_kwargs = {
'storage_client_name': storage_client,
**kwargs
}

try:
mp.set_start_method('spawn') # Required for CUDA, forkserver doesn't work
except RuntimeError:
raise RuntimeError("Set start method to spawn twice. This may be a temporary issue with the script. Please try running it again.")

model_dict = create_model_dict()
for k, v in model_dict.items():
v.model.share_memory()

with ProcessPoolExecutor(
max_workers=worker_count,
initializer=initialize_worker,
initargs=(initializer_kwargs,)
) as executor:

def callback(msg):
message = Message.from_json(msg.body)
future = executor.submit(process_message, message)
future.add_done_callback(lambda f: message_queue.ack(msg, f))

message_queue.consume(callback, worker_count)


if __name__ == "__main__":
cli()
5 changes: 5 additions & 0 deletions marker/batch/messages/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from marker.batch.messages.rabbitmq import RabbitMQMessageQueue

MESSAGE_QUEUES = {
cls.name: cls for cls in [RabbitMQMessageQueue]
}
16 changes: 16 additions & 0 deletions marker/batch/messages/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod


class MessageQueue(ABC):
@abstractmethod
def publish(self, topic: str, message: str):
"""Publish a message to a topic/queue."""
pass

@abstractmethod
def consume(self, topic: str, callback):
"""
Consume messages from a topic/queue.
A callback function is called for each message.
"""
pass
80 changes: 80 additions & 0 deletions marker/batch/messages/rabbitmq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from concurrent.futures import Future

import amqpstorm # type: ignore
import click

from marker.batch.messages.base import MessageQueue
from marker.batch.schema import Message


class RabbitMQMessageQueue(MessageQueue):
name = "rabbitmq"

@staticmethod
def config_options():
return [
click.Option(["--rabbitmq-host"], required=True, help="RabbitMQ host"),
click.Option(["--rabbitmq-port"], required=False, help="RabbitMQ port", default=5672),
click.Option(["--rabbitmq-username"], required=False, help="RabbitMQ username", default=None),
click.Option(["--rabbitmq-password"], required=False, help="RabbitMQ password", default=None),
click.Option(["--rabbitmq-vhost"], required=False, help="RabbitMQ vhost", default='/'),
click.Option(["--rabbitmq-heartbeat"], required=False, help="RabbitMQ heartbeat", default=60),
click.Option(["--rabbitmq-use-ssl"], is_flag=True, help="Use SSL", default=False),
]

def __init__(
self,
rabbitmq_host,
rabbitmq_port,
rabbitmq_username,
rabbitmq_password,
rabbitmq_vhost,
rabbitmq_heartbeat,
rabbitmq_use_ssl,
job_id,
**kwargs
):
super().__init__()
self.connection = amqpstorm.Connection(
hostname=rabbitmq_host,
username=rabbitmq_username,
password=rabbitmq_password,
port=rabbitmq_port,
heartbeat=rabbitmq_heartbeat,
virtual_host=rabbitmq_vhost,
ssl=rabbitmq_use_ssl,
)
self.channel = self.connection.channel()

self.queue_name = f"{job_id}-queue"
self.channel.queue.declare(queue=self.queue_name, durable=True)

print(f"Initializing RabbitMQ message queue with host={rabbitmq_host}, port={rabbitmq_port}, username={rabbitmq_username}, vhost={rabbitmq_vhost}")

def publish(self, message: Message) -> None:
msg = amqpstorm.Message.create(self.channel, message.to_json(), {"delivery_mode": 2, "content_type": "application/json"})
msg.publish(routing_key=self.queue_name)

def consume(self, callback, prefetch_count=1) -> None:
self.channel.basic.qos(prefetch_count=prefetch_count)
try:
self.channel.basic.consume(
callback=callback,
queue=self.queue_name,
no_ack=False
)
print('[*] Waiting for messages. To exit press CTRL+C')
self.channel.start_consuming()
except KeyboardInterrupt:
print("Interrupted by user")
self.channel.stop_consuming()
finally:
self.connection.close()

def ack(self, msg: amqpstorm.Message, f: Future) -> None:
try:
f.result()
msg.ack()
except Exception as e:
print(f"Failed message: {e}")
msg.reject(requeue=False)
43 changes: 43 additions & 0 deletions marker/batch/publish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from pathlib import Path

import click
import tqdm # type: ignore

from marker.batch.cli_dynamic import DynamicOptionCommand
from marker.batch.messages import MESSAGE_QUEUES
from marker.batch.schema import Message
from marker.batch.storage import STORAGE_CLIENTS


@click.command(cls=DynamicOptionCommand, context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
))
@click.option("--storage-client", type=click.Choice(STORAGE_CLIENTS.keys()), help="Storage client to use", required=True) # type: ignore
@click.option("--message-queue", type=click.Choice(MESSAGE_QUEUES.keys()), help="Message queue to use", required=True) # type: ignore
@click.option("--input-path", type=str, help="List messages with this prefix, example s3://bucket/prefix", default=None)
@click.option("--output-path", type=str, help="Write the outputs to this prefix, example s3://bucket/prefix", default=None)
@click.option("--job-id", type=str, help="Job ID", required=True)
def cli(storage_client, message_queue, input_path, output_path, **kwargs):
storage_client = STORAGE_CLIENTS[storage_client](**kwargs)
message_queue = MESSAGE_QUEUES[message_queue](**kwargs)
storage_prefix = storage_client.name + "://"

for msg_input_path in tqdm.tqdm(storage_client.list_files(input_path), desc="Publishing messages"):
msg_path = Path(msg_input_path.removeprefix(storage_prefix))
msg_key = msg_path.stem

relative_path = msg_path.parent.relative_to(input_path.removeprefix(storage_prefix))
msg_output_path = output_path.removeprefix(storage_prefix) / relative_path / msg_key

message_queue.publish(
Message(
key=msg_key,
input_path=storage_prefix + str(msg_path),
output_path=storage_prefix + str(msg_output_path) + "/"
)
)


if __name__ == "__main__":
cli()
10 changes: 10 additions & 0 deletions marker/batch/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass
from dataclasses_json import dataclass_json


@dataclass_json
@dataclass
class Message:
key: str
input_path: str
output_path: str
5 changes: 5 additions & 0 deletions marker/batch/storage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from marker.batch.storage.s3 import S3StorageClient

STORAGE_CLIENTS = {
cls.name: cls for cls in [S3StorageClient]
}
28 changes: 28 additions & 0 deletions marker/batch/storage/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from abc import ABC, abstractmethod

from marker.batch.schema import Message


class StorageClient(ABC):
name: str

@staticmethod
@abstractmethod
def config_options():
pass

@abstractmethod
def upload_file(self, file_path: str, upload_path: str) -> None:
pass

@abstractmethod
def download_file(self, file_path: str, download_path: str) -> None:
pass

@abstractmethod
def upload_folder(self, folder_path: str, upload_path: str) -> None:
pass

@abstractmethod
def list_files(self, path: str) -> list[Message]:
pass
Loading

0 comments on commit 0164959

Please sign in to comment.