diff --git a/gbmi/model.py b/gbmi/model.py index 2d344cf6..1eb3fbdb 100644 --- a/gbmi/model.py +++ b/gbmi/model.py @@ -263,14 +263,16 @@ def try_load_model_from_wandb_download( def try_load_model_from_wandb( - config: Config, wandb_model_path: str + config: Config, wandb_model_path: str, root: Optional[Union[str, Path]] = None ) -> Optional[Tuple[RunData, HookedTransformer]]: # Try loading the model from wandb + if root is None: + root = get_trained_model_dir(create=True) model_dir = None try: api = wandb.Api() model_at = api.artifact(wandb_model_path) - model_dir = Path(model_at.download()) + model_dir = Path(model_at.download(str(root))) except Exception as e: logging.warning(f"Could not load model {wandb_model_path} from wandb:\n", e) if model_dir is not None: diff --git a/notebooks_jason/max_of_2_grokking.py b/notebooks_jason/max_of_2_grokking.py index 05ff74f2..609d9b83 100644 --- a/notebooks_jason/max_of_2_grokking.py +++ b/notebooks_jason/max_of_2_grokking.py @@ -10,6 +10,8 @@ import torch import wandb +from gbmi.utils import get_trained_model_dir + api = wandb.Api() # %% @@ -65,8 +67,9 @@ # %% model_artifacts = list(artifact.logged_by().logged_artifacts()) # %% +root = get_trained_model_dir(create=True) models = [ - (a.version, try_load_model_from_wandb_download(cfg, a.download()), a) + (a.version, try_load_model_from_wandb_download(cfg, a.download(str(root))), a) for a in model_artifacts if a.type == "model" ]