Skip to content

Commit

Permalink
Add spacy embedding-clustering splitter (#784)
Browse files Browse the repository at this point in the history
This PR implements the embedding-clustered splitting algorithm; some
basic sanity checks are added to handle edge cases. This PR does not yet
switch the default implementation of split_text.
  • Loading branch information
brilee authored Oct 23, 2023
1 parent a6442b6 commit 98ce890
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 32 deletions.
29 changes: 0 additions & 29 deletions lilac/splitters/chunk_splitter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,3 @@ def test_newlines_with_overlap() -> None:
expected_chunks = text_to_textchunk(
text, ['Hello.', 'World.', 'This will', 'will get', 'get split.'], allowable_overlap=5)
assert split_items == expected_chunks


def test_split_code() -> None:
text = """
We expected the entire code to be one span.
```python
def hello():
echo('hello')
```
This is the rest of the text.
"""
split_items = split_text(text, chunk_size=60, chunk_overlap=0)
expected_chunks = text_to_textchunk(text, [
"""
We expected the entire code to be one span.
""",
"""```python
def hello():
echo('hello')
```""",
"""
This is the rest of the text.
""",
])
assert split_items == expected_chunks
98 changes: 98 additions & 0 deletions lilac/splitters/spacy_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""SpaCy-based chunk splitting algorithms."""
import bisect
import functools
from typing import Callable, Optional, Sequence

import numpy as np
import spacy

from .chunk_splitter import TextChunk


@functools.cache
def get_spacy() -> spacy.Language:
"""Lazily instantiate and return a singeton SpaCy sentencizer object."""
sentencizer = spacy.blank('en')
# This includes colon as a sentence boundary; LLM datasets tend to contain a lot of semantic
# markers with a colon, like "Teacher: ... " or "Answer the following question: ..."
sentencizer.add_pipe('sentencizer', config={'punct_chars': [':', ';', '.', '!', '?']})
# Increase the number of characters of the tokenizer as we're not using a parser or NER.
sentencizer.max_length = 10_000_000
return sentencizer


def simple_spacy_chunker(text: str, filter_short: int = 4) -> list[TextChunk]:
"""Split text into sentence-based chunks, using SpaCy."""
sentencizer = get_spacy()
chunks = [
(text[s.start_char:s.end_char], (s.start_char, s.end_char)) for s in sentencizer(text).sents
]
# Filter out stray whitespace, list numberings, etc.
chunks = [c for c in chunks if len(c[0].strip()) > filter_short]
return chunks


def group_by_embedding(fulltext: str, chunks: list[TextChunk], embed_fn: Callable[[list[str]],
list[np.ndarray]],
target_num_groups: int, max_len: int) -> list[TextChunk]:
"""Take a series of smaller chunks and cluster them together.
Args:
fulltext: Full text.
chunks: Smaller chunks to combine.
embed_fn: A function mapping strings to an embedding vector.
target_num_groups: Target number of chunks in final output.
max_len: Maximum size of a combined chunk.
"""
embeddings = np.array(embed_fn([c[0] for c in chunks]))
# Center the embeddings for all sentences; this accentuates sentence semantics,
# especially if the entire passage is roughly about the same topic
embeddings -= np.mean(embeddings, axis=0)
embeddings += np.random.uniform(size=embeddings.shape) * 1e-6
embeddings /= np.linalg.norm(embeddings, axis=-1, keepdims=True)

neighbor_distances: Sequence[float] = (embeddings[:-1] * embeddings[1:]).sum(axis=1)
potential_breaks: np.ndarray = np.array([c[1][1] for c in chunks[:-1]]) # end index of each chunk
priority_sort_breaks: np.ndarray = potential_breaks[np.argsort(neighbor_distances)]

# If there are fewer sentences than target number of groups, then this should degrade gracefully.
breakpoints = [0] + sorted(priority_sort_breaks[:(target_num_groups - 1)]) + [chunks[-1][1][1]]

def _find_long_spans(breakpoints: list[int]) -> Optional[tuple[int, int]]:
for i, j in zip(breakpoints[:-1], breakpoints[1:]):
if j - i > max_len:
return (i, j)
return None

# Recursively break long spans until there are no more.
while (span := _find_long_spans(breakpoints)) is not None:
i, j = span
for potential_break in priority_sort_breaks:
if i < potential_break < j:
bisect.insort(breakpoints, potential_break)
break
else: # No potential breaker was found. Arbitrarily split the span in half.
bisect.insort(breakpoints, int((i + j) / 2))

return [
(fulltext[start:end], (start, end)) for start, end in zip(breakpoints[:-1], breakpoints[1:])
]


def clustering_spacy_chunker(
text: str,
filter_short: int = 4,
max_len: int = 512,
target_num_groups: Optional[int] = None,
embed_fn: Optional[Callable[[list[str]], list[np.ndarray]]] = None) -> list[TextChunk]:
"""Split text into sentence-based chunks, with semantic clustering to join related sentences."""
chunks = simple_spacy_chunker(text, filter_short=filter_short)
if embed_fn is None:
return chunks

if target_num_groups is None:
# A rough heuristic for picking a number of target chunks.
# These magic numbers were chosen by manually chunking 40 texts spanning 50-5000 characters in
# length, and eyeballing a best-fit line from #num chunks vs. #length on a log-log plot.
target_num_groups = max(1, int((len(text)**0.33) / 1.5))
return group_by_embedding(text, chunks, embed_fn, target_num_groups, max_len)
55 changes: 55 additions & 0 deletions lilac/splitters/spacy_splitter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Tests the spacy chunk splitter."""

import numpy as np

from .spacy_splitter import clustering_spacy_chunker, simple_spacy_chunker
from .text_splitter_test_utils import clean_textchunks, text_to_textchunk


def dummy_embbedder(chunks: list[str], embed_dim: int = 4) -> list[np.ndarray]:

def _single_embed(str: str) -> np.ndarray:
np.random.seed(hash(str) % (2**32 - 1))
return np.random.random(size=(1, embed_dim))

return np.concatenate([_single_embed(s) for s in chunks], axis=0)


def test_short_snippets_filtered() -> None:
text = '1. Hello. 2. World.'
expected_spans = text_to_textchunk(text, ['Hello.', 'World.'])

split_items = simple_spacy_chunker(text)
assert split_items == expected_spans


def test_colon_considered_as_splitter() -> None:
text = 'Teacher: Tell me the answer. Student: I have no idea.'
expected_spans = text_to_textchunk(
text, ['Teacher:', 'Tell me the answer.', 'Student:', 'I have no idea.'])
split_items = simple_spacy_chunker(text)
assert split_items == expected_spans


def test_long_spans_default_split() -> None:
text = 'Blah blah blah.'
expected_spans = text_to_textchunk(text, ['Blah bl', 'ah blah.'])

split_items = clustering_spacy_chunker(text, embed_fn=dummy_embbedder, max_len=8)
assert split_items == expected_spans


def test_long_spans_preferred_splits() -> None:
text = 'Blah. blah. bla. bl.'
expected_spans = text_to_textchunk(text, ['Blah.', 'blah.', 'bla.', 'bl.'])
# Even though target_num_groups = 1, the max len constraint causes breaking.
split_items = clustering_spacy_chunker(
text, embed_fn=dummy_embbedder, target_num_groups=1, max_len=6, filter_short=1)
assert clean_textchunks(split_items) == expected_spans


def test_similar_spans_grouped() -> None:
text = 'Blah1. Blah2. Blah2.'
expected_spans = text_to_textchunk(text, ['Blah1.', 'Blah2. Blah2.'])
split_items = clustering_spacy_chunker(text, embed_fn=dummy_embbedder, target_num_groups=2)
assert clean_textchunks(split_items) == expected_spans
1 change: 0 additions & 1 deletion lilac/splitters/text_splitter_spacy.py

This file was deleted.

1 change: 0 additions & 1 deletion lilac/splitters/text_splitter_spacy_test.py

This file was deleted.

15 changes: 14 additions & 1 deletion lilac/splitters/text_splitter_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,18 @@ def text_to_textchunk(text: str, splits: list[str], allowable_overlap: int = 0)
expected_textchunks = []
for span in spans:
start, end = span[SPAN_KEY][TEXT_SPAN_START_FEATURE], span[SPAN_KEY][TEXT_SPAN_END_FEATURE]
expected_textchunks.append((text[start:end].strip(), (start, end)))
expected_textchunks.append((text[start:end], (start, end)))
return expected_textchunks


def clean_textchunks(chunks: list[TextChunk]) -> list[TextChunk]:
"""Strip whitespace from TextChunks."""

def clean_chunk(chunk: TextChunk) -> TextChunk:
text, (start, _) = chunk
stripped_text = text.strip()
new_start = text.find(stripped_text) + start
new_end = new_start + len(stripped_text)
return stripped_text, (new_start, new_end)

return [clean_chunk(chunk) for chunk in chunks]

0 comments on commit 98ce890

Please sign in to comment.