Skip to content

Commit

Permalink
npz serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
pmaher86 committed Oct 12, 2024
1 parent 861ae5c commit 375954d
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 16 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
exclude: src/blacksquare/word_list.npz
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.8
hooks:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ local_scheme = "no-local-version"
where = ["src"]

[tool.setuptools.package-data]
blacksquare = ["*.dict"]
blacksquare = ["*.npz"]
Binary file added src/blacksquare/word_list.npz
Binary file not shown.
40 changes: 25 additions & 15 deletions src/blacksquare/word_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,24 @@ def __init__(
Raises:
ValueError: If input type is not recognized
"""
if (
isinstance(source, str)
or isinstance(source, Path)
or isinstance(source, io.IOBase)
):
df = pd.read_csv(
source,
sep=";",
header=None,
names=["word", "score"],
dtype={"word": str, "score": float},
na_filter=False,
)
raw_words_scores = df.values
if isinstance(source, str) or isinstance(source, Path):
if Path(source).suffix == ".npz":
loaded = np.load(source)
length_keys = [k for k in loaded.keys() if k not in ("words", "scores")]
self._words = loaded["words"]
self._scores = loaded["scores"]
self._word_scores_by_length = {int(k): loaded[k] for k in length_keys}
return
else:
df = pd.read_csv(
source,
sep=";",
header=None,
names=["word", "score"],
dtype={"word": str, "score": float},
na_filter=False,
)
raw_words_scores = df.values
elif isinstance(source, list):
assert len(source) > 0 and isinstance(source[0], str)
raw_words_scores = [(w, 1) for w in source]
Expand Down Expand Up @@ -212,6 +216,12 @@ def score_filter(self, threshold: float) -> WordList:
def filter(self, filter_fn: Callable[[ScoredWord], bool]) -> WordList:
return WordList(dict([w for w in self if filter_fn(w)]))

def to_npz(self, file: str | Path) -> None:
by_length_str_key = {str(k): v for k, v in self._word_scores_by_length.items()}
np.savez_compressed(
file, words=self._words, scores=self._scores, **by_length_str_key
)

def __len__(self):
return len(self._words)

Expand Down Expand Up @@ -347,4 +357,4 @@ def _normalize(word: str) -> str:
return word.upper().replace(" ", "")


DEFAULT_WORDLIST = WordList(files("blacksquare").joinpath("xwordlist.dict"))
DEFAULT_WORDLIST = WordList(files("blacksquare").joinpath("word_list.npz"))

0 comments on commit 375954d

Please sign in to comment.