Skip to content

Commit

Permalink
Create basic RunMetadata class
Browse files Browse the repository at this point in the history
This makes it a lot easier to interface with the sql database since we
can use dataclasses' asdict, and have a function to sanitize if needed.

Also this partially begins to address #65
  • Loading branch information
WarmCyan committed Nov 9, 2023
1 parent dd80c4a commit 5e81fa5
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 14 deletions.
9 changes: 7 additions & 2 deletions curifactory/dbschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
runs_table = Table(
"run",
metadata_obj,
Column("id", Integer, primary_key=True),
Column("reference", String),
Column("reference", String, primary_key=True),
Column("experiment_name", String),
Column("run_number", Integer),
Column("timestamp", DateTime),
Expand All @@ -33,6 +32,12 @@
Column("hostname", String),
Column("user", String),
Column("notes", String),
Column("uncommited_patch", String),
Column("pip_freeze", String),
Column("conda_env", String),
Column("conda_env_full", String),
Column("os", String),
Column("reproduce", String),
)


Expand Down
74 changes: 74 additions & 0 deletions curifactory/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Class for tracking run metadata."""

from dataclasses import dataclass
from datetime import datetime


@dataclass
class RunMetadata:
"""Data structure for tracking all the relevant metadata for a run,
making it easier/providing a consistent interface for accessing
the information and converting it into formats necessary for saving.
"""

reference: str
"""The full reference name of the experiment, usually
``[experiment_name]_[run_number]_[timestamp]``."""
experiment_name: str
"""The name of the experiment and/or the prefix used for caching."""
run_number: int
"""The run counter for experiments with the given name."""
timestamp: datetime
"""The datetime timestamp for when the manager is initialized (and usually
also when the experiment starts running.)"""

param_files: list[str]
"""The list of parameter file names (minus extension, as they would be
passed into the CLI.)"""
params: dict[str, list[list[str, str]]]
"""A dictionary of parameter file names for keys, where each value is an array of arrays,
each inner array containing the parameter set name and the parameter set hash, for the
parameter sets that come from that parameter file.
e.g. ``{"my_param_file": [ [ "my_param_set", "44b5e428e7165975a3e4f0d1674dbe5f" ] ] }``
"""
full_store: bool
"""Whether this store was being fully exported or not."""

commit: str
"""The current git commit hash."""
workdir_dirty: bool
"""True if there are uncommited changes in the git repo."""
uncommited_patch: str
"""The output of ``git diff -p`` at runtime, to help more precisely reconstruct current codebase."""

status: str
"""Ran status: success/incomplete/error/etc."""
cli: str
"""The CLI line this run was created with."""
reproduce: str
"""The translated CLI line to reproduce this run."""

hostname: str
"""The name of the machine this experiment ran on."""
user: str
"""The name of the user account the experiment was run with."""
notes: str
"""User-entered notes associated with a session/run to output into the report etc."""

pip_freeze: str
"""The output from a ``pip freeze`` command."""
conda_env: str
"""The output from ``conda env export --from-history``."""
conda_env_full: str
"""The output from ``conda env export``."""
os: str
"""The name of the current OS running curifactory."""

def as_sql_safe_dict(self) -> dict:
"""Meant to be used when inserting/updating values in the Runs
sql table.
The targeted column names can be found in dbschema.py.
"""
pass
23 changes: 11 additions & 12 deletions curifactory/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
class SQLStore:
"""EXPERIMENTAL, making an sqlite version of the data below."""

# TODO: (11/8/2023) make this take the full store path instead
def __init__(self, manager_cache_path: str):
self.path = manager_cache_path
"""The location to store the ``store.db``."""
Expand All @@ -26,9 +27,11 @@ def __init__(self, manager_cache_path: str):
self._ensure_tables()

def _ensure_tables(self):
"""Check for the existence of (and create if necessary) all of the tables
listed in dbscheme.py"""
metadata_obj.create_all(self.engine)

def get_run(self, ref_name: str) -> tuple[dict, int]:
def get_run(self, ref_name: str) -> dict:
"""Get the metadata block for the run with the specified reference name.
Args:
Expand All @@ -41,22 +44,18 @@ def get_run(self, ref_name: str) -> tuple[dict, int]:
# https://docs.sqlalchemy.org/en/20/tutorial/data_select.html

with self.engine.connect() as conn:
stmt = select(runs_table).where(
runs_table.c.reference == ref_name
) # TODO: do I need to use prepare?
stmt = select(runs_table).where(runs_table.c.reference == ref_name)
# TODO: do I need to use prepare?
result = conn.execute(stmt)

# if we didn't get any rows back, this run doesn't exist.
if len(result) == 0:
return None, -1
return None

run = result[
0
]._asdict() # NOTE: documented function of namedtuple, _ here doesn't imply hidden/not intended for use
run = result[0]._asdict()
# NOTE: documented function of namedtuple, _ here doesn't imply hidden/not intended for use
run.param_files = json.loads(run.param_files)
return run, run.id

return None, -1
return run

def add_run(self, mngr) -> dict:
"""Add a new metadata block to the store for the passed ``ArtifactManager`` instance.
Expand All @@ -72,7 +71,7 @@ def add_run(self, mngr) -> dict:

# get the new run number
with self.engine.connect() as conn:
stmt = select(func.count(runs_table.c.id)).where(
stmt = select(func.count()).where(
runs_table.c.experiment_name == mngr.experiment_name
)
result = conn.execute(stmt)
Expand Down

0 comments on commit 5e81fa5

Please sign in to comment.