diff --git a/marker/processors/equation.py b/marker/processors/equation.py
index 868f98a2..5f5be17c 100644
--- a/marker/processors/equation.py
+++ b/marker/processors/equation.py
@@ -22,7 +22,7 @@ class EquationProcessor(BaseProcessor):
model_max_length: Annotated[
int,
"The maximum number of tokens to allow for the Texify model.",
- ] = 384
+ ] = 768
texify_batch_size: Annotated[
Optional[int],
"The batch size to use for the Texify model.",
@@ -65,27 +65,7 @@ def __call__(self, document: Document):
continue
block = document.get_block(equation_d["block_id"])
- block.html = self.parse_latex_to_html(prediction)
-
- def parse_latex_to_html(self, latex: str):
- html_out = ""
- try:
- latex = self.parse_latex(latex)
- except ValueError as e:
- # If we have mismatched delimiters, we'll treat it as a single block
- # Strip the $'s from the latex
- latex = [
- {"class": "block", "content": latex.replace("$", "")}
- ]
-
- for el in latex:
- if el["class"] == "block":
- html_out += f''
- elif el["class"] == "inline":
- html_out += f''
- else:
- html_out += f" {el['content']} "
- return html_out.strip()
+ block.html = prediction
def get_batch_size(self):
if self.texify_batch_size is not None:
@@ -106,71 +86,22 @@ def get_latex_batched(self, equation_data: List[dict]):
max_idx = min(min_idx + batch_size, len(equation_data))
batch_equations = equation_data[min_idx:max_idx]
- max_length = max([eq["token_count"] for eq in batch_equations])
- max_length = min(max_length, self.model_max_length)
- max_length += self.token_buffer
-
batch_images = [eq["image"] for eq in batch_equations]
model_output = self.texify_model(
- batch_images,
- max_tokens=max_length
+ batch_images
)
for j, output in enumerate(model_output):
- token_count = self.get_total_texify_tokens(output)
- if token_count >= max_length - 1:
- output = ""
+ token_count = self.get_total_texify_tokens(output.text)
+ if token_count >= self.model_max_length - 1:
+ output.text = ""
image_idx = i + j
- predictions[image_idx] = output
+ predictions[image_idx] = output.text
return predictions
def get_total_texify_tokens(self, text):
tokenizer = self.texify_model.processor.tokenizer
tokens = tokenizer(text)
- return len(tokens["input_ids"])
-
-
- @staticmethod
- def parse_latex(text: str):
- if text.count("$") % 2 != 0:
- raise ValueError("Mismatched delimiters in LaTeX")
-
- DELIMITERS = [
- ("$$", "block"),
- ("$", "inline")
- ]
-
- text = text.replace("\n", "
") # we can't handle \n's inside
properly if we don't do this
-
- i = 0
- stack = []
- result = []
- buffer = ""
-
- while i < len(text):
- for delim, class_name in DELIMITERS:
- if text[i:].startswith(delim):
- if stack and stack[-1] == delim: # Closing
- stack.pop()
- result.append({"class": class_name, "content": buffer})
- buffer = ""
- i += len(delim)
- break
- elif not stack: # Opening
- if buffer:
- result.append({"class": "text", "content": buffer})
- stack.append(delim)
- buffer = ""
- i += len(delim)
- break
- else:
- raise ValueError(f"Nested {class_name} delimiters not supported")
- else: # No delimiter match
- buffer += text[i]
- i += 1
-
- if buffer:
- result.append({"class": "text", "content": buffer})
- return result
+ return len(tokens["input_ids"])
\ No newline at end of file
diff --git a/marker/renderers/markdown.py b/marker/renderers/markdown.py
index 0762ab3c..f6b78933 100644
--- a/marker/renderers/markdown.py
+++ b/marker/renderers/markdown.py
@@ -12,12 +12,16 @@
from marker.schema.document import Document
+def escape_dollars(text):
+ return text.replace("$", r"\$")
+
def cleanup_text(full_text):
full_text = re.sub(r'\n{3,}', '\n\n', full_text)
full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
return full_text.strip()
def get_formatted_table_text(element):
+
text = []
for content in element.contents:
if content is None:
@@ -26,13 +30,14 @@ def get_formatted_table_text(element):
if isinstance(content, NavigableString):
stripped = content.strip()
if stripped:
- text.append(stripped)
+ text.append(escape_dollars(stripped))
elif content.name == 'br':
text.append('
')
elif content.name == "math":
text.append("$" + content.text + "$")
else:
- text.append(str(content))
+ content_str = escape_dollars(str(content))
+ text.append(content_str)
full_text = ""
for i, t in enumerate(text):
@@ -120,7 +125,7 @@ def convert_table(self, el, text, convert_as_inline):
if r == 0 and c == 0:
grid[row_idx][col_idx] = value
else:
- grid[row_idx + r][col_idx + c] = ''
+ grid[row_idx + r][col_idx + c] = '' # Empty cell due to rowspan/colspan
except IndexError:
# Sometimes the colspan/rowspan predictions can overflow
print(f"Overflow in columns: {col_idx + c} >= {total_cols}")