Skip to content

Commit

Permalink
Flush CUDA memory after inference
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 9, 2024
1 parent 6933022 commit 4966f7a
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 65 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,11 @@ First 3 are non-arXiv books, last 3 are arXiv papers.
| marker | 0.536176 | 0.516833 | 0.70515 | 0.710657 | 0.690042 | 0.523467 |
| nougat | 0.44009 | 0.588973 | 0.322706 | 0.401342 | 0.160842 | 0.525663 |

Peak GPU memory usage during the benchmark is `4.2GB` for nougat, and `6.5GB` for marker. Benchmarks were run on an A6000 Ada.
Peak GPU memory usage during the benchmark is `4.2GB` for nougat, and `5.1GB` for marker. Benchmarks were run on an A6000 Ada.

**Throughput**

Marker takes about 3GB of VRAM on average per task, so you can convert 16 documents in parallel on an A6000.
Marker takes about 4GB of VRAM on average per task, so you can convert 12 documents in parallel on an A6000.

![Benchmark results](data/images/per_doc.png)

Expand Down
43 changes: 27 additions & 16 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,22 @@
configure_logging()


def start_memory_profiling():
torch.cuda.memory._record_memory_history(
max_entries=100000
)


def stop_memory_profiling(memory_file):
try:
torch.cuda.memory._dump_snapshot(memory_file)
except Exception as e:
logger.error(f"Failed to capture memory snapshot {e}")

# Stop recording memory snapshot history.
torch.cuda.memory._record_memory_history(enabled=None)


def nougat_prediction(pdf_filename, batch_size=1):
out_dir = tempfile.mkdtemp()
subprocess.run(["nougat", pdf_filename, "-o", out_dir, "--no-skipping", "--recompute", "--batchsize", str(batch_size)], check=True)
Expand All @@ -42,28 +58,28 @@ def main():
parser.add_argument("--nougat_batch_size", type=int, default=1, help="Batch size to use for nougat when making predictions.")
parser.add_argument("--md_out_path", type=str, default=None, help="Output path for generated markdown files")
parser.add_argument("--profile_memory", action="store_true", help="Profile memory usage", default=False)
parser.add_argument("--profile_memory_file", type=str, default="benchmark_memory.pickle", help="File to save memory profile to")

args = parser.parse_args()

methods = ["marker"]
if args.nougat:
methods.append("nougat")

if args.profile_memory:
start_memory_profiling()

model_lst = load_all_models()

if args.profile_memory:
stop_memory_profiling("model_load.pickle")

scores = defaultdict(dict)
benchmark_files = os.listdir(args.in_folder)
benchmark_files = [b for b in benchmark_files if b.endswith(".pdf")]
times = defaultdict(dict)
pages = defaultdict(int)

if args.profile_memory:
torch.cuda.memory._record_memory_history(
max_entries=100000
)

for fname in tqdm(benchmark_files):
for idx, fname in tqdm(enumerate(benchmark_files)):
md_filename = fname.rsplit(".", 1)[0] + ".md"

reference_filename = os.path.join(args.reference_folder, md_filename)
Expand All @@ -77,7 +93,11 @@ def main():
for method in methods:
start = time.time()
if method == "marker":
if args.profile_memory:
start_memory_profiling()
full_text, _, out_meta = convert_single_pdf(pdf_filename, model_lst, batch_multiplier=args.marker_batch_multiplier)
if args.profile_memory:
stop_memory_profiling(f"marker_memory_{idx}.pickle")
elif method == "nougat":
full_text = nougat_prediction(pdf_filename, batch_size=args.nougat_batch_size)
elif method == "naive":
Expand Down Expand Up @@ -119,15 +139,6 @@ def main():

json.dump(write_data, f, indent=4)

if args.profile_memory:
try:
torch.cuda.memory._dump_snapshot(args.profile_memory_file)
except Exception as e:
logger.error(f"Failed to capture memory snapshot {e}")

# Stop recording memory snapshot history.
torch.cuda.memory._record_memory_history(enabled=None)

summary_table = []
score_table = []
score_headers = benchmark_files
Expand Down
9 changes: 9 additions & 0 deletions marker/convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import warnings

from marker.utils import flush_cuda_memory

warnings.filterwarnings("ignore", category=UserWarning) # Filter torch pytree user warnings

import pypdfium2 as pdfium
Expand Down Expand Up @@ -75,16 +78,19 @@ def convert_single_pdf(

# Identify text lines on pages
surya_detection(doc, pages, detection_model, batch_multiplier=batch_multiplier)
flush_cuda_memory()

# OCR pages as needed
pages, ocr_stats = run_ocr(doc, pages, langs, ocr_model, batch_multiplier=batch_multiplier)
flush_cuda_memory()

out_meta["ocr_stats"] = ocr_stats
if len([b for p in pages for b in p.blocks]) == 0:
print(f"Could not extract any text blocks for {fname}")
return "", out_meta

surya_layout(doc, pages, layout_model, batch_multiplier=batch_multiplier)
flush_cuda_memory()

# Find headers and footers
bad_span_ids = filter_header_footer(pages)
Expand All @@ -100,6 +106,7 @@ def convert_single_pdf(
# Sort blocks by reading order
surya_order(doc, pages, order_model, batch_multiplier=batch_multiplier)
sort_blocks_in_reading_order(pages)
flush_cuda_memory()

# Fix code blocks
code_block_count = identify_code_blocks(pages)
Expand All @@ -121,6 +128,7 @@ def convert_single_pdf(
texify_model,
batch_multiplier=batch_multiplier
)
flush_cuda_memory()
out_meta["block_stats"]["equations"] = eq_stats

# Extract images and figures
Expand Down Expand Up @@ -149,6 +157,7 @@ def convert_single_pdf(
edit_model,
batch_multiplier=batch_multiplier
)
flush_cuda_memory()
out_meta["postprocess_stats"] = {"edit": edit_stats}
doc_images = images_to_dict(pages)

Expand Down
2 changes: 1 addition & 1 deletion marker/layout/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_batch_size():
if settings.ORDER_BATCH_SIZE is not None:
return settings.ORDER_BATCH_SIZE
elif settings.TORCH_DEVICE_MODEL == "cuda":
return 12
return 6
elif settings.TORCH_DEVICE_MODEL == "mps":
return 6
return 6
Expand Down
4 changes: 2 additions & 2 deletions marker/ocr/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def get_batch_size():
if settings.DETECTOR_BATCH_SIZE is not None:
return settings.DETECTOR_BATCH_SIZE
elif settings.TORCH_DEVICE_MODEL == "cuda":
return 6
return 6
return 4
return 4


def surya_detection(doc: PdfDocument, pages: List[Page], det_model, batch_multiplier=1):
Expand Down
6 changes: 3 additions & 3 deletions marker/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
return "cpu"

INFERENCE_RAM: int = 40 # How much VRAM each GPU has (in GB).
VRAM_PER_TASK: float = 3 # How much VRAM to allocate per task (in GB). Peak marker VRAM usage is around 3.5GB, but avg across workers is lower.
VRAM_PER_TASK: float = 4 # How much VRAM to allocate per task (in GB). Peak marker VRAM usage is around 5GB, but avg across workers is lower.
DEFAULT_LANG: str = "English" # Default language we assume files to be in, should be one of the keys in TESSERACT_LANGUAGES

SUPPORTED_FILETYPES: Dict = {
Expand Down Expand Up @@ -57,7 +57,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
TEXIFY_MODEL_MAX: int = 384 # Max inference length for texify
TEXIFY_TOKEN_BUFFER: int = 256 # Number of tokens to buffer above max for texify
TEXIFY_DPI: int = 96 # DPI to render images at
TEXIFY_BATCH_SIZE: int = 2 if TORCH_DEVICE_MODEL == "cpu" else 6 # Batch size for texify, lower on cpu due to float32
TEXIFY_BATCH_SIZE: Optional[int] = None # Defaults to 6 for cuda, 12 otherwise
TEXIFY_MODEL_NAME: str = "vikp/texify"

# Layout model
Expand All @@ -72,7 +72,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
ORDER_BATCH_SIZE: Optional[int] = None # Defaults to 12 for cuda, 6 otherwise

# Final editing model
EDITOR_BATCH_SIZE: int = 4
EDITOR_BATCH_SIZE: Optional[int] = None # Defaults to 6 for cuda, 12 otherwise
EDITOR_MAX_LENGTH: int = 1024
EDITOR_MODEL_NAME: str = "vikp/pdf_postprocessor_t5"
ENABLE_EDITOR_MODEL: bool = False # The editor model can create false positives
Expand Down
7 changes: 7 additions & 0 deletions marker/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch
from marker.settings import settings


def flush_cuda_memory():
if settings.TORCH_DEVICE_MODEL == "cuda":
torch.cuda.empty_cache()
Loading

0 comments on commit 4966f7a

Please sign in to comment.