Skip to content

Commit

Permalink
Add some helper functions, tweak code, set initial values
Browse files Browse the repository at this point in the history
  • Loading branch information
MattGPT-ai committed Jul 6, 2023
1 parent d788e48 commit 5d67242
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions fast_censor/profanity_check_trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ def set_delimiters(self, delimiters: Union[Set[str], str, List[str]]) -> None:
else:
self.delimiters = {delim for delim in delimiters} # use set for quick access

def get_delimiters(self) -> Set[str]:
"""Returns:
set of delimiter chars
"""
return self.delimiters

def __init__(
self,
words: Optional[Collection[str]] = None, # overrides wordlist
Expand Down Expand Up @@ -158,14 +164,12 @@ def add_word(self, word: str) -> Optional[TrieNode]:
nxt = TrieNode(c)
if self.debug:
print('created node', nxt)
chars = self.mapping.get(c)
if chars is None:
chars = tuple(c)
chars = self.mapping.get(c, tuple(c))
for char in chars:
pointer.children[char].append(nxt)
if self.debug:
print(f"set {pointer.val} node child {char} to {nxt.val}")
else: # character already
else: # character already a child of this node
for child in c_children:
if child.val == c:
nxt = child
Expand Down Expand Up @@ -204,6 +208,13 @@ def build_trie(self, words: Union[Set[str], List[str]]) -> TrieNode:
if self.debug:
self.bfs()

return self.head_node

def get_trie(self) -> TrieNode:
"""Returns:
the head node"""
return self.head_node

def bfs(self):
"""breadth-first-search and print nodes for debugging"""
nodes = [self.head_node]
Expand Down Expand Up @@ -275,22 +286,22 @@ def check_text(self, string: str, allow_repetitions: bool = True) -> List[Tuple[
else:
if self.debug:
print('appending pointer', new_pointer.val)
new_pointers.add((new_pointer, length+1)) # advance
new_pointers.add((new_pointer, length + 1)) # advance
# allow additional repeated characters after all repetitions have been traversed
if allow_repetitions and (c == pointer.val or
(pointer.val in self.mapping and c in
self.mapping[pointer.val])):
new_pointers.add((pointer, length+1)) # don't advance
new_pointers.add((pointer, length + 1)) # don't advance
# else pointer is not continued

pointers = new_pointers # advance all

return profanity_matches

def censor_text(self, text: str) -> str:
def censor_text(self, text: str, allow_repetitions: bool = True) -> str:
"""returns string with all profanity matches replaced with `censor_char`"""
text_list = list(text)
matches = self.check_text(text)
matches = self.check_text(text, allow_repetitions)
for i, j in matches:
text_list[i:j+1] = self.censor_char * (j - i)
return ''.join(text_list)
Expand Down Expand Up @@ -330,6 +341,7 @@ def __iter__(self):

def __next__(self):
word_start = True
pointers = set()
while self.i < len(self.string):
c = self.string[self.i]
if self.trie.is_delimiter(c):
Expand All @@ -347,12 +359,12 @@ def __next__(self):
for new_pointer in pointer.children[c]:
if new_pointer.end_node_string:
yield
new_pointers.add((new_pointer, length+1)) # advance
new_pointers.add((new_pointer, length + 1)) # advance
# allow additional repeated characters after all repetitions have been traversed
if self.allow_repetitions and (c == pointer.val or
(pointer.val in self.trie.mapping and c in
self.trie.mapping[pointer.val])):
new_pointers.add((pointer, length+1)) # don't advance
(pointer.val in self.trie.mapping and c in
self.trie.mapping[pointer.val])):
new_pointers.add((pointer, length + 1)) # don't advance
# else pointer is not continued

pointers = new_pointers # advance all
Expand Down

0 comments on commit 5d67242

Please sign in to comment.