-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4ca60f4
commit 0164959
Showing
10 changed files
with
428 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import click | ||
|
||
from marker.batch.messages import MESSAGE_QUEUES | ||
from marker.batch.storage import STORAGE_CLIENTS | ||
|
||
|
||
class DynamicOptionCommand(click.Command): | ||
def parse_args(self, ctx, args): | ||
# First parse with original options to get storage-client and message-queue | ||
parser = self.make_parser(ctx) | ||
opts, _, _ = parser.parse_args(args=list(args)) | ||
|
||
# Get the selected storage client and message queue from parsed options | ||
storage_client = opts.get('storage_client') | ||
message_queue = opts.get('message_queue') | ||
|
||
# If storage client is selected, add its specific options | ||
if storage_client and storage_client in STORAGE_CLIENTS: | ||
client_cls = STORAGE_CLIENTS[storage_client] | ||
for option in client_cls.config_options(): | ||
self.params.append(option) | ||
|
||
# If message queue is selected, add its specific options | ||
if message_queue and message_queue in MESSAGE_QUEUES: | ||
queue_cls = MESSAGE_QUEUES[message_queue] | ||
for option in queue_cls.config_options(): | ||
self.params.append(option) | ||
|
||
# Re-parse with all options including the dynamic ones | ||
return super().parse_args(ctx, args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import multiprocessing as mp | ||
import sys | ||
import tempfile | ||
import traceback | ||
from concurrent.futures import ProcessPoolExecutor | ||
from typing import Optional | ||
|
||
import click | ||
|
||
from marker.batch.cli_dynamic import DynamicOptionCommand | ||
from marker.batch.messages import MESSAGE_QUEUES | ||
from marker.batch.schema import Message | ||
from marker.batch.storage import STORAGE_CLIENTS | ||
from marker.batch.storage.base import StorageClient | ||
from marker.config.parser import ConfigParser | ||
from marker.models import create_model_dict | ||
from marker.output import save_output | ||
|
||
_worker_storage_client: Optional[StorageClient] = None | ||
_model_refs: Optional[dict] = None | ||
|
||
|
||
def initialize_worker(kwargs: dict) -> None: | ||
global _worker_storage_client, _model_refs | ||
try: | ||
_worker_storage_client = STORAGE_CLIENTS[kwargs['storage_client_name']](**kwargs) | ||
except Exception as e: | ||
print(f"Failed to initialize storage client in worker: {e}") | ||
sys.exit(1) | ||
|
||
try: | ||
_model_refs = create_model_dict() | ||
except Exception as e: | ||
print(f"Failed to initialize model in worker: {e}") | ||
sys.exit(1) | ||
|
||
|
||
def process_single_pdf(key, input_path, output_path, cli_options): | ||
config_parser = ConfigParser(cli_options) | ||
converter_cls = config_parser.get_converter_cls() | ||
|
||
try: | ||
converter = converter_cls( | ||
config=config_parser.generate_config_dict(), | ||
artifact_dict=_model_refs, | ||
processor_list=config_parser.get_processors(), | ||
renderer=config_parser.get_renderer() | ||
) | ||
rendered = converter(input_path) | ||
save_output(rendered, output_path, key) | ||
except Exception as e: | ||
print(f"Error converting {input_path}: {e}") | ||
print(traceback.format_exc()) | ||
|
||
|
||
def process_message(message: Message) -> None: | ||
if _worker_storage_client is None: | ||
raise RuntimeError("Storage client is not initialized in the worker process.") | ||
if _model_refs is None: | ||
raise RuntimeError("Model is not initialized in the worker process.") | ||
|
||
print(f"Processing message: {message}") | ||
with tempfile.NamedTemporaryFile(delete=True) as tmp_file: | ||
temp_file_path = tmp_file.name | ||
_worker_storage_client.download_file(message.input_path, temp_file_path) | ||
|
||
with tempfile.TemporaryDirectory(delete=False) as temp_dir: | ||
process_single_pdf(message.key, temp_file_path, temp_dir, { | ||
"output_format": "markdown", | ||
"disable_multiprocessing": True | ||
}) | ||
|
||
print(f"Uploading output to {message.output_path}") | ||
_worker_storage_client.upload_folder(temp_dir, message.output_path) | ||
|
||
print(f"Finished processing message: {message}") | ||
|
||
|
||
@click.command(cls=DynamicOptionCommand, context_settings=dict( | ||
ignore_unknown_options=True, | ||
allow_extra_args=True, | ||
)) | ||
@click.option("--storage-client", type=click.Choice(STORAGE_CLIENTS.keys()), help="Storage client to use", required=True) # type: ignore | ||
@click.option("--message-queue", type=click.Choice(MESSAGE_QUEUES.keys()), help="Message queue to use", required=True) # type: ignore | ||
@click.option("--job-id", type=str, help="Job ID", required=True) | ||
@click.option("--worker-count", type=int, help="Number of workers to use", default=1) | ||
def cli(storage_client, message_queue, worker_count, **kwargs): | ||
message_queue = MESSAGE_QUEUES[message_queue](**kwargs) | ||
|
||
initializer_kwargs = { | ||
'storage_client_name': storage_client, | ||
**kwargs | ||
} | ||
|
||
try: | ||
mp.set_start_method('spawn') # Required for CUDA, forkserver doesn't work | ||
except RuntimeError: | ||
raise RuntimeError("Set start method to spawn twice. This may be a temporary issue with the script. Please try running it again.") | ||
|
||
model_dict = create_model_dict() | ||
for k, v in model_dict.items(): | ||
v.model.share_memory() | ||
|
||
with ProcessPoolExecutor( | ||
max_workers=worker_count, | ||
initializer=initialize_worker, | ||
initargs=(initializer_kwargs,) | ||
) as executor: | ||
|
||
def callback(msg): | ||
message = Message.from_json(msg.body) | ||
future = executor.submit(process_message, message) | ||
future.add_done_callback(lambda f: message_queue.ack(msg, f)) | ||
|
||
message_queue.consume(callback, worker_count) | ||
|
||
|
||
if __name__ == "__main__": | ||
cli() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from marker.batch.messages.rabbitmq import RabbitMQMessageQueue | ||
|
||
MESSAGE_QUEUES = { | ||
cls.name: cls for cls in [RabbitMQMessageQueue] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class MessageQueue(ABC): | ||
@abstractmethod | ||
def publish(self, topic: str, message: str): | ||
"""Publish a message to a topic/queue.""" | ||
pass | ||
|
||
@abstractmethod | ||
def consume(self, topic: str, callback): | ||
""" | ||
Consume messages from a topic/queue. | ||
A callback function is called for each message. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from concurrent.futures import Future | ||
|
||
import amqpstorm # type: ignore | ||
import click | ||
|
||
from marker.batch.messages.base import MessageQueue | ||
from marker.batch.schema import Message | ||
|
||
|
||
class RabbitMQMessageQueue(MessageQueue): | ||
name = "rabbitmq" | ||
|
||
@staticmethod | ||
def config_options(): | ||
return [ | ||
click.Option(["--rabbitmq-host"], required=True, help="RabbitMQ host"), | ||
click.Option(["--rabbitmq-port"], required=False, help="RabbitMQ port", default=5672), | ||
click.Option(["--rabbitmq-username"], required=False, help="RabbitMQ username", default=None), | ||
click.Option(["--rabbitmq-password"], required=False, help="RabbitMQ password", default=None), | ||
click.Option(["--rabbitmq-vhost"], required=False, help="RabbitMQ vhost", default='/'), | ||
click.Option(["--rabbitmq-heartbeat"], required=False, help="RabbitMQ heartbeat", default=60), | ||
click.Option(["--rabbitmq-use-ssl"], is_flag=True, help="Use SSL", default=False), | ||
] | ||
|
||
def __init__( | ||
self, | ||
rabbitmq_host, | ||
rabbitmq_port, | ||
rabbitmq_username, | ||
rabbitmq_password, | ||
rabbitmq_vhost, | ||
rabbitmq_heartbeat, | ||
rabbitmq_use_ssl, | ||
job_id, | ||
**kwargs | ||
): | ||
super().__init__() | ||
self.connection = amqpstorm.Connection( | ||
hostname=rabbitmq_host, | ||
username=rabbitmq_username, | ||
password=rabbitmq_password, | ||
port=rabbitmq_port, | ||
heartbeat=rabbitmq_heartbeat, | ||
virtual_host=rabbitmq_vhost, | ||
ssl=rabbitmq_use_ssl, | ||
) | ||
self.channel = self.connection.channel() | ||
|
||
self.queue_name = f"{job_id}-queue" | ||
self.channel.queue.declare(queue=self.queue_name, durable=True) | ||
|
||
print(f"Initializing RabbitMQ message queue with host={rabbitmq_host}, port={rabbitmq_port}, username={rabbitmq_username}, vhost={rabbitmq_vhost}") | ||
|
||
def publish(self, message: Message) -> None: | ||
msg = amqpstorm.Message.create(self.channel, message.to_json(), {"delivery_mode": 2, "content_type": "application/json"}) | ||
msg.publish(routing_key=self.queue_name) | ||
|
||
def consume(self, callback, prefetch_count=1) -> None: | ||
self.channel.basic.qos(prefetch_count=prefetch_count) | ||
try: | ||
self.channel.basic.consume( | ||
callback=callback, | ||
queue=self.queue_name, | ||
no_ack=False | ||
) | ||
print('[*] Waiting for messages. To exit press CTRL+C') | ||
self.channel.start_consuming() | ||
except KeyboardInterrupt: | ||
print("Interrupted by user") | ||
self.channel.stop_consuming() | ||
finally: | ||
self.connection.close() | ||
|
||
def ack(self, msg: amqpstorm.Message, f: Future) -> None: | ||
try: | ||
f.result() | ||
msg.ack() | ||
except Exception as e: | ||
print(f"Failed message: {e}") | ||
msg.reject(requeue=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from pathlib import Path | ||
|
||
import click | ||
import tqdm # type: ignore | ||
|
||
from marker.batch.cli_dynamic import DynamicOptionCommand | ||
from marker.batch.messages import MESSAGE_QUEUES | ||
from marker.batch.schema import Message | ||
from marker.batch.storage import STORAGE_CLIENTS | ||
|
||
|
||
@click.command(cls=DynamicOptionCommand, context_settings=dict( | ||
ignore_unknown_options=True, | ||
allow_extra_args=True, | ||
)) | ||
@click.option("--storage-client", type=click.Choice(STORAGE_CLIENTS.keys()), help="Storage client to use", required=True) # type: ignore | ||
@click.option("--message-queue", type=click.Choice(MESSAGE_QUEUES.keys()), help="Message queue to use", required=True) # type: ignore | ||
@click.option("--input-path", type=str, help="List messages with this prefix, example s3://bucket/prefix", default=None) | ||
@click.option("--output-path", type=str, help="Write the outputs to this prefix, example s3://bucket/prefix", default=None) | ||
@click.option("--job-id", type=str, help="Job ID", required=True) | ||
def cli(storage_client, message_queue, input_path, output_path, **kwargs): | ||
storage_client = STORAGE_CLIENTS[storage_client](**kwargs) | ||
message_queue = MESSAGE_QUEUES[message_queue](**kwargs) | ||
storage_prefix = storage_client.name + "://" | ||
|
||
for msg_input_path in tqdm.tqdm(storage_client.list_files(input_path), desc="Publishing messages"): | ||
msg_path = Path(msg_input_path.removeprefix(storage_prefix)) | ||
msg_key = msg_path.stem | ||
|
||
relative_path = msg_path.parent.relative_to(input_path.removeprefix(storage_prefix)) | ||
msg_output_path = output_path.removeprefix(storage_prefix) / relative_path / msg_key | ||
|
||
message_queue.publish( | ||
Message( | ||
key=msg_key, | ||
input_path=storage_prefix + str(msg_path), | ||
output_path=storage_prefix + str(msg_output_path) + "/" | ||
) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
cli() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from dataclasses import dataclass | ||
from dataclasses_json import dataclass_json | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class Message: | ||
key: str | ||
input_path: str | ||
output_path: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from marker.batch.storage.s3 import S3StorageClient | ||
|
||
STORAGE_CLIENTS = { | ||
cls.name: cls for cls in [S3StorageClient] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from marker.batch.schema import Message | ||
|
||
|
||
class StorageClient(ABC): | ||
name: str | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def config_options(): | ||
pass | ||
|
||
@abstractmethod | ||
def upload_file(self, file_path: str, upload_path: str) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def download_file(self, file_path: str, download_path: str) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def upload_folder(self, folder_path: str, upload_path: str) -> None: | ||
pass | ||
|
||
@abstractmethod | ||
def list_files(self, path: str) -> list[Message]: | ||
pass |
Oops, something went wrong.