Skip to content

Commit

Permalink
SpaczzRuler persistence consistency prep for spaCy v3
Browse files Browse the repository at this point in the history
  • Loading branch information
gandersen101 committed Feb 25, 2021
1 parent 9ee8d2d commit 43fc507
Showing 1 changed file with 41 additions and 41 deletions.
82 changes: 41 additions & 41 deletions src/spaczz/pipeline/spaczzruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,32 +109,32 @@ def __init__(self: SpaczzRuler, nlp: Language, **cfg: Any) -> None:
self._ent_ids: defaultdict[Any, Any] = defaultdict(dict)
self.overwrite = cfg.get("spaczz_overwrite_ents", False)
default_names = (
"spaczz_fuzzy_defaults",
"spaczz_regex_defaults",
"spaczz_token_defaults",
("spaczz_fuzzy_defaults", "fuzzy_defaults"),
("spaczz_regex_defaults", "regex_defaults"),
("spaczz_token_defaults", "token_defaults"),
)
self.defaults = {}
for name in default_names:
if name in cfg:
if isinstance(cfg[name], dict):
self.defaults[name] = cfg[name]
for cfg_name, real_name in default_names:
if cfg_name in cfg:
if isinstance(cfg[cfg_name], dict):
self.defaults[real_name] = cfg[cfg_name]
else:
raise TypeError(
(
"Defaults must be a dictionary of keyword arguments,",
f"not {type(cfg[name])}.",
f"not {type(cfg[cfg_name])}.",
)
)
self.fuzzy_matcher = FuzzyMatcher(
nlp.vocab, **self.defaults.get("spaczz_fuzzy_defaults", {}),
nlp.vocab, **self.defaults.get("fuzzy_defaults", {}),
)
self.regex_matcher = RegexMatcher(
nlp.vocab,
cfg.get("spaczz_regex_config", "default"),
**self.defaults.get("spaczz_regex_defaults", {}),
**self.defaults.get("regex_defaults", {}),
)
self.token_matcher = TokenMatcher(
nlp.vocab, **self.defaults.get("spaczz_token_defaults", {})
nlp.vocab, **self.defaults.get("token_defaults", {})
)
patterns = cfg.get("spaczz_patterns")
if patterns is not None:
Expand Down Expand Up @@ -492,7 +492,7 @@ def from_bytes(
Args:
patterns_bytes : The bytestring to load.
**kwargs: Other config paramters, mostly for consistency.
**kwargs: Other config parameters, mostly for consistency.
Returns:
The loaded spaczz ruler.
Expand All @@ -512,22 +512,22 @@ def from_bytes(
"""
cfg = srsly.msgpack_loads(patterns_bytes)
if isinstance(cfg, dict):
self.add_patterns(cfg.get("spaczz_patterns", cfg))
self.defaults = cfg.get("spaczz_defaults", {})
if self.defaults.get("spaczz_fuzzy_defaults"):
self.add_patterns(cfg.get("patterns", cfg))
self.defaults = cfg.get("defaults", {})
if self.defaults.get("fuzzy_defaults"):
self.fuzzy_matcher = FuzzyMatcher(
self.nlp.vocab, **self.defaults["spaczz_fuzzy_defaults"]
self.nlp.vocab, **self.defaults["fuzzy_defaults"]
)
if self.defaults.get("spaczz_regex_defaults"):
if self.defaults.get("regex_defaults"):
self.regex_matcher = RegexMatcher(
self.nlp.vocab, **self.defaults["spaczz_regex_defaults"]
self.nlp.vocab, **self.defaults["regex_defaults"]
)
if self.defaults.get("spaczz_token_defaults"):
if self.defaults.get("token_defaults"):
self.token_matcher = TokenMatcher(
self.nlp.vocab, **self.defaults["spaczz_token_defaults"]
self.nlp.vocab, **self.defaults["token_defaults"]
)
self.overwrite = cfg.get("spaczz_overwrite", False)
self.ent_id_sep = cfg.get("spaczz_ent_id_sep", DEFAULT_ENT_ID_SEP)
self.overwrite = cfg.get("overwrite", False)
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
else:
self.add_patterns(cfg)
return self
Expand All @@ -554,10 +554,10 @@ def to_bytes(self: SpaczzRuler, **kwargs: Any) -> bytes:
"""
serial = OrderedDict(
(
("spaczz_overwrite", self.overwrite),
("spaczz_ent_id_sep", self.ent_id_sep),
("spaczz_patterns", self.patterns),
("spaczz_defaults", self.defaults),
("overwrite", self.overwrite),
("ent_id_sep", self.ent_id_sep),
("patterns", self.patterns),
("defaults", self.defaults),
)
)
return srsly.msgpack_dumps(serial)
Expand Down Expand Up @@ -601,27 +601,27 @@ def from_disk(
else:
cfg = {}
deserializers_patterns = {
"spaczz_patterns": lambda p: self.add_patterns(
"patterns": lambda p: self.add_patterns(
srsly.read_jsonl(p.with_suffix(".jsonl"))
)
}
deserializers_cfg = {"cfg": lambda p: cfg.update(srsly.read_json(p))}
read_from_disk(path, deserializers_cfg, {})
self.overwrite = cfg.get("spaczz_overwrite", False)
self.defaults = cfg.get("spaczz_defaults", {})
if self.defaults.get("spaczz_fuzzy_defaults"):
self.overwrite = cfg.get("overwrite", False)
self.defaults = cfg.get("defaults", {})
if self.defaults.get("fuzzy_defaults"):
self.fuzzy_matcher = FuzzyMatcher(
self.nlp.vocab, **self.defaults["spaczz_fuzzy_defaults"]
self.nlp.vocab, **self.defaults["fuzzy_defaults"]
)
if self.defaults.get("spaczz_regex_defaults"):
if self.defaults.get("regex_defaults"):
self.regex_matcher = RegexMatcher(
self.nlp.vocab, **self.defaults["spaczz_regex_defaults"]
self.nlp.vocab, **self.defaults["regex_defaults"]
)
if self.defaults.get("spaczz_token_defaults"):
if self.defaults.get("token_defaults"):
self.token_matcher = TokenMatcher(
self.nlp.vocab, **self.defaults["spaczz_token_defaults"]
self.nlp.vocab, **self.defaults["token_defaults"]
)
self.ent_id_sep = cfg.get("spaczz_ent_id_sep", DEFAULT_ENT_ID_SEP)
self.ent_id_sep = cfg.get("ent_id_sep", DEFAULT_ENT_ID_SEP)
read_from_disk(path, deserializers_patterns, {})
return self

Expand Down Expand Up @@ -651,12 +651,12 @@ def to_disk(self: SpaczzRuler, path: Union[str, Path], **kwargs: Any) -> None:
"""
path = ensure_path(path)
cfg = {
"spaczz_overwrite": self.overwrite,
"spaczz_defaults": self.defaults,
"spaczz_ent_id_sep": self.ent_id_sep,
"overwrite": self.overwrite,
"defaults": self.defaults,
"ent_id_sep": self.ent_id_sep,
}
serializers = {
"spaczz_patterns": lambda p: srsly.write_jsonl(
"patterns": lambda p: srsly.write_jsonl(
p.with_suffix(".jsonl"), self.patterns
),
"cfg": lambda p: srsly.write_json(p, cfg),
Expand Down Expand Up @@ -716,7 +716,7 @@ def _create_label(self: SpaczzRuler, label: str, ent_id: Union[str, None]) -> st
The label and ent_id joined with configured ent_id_sep.
"""
if isinstance(ent_id, str):
label = "{}{}{}".format(label, self.ent_id_sep, ent_id)
label = f"{label}{self.ent_id_sep}{ent_id}"
return label

def _split_label(self: SpaczzRuler, label: str) -> tuple[str, Union[str, None]]:
Expand Down

0 comments on commit 43fc507

Please sign in to comment.