Skip to content

Commit

Permalink
Integrate new texify model
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 29, 2025
1 parent 5fdb25b commit 597db72
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 80 deletions.
85 changes: 8 additions & 77 deletions marker/processors/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class EquationProcessor(BaseProcessor):
model_max_length: Annotated[
int,
"The maximum number of tokens to allow for the Texify model.",
] = 384
] = 768
texify_batch_size: Annotated[
Optional[int],
"The batch size to use for the Texify model.",
Expand Down Expand Up @@ -65,27 +65,7 @@ def __call__(self, document: Document):
continue

block = document.get_block(equation_d["block_id"])
block.html = self.parse_latex_to_html(prediction)

def parse_latex_to_html(self, latex: str):
html_out = ""
try:
latex = self.parse_latex(latex)
except ValueError as e:
# If we have mismatched delimiters, we'll treat it as a single block
# Strip the $'s from the latex
latex = [
{"class": "block", "content": latex.replace("$", "")}
]

for el in latex:
if el["class"] == "block":
html_out += f'<math display="block">{el["content"]}</math>'
elif el["class"] == "inline":
html_out += f'<math display="inline">{el["content"]}</math>'
else:
html_out += f" {el['content']} "
return html_out.strip()
block.html = prediction

def get_batch_size(self):
if self.texify_batch_size is not None:
Expand All @@ -106,71 +86,22 @@ def get_latex_batched(self, equation_data: List[dict]):
max_idx = min(min_idx + batch_size, len(equation_data))

batch_equations = equation_data[min_idx:max_idx]
max_length = max([eq["token_count"] for eq in batch_equations])
max_length = min(max_length, self.model_max_length)
max_length += self.token_buffer

batch_images = [eq["image"] for eq in batch_equations]

model_output = self.texify_model(
batch_images,
max_tokens=max_length
batch_images
)

for j, output in enumerate(model_output):
token_count = self.get_total_texify_tokens(output)
if token_count >= max_length - 1:
output = ""
token_count = self.get_total_texify_tokens(output.text)
if token_count >= self.model_max_length - 1:
output.text = ""

image_idx = i + j
predictions[image_idx] = output
predictions[image_idx] = output.text
return predictions

def get_total_texify_tokens(self, text):
tokenizer = self.texify_model.processor.tokenizer
tokens = tokenizer(text)
return len(tokens["input_ids"])


@staticmethod
def parse_latex(text: str):
if text.count("$") % 2 != 0:
raise ValueError("Mismatched delimiters in LaTeX")

DELIMITERS = [
("$$", "block"),
("$", "inline")
]

text = text.replace("\n", "<br>") # we can't handle \n's inside <p> properly if we don't do this

i = 0
stack = []
result = []
buffer = ""

while i < len(text):
for delim, class_name in DELIMITERS:
if text[i:].startswith(delim):
if stack and stack[-1] == delim: # Closing
stack.pop()
result.append({"class": class_name, "content": buffer})
buffer = ""
i += len(delim)
break
elif not stack: # Opening
if buffer:
result.append({"class": "text", "content": buffer})
stack.append(delim)
buffer = ""
i += len(delim)
break
else:
raise ValueError(f"Nested {class_name} delimiters not supported")
else: # No delimiter match
buffer += text[i]
i += 1

if buffer:
result.append({"class": "text", "content": buffer})
return result
return len(tokens["input_ids"])
11 changes: 8 additions & 3 deletions marker/renderers/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
from marker.schema.document import Document


def escape_dollars(text):
return text.replace("$", r"\$")

def cleanup_text(full_text):
full_text = re.sub(r'\n{3,}', '\n\n', full_text)
full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
return full_text.strip()

def get_formatted_table_text(element):

text = []
for content in element.contents:
if content is None:
Expand All @@ -26,13 +30,14 @@ def get_formatted_table_text(element):
if isinstance(content, NavigableString):
stripped = content.strip()
if stripped:
text.append(stripped)
text.append(escape_dollars(stripped))
elif content.name == 'br':
text.append('<br>')
elif content.name == "math":
text.append("$" + content.text + "$")
else:
text.append(str(content))
content_str = escape_dollars(str(content))
text.append(content_str)

full_text = ""
for i, t in enumerate(text):
Expand Down Expand Up @@ -120,7 +125,7 @@ def convert_table(self, el, text, convert_as_inline):
if r == 0 and c == 0:
grid[row_idx][col_idx] = value
else:
grid[row_idx + r][col_idx + c] = ''
grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan
except IndexError:
# Sometimes the colspan/rowspan predictions can overflow
print(f"Overflow in columns: {col_idx + c} >= {total_cols}")
Expand Down

0 comments on commit 597db72

Please sign in to comment.