-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add spacy embedding-clustering splitter (#784)
This PR implements the embedding-clustered splitting algorithm; some basic sanity checks are added to handle edge cases. This PR does not yet switch the default implementation of split_text.
- Loading branch information
Showing
6 changed files
with
167 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
"""SpaCy-based chunk splitting algorithms.""" | ||
import bisect | ||
import functools | ||
from typing import Callable, Optional, Sequence | ||
|
||
import numpy as np | ||
import spacy | ||
|
||
from .chunk_splitter import TextChunk | ||
|
||
|
||
@functools.cache | ||
def get_spacy() -> spacy.Language: | ||
"""Lazily instantiate and return a singeton SpaCy sentencizer object.""" | ||
sentencizer = spacy.blank('en') | ||
# This includes colon as a sentence boundary; LLM datasets tend to contain a lot of semantic | ||
# markers with a colon, like "Teacher: ... " or "Answer the following question: ..." | ||
sentencizer.add_pipe('sentencizer', config={'punct_chars': [':', ';', '.', '!', '?']}) | ||
# Increase the number of characters of the tokenizer as we're not using a parser or NER. | ||
sentencizer.max_length = 10_000_000 | ||
return sentencizer | ||
|
||
|
||
def simple_spacy_chunker(text: str, filter_short: int = 4) -> list[TextChunk]: | ||
"""Split text into sentence-based chunks, using SpaCy.""" | ||
sentencizer = get_spacy() | ||
chunks = [ | ||
(text[s.start_char:s.end_char], (s.start_char, s.end_char)) for s in sentencizer(text).sents | ||
] | ||
# Filter out stray whitespace, list numberings, etc. | ||
chunks = [c for c in chunks if len(c[0].strip()) > filter_short] | ||
return chunks | ||
|
||
|
||
def group_by_embedding(fulltext: str, chunks: list[TextChunk], embed_fn: Callable[[list[str]], | ||
list[np.ndarray]], | ||
target_num_groups: int, max_len: int) -> list[TextChunk]: | ||
"""Take a series of smaller chunks and cluster them together. | ||
Args: | ||
fulltext: Full text. | ||
chunks: Smaller chunks to combine. | ||
embed_fn: A function mapping strings to an embedding vector. | ||
target_num_groups: Target number of chunks in final output. | ||
max_len: Maximum size of a combined chunk. | ||
""" | ||
embeddings = np.array(embed_fn([c[0] for c in chunks])) | ||
# Center the embeddings for all sentences; this accentuates sentence semantics, | ||
# especially if the entire passage is roughly about the same topic | ||
embeddings -= np.mean(embeddings, axis=0) | ||
embeddings += np.random.uniform(size=embeddings.shape) * 1e-6 | ||
embeddings /= np.linalg.norm(embeddings, axis=-1, keepdims=True) | ||
|
||
neighbor_distances: Sequence[float] = (embeddings[:-1] * embeddings[1:]).sum(axis=1) | ||
potential_breaks: np.ndarray = np.array([c[1][1] for c in chunks[:-1]]) # end index of each chunk | ||
priority_sort_breaks: np.ndarray = potential_breaks[np.argsort(neighbor_distances)] | ||
|
||
# If there are fewer sentences than target number of groups, then this should degrade gracefully. | ||
breakpoints = [0] + sorted(priority_sort_breaks[:(target_num_groups - 1)]) + [chunks[-1][1][1]] | ||
|
||
def _find_long_spans(breakpoints: list[int]) -> Optional[tuple[int, int]]: | ||
for i, j in zip(breakpoints[:-1], breakpoints[1:]): | ||
if j - i > max_len: | ||
return (i, j) | ||
return None | ||
|
||
# Recursively break long spans until there are no more. | ||
while (span := _find_long_spans(breakpoints)) is not None: | ||
i, j = span | ||
for potential_break in priority_sort_breaks: | ||
if i < potential_break < j: | ||
bisect.insort(breakpoints, potential_break) | ||
break | ||
else: # No potential breaker was found. Arbitrarily split the span in half. | ||
bisect.insort(breakpoints, int((i + j) / 2)) | ||
|
||
return [ | ||
(fulltext[start:end], (start, end)) for start, end in zip(breakpoints[:-1], breakpoints[1:]) | ||
] | ||
|
||
|
||
def clustering_spacy_chunker( | ||
text: str, | ||
filter_short: int = 4, | ||
max_len: int = 512, | ||
target_num_groups: Optional[int] = None, | ||
embed_fn: Optional[Callable[[list[str]], list[np.ndarray]]] = None) -> list[TextChunk]: | ||
"""Split text into sentence-based chunks, with semantic clustering to join related sentences.""" | ||
chunks = simple_spacy_chunker(text, filter_short=filter_short) | ||
if embed_fn is None: | ||
return chunks | ||
|
||
if target_num_groups is None: | ||
# A rough heuristic for picking a number of target chunks. | ||
# These magic numbers were chosen by manually chunking 40 texts spanning 50-5000 characters in | ||
# length, and eyeballing a best-fit line from #num chunks vs. #length on a log-log plot. | ||
target_num_groups = max(1, int((len(text)**0.33) / 1.5)) | ||
return group_by_embedding(text, chunks, embed_fn, target_num_groups, max_len) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
"""Tests the spacy chunk splitter.""" | ||
|
||
import numpy as np | ||
|
||
from .spacy_splitter import clustering_spacy_chunker, simple_spacy_chunker | ||
from .text_splitter_test_utils import clean_textchunks, text_to_textchunk | ||
|
||
|
||
def dummy_embbedder(chunks: list[str], embed_dim: int = 4) -> list[np.ndarray]: | ||
|
||
def _single_embed(str: str) -> np.ndarray: | ||
np.random.seed(hash(str) % (2**32 - 1)) | ||
return np.random.random(size=(1, embed_dim)) | ||
|
||
return np.concatenate([_single_embed(s) for s in chunks], axis=0) | ||
|
||
|
||
def test_short_snippets_filtered() -> None: | ||
text = '1. Hello. 2. World.' | ||
expected_spans = text_to_textchunk(text, ['Hello.', 'World.']) | ||
|
||
split_items = simple_spacy_chunker(text) | ||
assert split_items == expected_spans | ||
|
||
|
||
def test_colon_considered_as_splitter() -> None: | ||
text = 'Teacher: Tell me the answer. Student: I have no idea.' | ||
expected_spans = text_to_textchunk( | ||
text, ['Teacher:', 'Tell me the answer.', 'Student:', 'I have no idea.']) | ||
split_items = simple_spacy_chunker(text) | ||
assert split_items == expected_spans | ||
|
||
|
||
def test_long_spans_default_split() -> None: | ||
text = 'Blah blah blah.' | ||
expected_spans = text_to_textchunk(text, ['Blah bl', 'ah blah.']) | ||
|
||
split_items = clustering_spacy_chunker(text, embed_fn=dummy_embbedder, max_len=8) | ||
assert split_items == expected_spans | ||
|
||
|
||
def test_long_spans_preferred_splits() -> None: | ||
text = 'Blah. blah. bla. bl.' | ||
expected_spans = text_to_textchunk(text, ['Blah.', 'blah.', 'bla.', 'bl.']) | ||
# Even though target_num_groups = 1, the max len constraint causes breaking. | ||
split_items = clustering_spacy_chunker( | ||
text, embed_fn=dummy_embbedder, target_num_groups=1, max_len=6, filter_short=1) | ||
assert clean_textchunks(split_items) == expected_spans | ||
|
||
|
||
def test_similar_spans_grouped() -> None: | ||
text = 'Blah1. Blah2. Blah2.' | ||
expected_spans = text_to_textchunk(text, ['Blah1.', 'Blah2. Blah2.']) | ||
split_items = clustering_spacy_chunker(text, embed_fn=dummy_embbedder, target_num_groups=2) | ||
assert clean_textchunks(split_items) == expected_spans |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters