Skip to content

Commit

Permalink
Download artifacts in trained-models/, not root
Browse files Browse the repository at this point in the history
Fixes #5
  • Loading branch information
JasonGross committed Jan 19, 2024
1 parent 7e09009 commit e28e04c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
6 changes: 4 additions & 2 deletions gbmi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,16 @@ def try_load_model_from_wandb_download(


def try_load_model_from_wandb(
config: Config, wandb_model_path: str
config: Config, wandb_model_path: str, root: Optional[Union[str, Path]] = None
) -> Optional[Tuple[RunData, HookedTransformer]]:
# Try loading the model from wandb
if root is None:
root = get_trained_model_dir(create=True)
model_dir = None
try:
api = wandb.Api()
model_at = api.artifact(wandb_model_path)
model_dir = Path(model_at.download())
model_dir = Path(model_at.download(str(root)))
except Exception as e:
logging.warning(f"Could not load model {wandb_model_path} from wandb:\n", e)
if model_dir is not None:
Expand Down
5 changes: 4 additions & 1 deletion notebooks_jason/max_of_2_grokking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
import wandb

from gbmi.utils import get_trained_model_dir

api = wandb.Api()

# %%
Expand Down Expand Up @@ -65,8 +67,9 @@
# %%
model_artifacts = list(artifact.logged_by().logged_artifacts())
# %%
root = get_trained_model_dir(create=True)
models = [
(a.version, try_load_model_from_wandb_download(cfg, a.download()), a)
(a.version, try_load_model_from_wandb_download(cfg, a.download(str(root))), a)
for a in model_artifacts
if a.type == "model"
]
Expand Down

0 comments on commit e28e04c

Please sign in to comment.