Skip to content

Commit

Permalink
feat(nlp): revise clip tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Sep 8, 2024
1 parent 6272cab commit 083f105
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 103 deletions.
15 changes: 9 additions & 6 deletions todd/patches/py/codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,31 @@
import codecs
import urllib.parse

ASCII = 'ascii'
FILENAME = 'filename'


def filename_codec(name: str) -> codecs.CodecInfo | None:
if name != 'filename':
if name != FILENAME:
return None

def encode(s: str, *args, **kwargs) -> tuple[bytes, int]:
f = urllib.parse.quote(s, safe='').encode('ascii')
f = urllib.parse.quote(s, safe='').encode(ASCII)
return f, len(s)

def decode(f: bytes, *args, **kwargs) -> tuple[str, int]:
s = urllib.parse.unquote(f.decode('ascii'))
s = urllib.parse.unquote(f.decode(ASCII))
return s, len(f)

return codecs.CodecInfo(encode, decode, name=name)
return codecs.CodecInfo(encode, decode, name=FILENAME)


codecs.register(filename_codec)


def encode_filename(s: str) -> str:
return codecs.encode(s, 'filename').decode('ascii')
return codecs.encode(s, FILENAME).decode(ASCII)


def decode_filename(f: str) -> str:
return codecs.decode(f.encode('ascii'), 'filename')
return codecs.decode(f.encode(ASCII), FILENAME)
53 changes: 27 additions & 26 deletions todd/tasks/natural_language_processing/tokenizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import torch

from todd.patches.py import classproperty

from ..bpe import TokenSequence


Expand All @@ -17,19 +19,23 @@ class BaseTokenizer:
def __init__(
self,
*args,
word2token: Mapping[str, int],
special_word2token: Mapping[str, int],
text2token: Mapping[str, int],
special_text2token: Mapping[str, int],
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._word2token = word2token
self._token2word = {v: k for k, v in word2token.items()}
self._special_word2token = special_word2token
self._special_token2word = {
self._text2token = text2token
self._token2text = {v: k for k, v in text2token.items()}
self._special_text2token = special_text2token
self._special_token2text = {
v: k
for k, v in special_word2token.items()
for k, v in special_text2token.items()
}

@classproperty
def special_texts(self) -> list[str]:
return [self.SOS, self.EOS]

@abstractmethod
def encode(
self,
Expand All @@ -39,33 +45,20 @@ def encode(
) -> TokenSequence:
pass

@abstractmethod
def decode(self, token_sequence: TokenSequence) -> str:
pass

def word_to_token(self, word: str) -> int:
if word in self._special_word2token:
return self._special_word2token[word]
return self._word2token[word]

def token_to_word(self, token: int) -> str:
if token in self._special_token2word:
return self._special_token2word[token]
return self._token2word[token]

def encodes(
self,
texts: Iterable[str],
*,
max_length: int | None = None,
) -> torch.Tensor:
if max_length is not None:
max_length -= 2
sos_token = self._special_word2token[self.SOS]
eos_token = self._special_word2token[self.EOS]
sos_token = self._special_text2token[self.SOS]
eos_token = self._special_text2token[self.EOS]
token_sequences = [[
sos_token,
*self.encode(text, max_length=max_length),
*self.encode(
text,
max_length=None if max_length is None else max_length - 2,
),
eos_token,
] for text in texts]
tokens = torch.zeros(
Expand All @@ -76,3 +69,11 @@ def encodes(
for i, token_sequence in enumerate(token_sequences):
tokens[i, :len(token_sequence)] = torch.tensor(token_sequence)
return tokens

def _token_to_text(self, token: int) -> str:
if token in self._special_token2text:
return self._special_token2text[token]
return self._token2text[token]

def decode(self, token_sequence: TokenSequence) -> str:
return ''.join(self._token_to_text(token) for token in token_sequence)
206 changes: 135 additions & 71 deletions todd/tasks/natural_language_processing/tokenizers/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,119 +8,183 @@
import pathlib
from typing import Any

import ftfy
import regex as re

from ..bpe import BPE, TokenSequence
from .base import BaseTokenizer

UTF_8 = 'utf-8'

class CLIPTokenizer(BaseTokenizer):
SOS = '<|startoftext|>'
EOS = '<|endoftext|>'

EOW = '</w>'
class Codec:
NUM_CODES = 256

def __init__(self, *args, bpe_path: Any = None, **kwargs) -> None:
byte2unicode = list(range(256))
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

byte2ord = list(range(self.NUM_CODES))

for i, byte in enumerate(
itertools.chain(
range(ord('!')),
range(ord('~') + 1, ord('¡')),
range(ord('¬') + 1, ord('®')),
),
256,
self.NUM_CODES,
):
byte2unicode[byte] = i
byte2ord[byte] = i

self._byte2unicode = list(map(chr, byte2unicode))
self._unicode2byte = {
byte2unicode = list(map(chr, byte2ord))
unicode2byte = {
unicode: byte
for byte, unicode in enumerate(self._byte2unicode)
for byte, unicode in enumerate(byte2unicode)
}

word2token = {
self._byte2unicode = byte2unicode
self._unicode2byte = unicode2byte

@property
def codes(self) -> list[str]:
return sorted(self._byte2unicode)

def encode(self, text: str) -> str:
return ''.join(self._byte2unicode[byte] for byte in text.encode(UTF_8))

def decode(self, encoded_text: str) -> str:
bytes_ = bytes(self._unicode2byte[unicode] for unicode in encoded_text)
return bytes_.decode(UTF_8)


def load_bpe(path: Any, size: int) -> list[tuple[str, str]]:
if path is None:
path = pathlib.Path(__file__).parent / 'clip_bpe.txt.gz'
bpe: list[tuple[str, str]] = []
with gzip.open(path, 'rt', encoding=UTF_8) as f:
f.readline() # skip first line
for _ in range(size):
text1, text2 = f.readline().split() # ensure length is 2
bpe.append((text1, text2))
return bpe


class Parser:
PATTERN = '|'.join([
r'<\|startoftext\|>',
r'<\|endoftext\|>',
"'s",
"'t",
"'re",
"'ve",
"'m"
"'ll"
"'d",
r'[\p{L}]+',
r'[\p{N}]',
r'[^\s\p{L}\p{N}]+',
])

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._pattern = re.compile(self.PATTERN, re.IGNORECASE)

def __call__(self, text: str) -> list[str]:
text = ftfy.fix_text(text)
text = html.unescape(text)
text = html.unescape(text) # do this a second time
text = re.sub(r'\s+', ' ', text)
text = text.strip().lower()
return self._pattern.findall(text)


class Cache:

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._text2token_sequence: dict[str, TokenSequence] = dict()

def get(self, text: str) -> TokenSequence | None:
return self._text2token_sequence.get(text)

def set_(self, text: str, token_sequence: TokenSequence) -> None:
self._text2token_sequence[text] = token_sequence


class CLIPTokenizer(BaseTokenizer):
SOS = '<|startoftext|>'
EOS = '<|endoftext|>'

EOW = '</w>' # not a special token

def __init__(self, *args, bpe_path: Any = None, **kwargs) -> None:
codec = Codec()
bpe = load_bpe(
bpe_path,
3 * 2**14 - codec.NUM_CODES - len(self.special_texts),
)

text2token = {
unicode: byte
for byte, unicode in enumerate(sorted(self._byte2unicode))
for byte, unicode in enumerate(codec.codes)
}
word2token.update({
unicode + self.EOW: byte + len(word2token)
for unicode, byte in word2token.items()
text2token.update({
unicode + self.EOW: byte + codec.NUM_CODES
for unicode, byte in text2token.items()
})
text2token.update({
text1 + text2: i
for i, (text1, text2) in enumerate(bpe, len(text2token))
})

token_pairs = []
if bpe_path is None:
bpe_path = pathlib.Path(__file__).parent / 'clip_bpe.txt.gz'
with gzip.open(bpe_path, 'rt', encoding='utf-8') as f:
f.readline() # skip first line
for _ in range(49152 - 256 - 2):
word1, word2 = f.readline().split()
token_pairs.append((word2token[word1], word2token[word2]))
word2token[word1 + word2] = len(word2token)

special_words = [self.SOS, self.EOS]
special_word2token = {
word: len(word2token) + i
for i, word in enumerate(special_words)
special_text2token = {
text: len(text2token) + i
for i, text in enumerate(self.special_texts)
}

super().__init__(
*args,
word2token=word2token,
special_word2token=special_word2token,
text2token=text2token,
special_text2token=special_text2token,
**kwargs,
)

self._bpe = BPE(len(self._byte2unicode) * 2, token_pairs)

self._word2token_sequence: dict[str, TokenSequence] = dict() # cache
self._word_pattern = re.compile(
r"<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d"
r"|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+",
re.IGNORECASE,
self._codec = codec
self._parser = Parser()
self._cache = Cache()
self._bpe = BPE(
codec.NUM_CODES * 2,
[(text2token[text1], text2token[text2]) for text1, text2 in bpe],
)

def _parse(self, text: str) -> list[str]:
import ftfy
text = ftfy.fix_text(text)
text = html.unescape(text)
text = html.unescape(text) # do this a second time
text = re.sub(r'\s+', ' ', text)
text = text.strip().lower()
return self._word_pattern.findall(text)

def _encode(self, word: str) -> TokenSequence:
if word in self._special_word2token:
return [self._special_word2token[word]]

word_sequence = [
self._byte2unicode[byte] for byte in word.encode('utf-8')
]
word_sequence[-1] = word_sequence[-1] + self.EOW
token_sequence = [
self._word2token[unicode] for unicode in word_sequence
]
return self._bpe.encode(token_sequence)
def _encode(self, text: str) -> TokenSequence:
token_sequence = self._cache.get(text)
if token_sequence is not None:
return token_sequence

if text in self._special_text2token:
token_sequence = [self._special_text2token[text]]
else:
code_sequence = list(self._codec.encode(text))
code_sequence[-1] = code_sequence[-1] + self.EOW
token_sequence = [
self._text2token[unicode] for unicode in code_sequence
]
token_sequence = self._bpe.encode(token_sequence)

self._cache.set_(text, token_sequence)
return token_sequence

def encode(
self,
text: str,
*,
max_length: int | None = None,
) -> TokenSequence:
token_sequence: TokenSequence = []
for word in self._parse(text):
word_token_sequence = self._word2token_sequence.get(word)
if word_token_sequence is None:
word_token_sequence = self._encode(word)
self._word2token_sequence[word] = word_token_sequence
token_sequence.extend(word_token_sequence)
token_sequence = sum(map(self._encode, self._parser(text)), [])
if max_length is not None:
token_sequence = token_sequence[:max_length]
return token_sequence

def decode(self, token_sequence: TokenSequence) -> str:
text = ''.join(self.token_to_word(token) for token in token_sequence)
text = bytearray(self._unicode2byte[c] for c in text).decode('utf-8')
text = text.replace(self.EOW, ' ')
return text
encoded_text = super().decode(token_sequence)
return self._codec.decode(encoded_text).replace(self.EOW, ' ')

0 comments on commit 083f105

Please sign in to comment.