Skip to content

Commit

Permalink
Reducing logspam in caching.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555520027
  • Loading branch information
RyanMullins authored and LIT team committed Aug 10, 2023
1 parent 33b9f41 commit dde3006
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions lit_nlp/lib/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,26 +168,32 @@ def save_to_disk(self):
"""Save cache data to disk."""
# No cache loader is created if no cache directory was provided, in which
# case this is a no-op.
if not self._cache_loader:
cache_loader = self._cache_loader
if not cache_loader:
return
if self._num_persisted == len(self._d):
logging.info("No need to re-save cache to %s", self._cache_dir)
return
logging.info("Saving cache (%d entries) to %s", len(self._d),
self._cache_dir)
self._cache_loader.save(self._d)
logging.info(
"Saving cache (%d entries) to %s", len(self._d), self._cache_dir
)
cache_loader.save(self._d)
self._num_persisted = len(self._d)

def load_from_disk(self):
"""Load cache data from disk."""
# No cache loader is created if no cache directory was provided, in which
# case this is a no-op.
if not self._cache_loader:
cache_loader = self._cache_loader
if not cache_loader:
return
self._d = self._cache_loader.load()
self._d = cache_loader.load()
self._num_persisted = len(self._d)
logging.info("Loaded cache (%d entries) from %s", self._num_persisted,
self._cache_dir)
logging.info(
"Loaded cache (%d entries) from %s",
self._num_persisted,
self._cache_dir,
)


class CachingModelWrapper(lit_model.ModelWrapper):
Expand Down Expand Up @@ -234,11 +240,7 @@ def save_cache(self):
self._cache.save_to_disk()

def key_fn(self, d) -> CacheKey:
if not (d_id := d.get("_id")):
logging.warning(
"Found empty or missing example ID - using empty cache ID.")
return None
return (self._name, d_id)
return (self._name, d_id) if (d_id := d.get("_id")) else None

def _validate_ids(self, inputs: Iterable[JsonDict]):
for ex in inputs:
Expand All @@ -261,10 +263,19 @@ def fit_transform(self, inputs: Iterable[JsonDict]):
)

inputs_as_list = list(inputs)
cache_keys = [self.key_fn(d) for d in inputs_as_list]
if (none_keys := [k for k in cache_keys if k is None]):
logging.warning(
"Attmepting to cache %d (of %d) where the cache key is None "
"- this can be from a missing or empty example id. These"
" will be recomputed on subsequent attempts.",
len(none_keys),
len(cache_keys),
)
outputs = list(wrapped.fit_transform(inputs_as_list))
with self._cache.lock:
for inp, output in zip(inputs_as_list, outputs):
self._cache.put(output, self.key_fn(inp))
for cache_key, output in zip(cache_keys, outputs, strict=True):
self._cache.put(output, cache_key)
return outputs

# TODO(b/170662608) Remove once batching logic changes are done.
Expand All @@ -284,6 +295,14 @@ def predict(self,

# Try to get results from the cache.
input_keys = [self.key_fn(d) for d in inputs_as_list]
if (none_keys := [k for k in input_keys if k is None]):
logging.warning(
"Attmepting to retrieve %d (of %d) predictions from the cache where"
" the cache key is None - this can be from a missing or empty example"
" id. These will call model.predict() on this and subsequent calls.",
len(none_keys),
len(input_keys),
)
if self._cache.pred_lock_key(input_keys):
with self._cache.get_pred_lock(input_keys):
cached_results = self._get_results_from_cache(input_keys)
Expand Down

0 comments on commit dde3006

Please sign in to comment.