Skip to content

Commit

Permalink
Merge pull request #816 from MolSSI/fix_cache
Browse files Browse the repository at this point in the history
Fixes some cache issues in v0.54
  • Loading branch information
bennybp authored Apr 11, 2024
2 parents 832c9de + bae24bf commit d81c496
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 34 deletions.
10 changes: 2 additions & 8 deletions qcarchivetesting/qcarchivetesting/testing_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def stop_job_runner(self) -> None:
"""
self._stop_job_runner()

def client(self, username=None, password=None, cache_dir=None, shared_memory_cache=False) -> PortalClient:
def client(self, username=None, password=None, cache_dir=None) -> PortalClient:
"""
Obtain a client connected to this snowflake
Expand All @@ -260,20 +260,14 @@ def client(self, username=None, password=None, cache_dir=None, shared_memory_cac
The password to use
cache_dir
Directory to store cache files in
shared_memory_cache
Whether to use a shared memory cache
Note: We generally don't want a shared memory cache for tests, so the default is False here
rather than True as in the main codebase
Returns
-------
:
A PortalClient that is connected to this snowflake
"""

client = PortalClient(self.get_uri(), username=username, password=password, cache_dir=cache_dir,
shared_memory_cache=shared_memory_cache)
client = PortalClient(self.get_uri(), username=username, password=password, cache_dir=cache_dir)
client.encoding = self.encoding
return client

Expand Down
22 changes: 16 additions & 6 deletions qcportal/qcportal/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import sqlite3
import threading
import uuid
from typing import TYPE_CHECKING, Optional, TypeVar, Type, Any, List, Iterable, Tuple, Sequence
from urllib.parse import urlparse

Expand Down Expand Up @@ -553,7 +554,9 @@ def get_existing_dataset_records(


class PortalCache:
def __init__(self, server_uri: str, cache_dir: Optional[str], max_size: int, shared_memory: bool = True):
def __init__(
self, server_uri: str, cache_dir: Optional[str], max_size: int, shared_memory_key: Optional[str] = None
):
parsed_url = urlparse(server_uri)

# Should work as a reasonable fingerprint?
Expand All @@ -566,19 +569,23 @@ def __init__(self, server_uri: str, cache_dir: Optional[str], max_size: int, sha
os.makedirs(self.cache_dir, exist_ok=True)
else:
self._is_disk = False
self._shared_memory = shared_memory

# If no shared memory key specified, make a unique one
if shared_memory_key is None:
shared_memory_key = f"{server_uri}_{os.getpid()}_{uuid.uuid4()}"

self._shared_memory_key = shared_memory_key
self.cache_dir = None

def get_cache_uri(self, cache_name: str) -> str:
if self._is_disk:
file_path = os.path.join(self.cache_dir, f"{cache_name}.sqlite")
uri = f"file:{file_path}"
elif self._shared_memory:
else:
# We always want some shared cache due to the use of threads.
# vfs=memdb seems to be a better way than mode=memory&cache=shared . Very little docs about it though
# The / after the : is apparently very important. Otherwise, the shared stuff doesn't work
uri = f"file:/qca_cache_{cache_name}?vfs=memdb"
else:
uri = f"file:/qca_cache_{cache_name}?mode=memory"
uri = f"file:/{self._shared_memory_key}_{cache_name}?vfs=memdb"

return uri

Expand Down Expand Up @@ -676,6 +683,9 @@ class functions of the record types.
existing_records = record_cache.get_records(record_ids, record_type)
records_tofetch = set(record_ids) - {x.id for x in existing_records}

for r in existing_records:
r.propagate_client(client)

if records_tofetch:
if client is None:
raise RuntimeError("Need to fetch some records, but not connected to a client")
Expand Down
12 changes: 6 additions & 6 deletions qcportal/qcportal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
*,
cache_dir: Optional[str] = None,
cache_max_size: int = 0,
shared_memory_cache: bool = True,
memory_cache_key: Optional[str] = None,
) -> None:
"""
Parameters
Expand All @@ -147,15 +147,15 @@ def __init__(
Directory to store an internal cache of records and other data
cache_max_size
Maximum size of the cache directory
shared_memory_cache
If True (and cache_dir is not specified), the memory-backed cache will be marked as shared,
meaning that multiple instances of the PortalClient can share the same cache, as well as making the
cache shared among threads. Generally, this is what you want, but
memory_cache_key
If set, all clients with the same memory_cache_key will share an in-memory cache. If not specified,
a unique one will be generated, meaning this client will not share a memory-based cache with any
other clients. Not used if cache_dir is set.
"""

PortalClientBase.__init__(self, address, username, password, verify, show_motd)
self._logger = logging.getLogger("PortalClient")
self.cache = PortalCache(address, cache_dir, cache_max_size, shared_memory_cache)
self.cache = PortalCache(address, cache_dir, cache_max_size, memory_cache_key)

def __repr__(self) -> str:
"""A short representation of the current PortalClient.
Expand Down
16 changes: 8 additions & 8 deletions qcportal/qcportal/dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,10 @@ def fetch_specifications(
@property
def specification_names(self) -> List[str]:
if not self._specification_names:
self._specification_names = self._cache_data.get_specification_names()

if not self._specification_names and not self.is_view:
self.fetch_specification_names()
if self.is_view:
self._specification_names = self._cache_data.get_specification_names()
else:
self.fetch_specification_names()

return self._specification_names

Expand Down Expand Up @@ -769,10 +769,10 @@ def iterate_entries(
@property
def entry_names(self) -> List[str]:
if not self._entry_names:
self._entry_names = self._cache_data.get_entry_names()

if not self._entry_names and not self.is_view:
self.fetch_entry_names()
if self.is_view:
self._entry_names = self._cache_data.get_entry_names()
else:
self.fetch_entry_names()

return self._entry_names

Expand Down
1 change: 1 addition & 0 deletions qcportal/qcportal/dataset_testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def run_dataset_model_rename_spec(snowflake_client, ds, test_entries, test_specs

assert set(spec_names) == {"spec_2", "spec_1_new"}
assert set(ds.specifications.keys()) == {"spec_2", "spec_1_new"}
assert set(ds.specification_names) == {"spec_2", "spec_1_new"}
assert set(ds._specification_names) == {"spec_2", "spec_1_new"}
assert set(ds._cache_data.get_specification_names()) == {"spec_2", "spec_1_new"}
assert ds._cache_data.get_specification("spec_1_new").name == "spec_1_new"
Expand Down
4 changes: 1 addition & 3 deletions qcportal/qcportal/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ def test_dataset_cache_update(snowflake_client: PortalClient):


def test_dataset_cache_multithread(snowflake: QCATestingSnowflake):
# Use shared memory cache
# This is the default for a regular client, but a testing snowflake uses :memory:
snowflake_client: PortalClient = snowflake.client(shared_memory_cache=True)
snowflake_client: PortalClient = snowflake.client()

ds: SinglepointDataset = snowflake_client.add_dataset("singlepoint", "Test dataset")

Expand Down
3 changes: 0 additions & 3 deletions qcportal/qcportal/torsiondrive/record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,6 @@ def _fetch_children_multi(
if r.minimum_optimizations_ and do_minopt:
opt_ids.update(r.minimum_optimizations_.values())

if not opt_ids:
return

opt_ids = list(opt_ids)
opt_records = get_records_with_cache(
client, record_cache, OptimizationRecord, opt_ids, include=include, force_fetch=force_fetch
Expand Down

0 comments on commit d81c496

Please sign in to comment.