From 3d281a20b8b6f9c1491bd22246b8beb78db98381 Mon Sep 17 00:00:00 2001 From: thelfer1 Date: Mon, 5 Aug 2024 14:25:23 -0400 Subject: [PATCH] added new models --- SuperResolution/models.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/SuperResolution/models.py b/SuperResolution/models.py index af3297c..e9c0998 100644 --- a/SuperResolution/models.py +++ b/SuperResolution/models.py @@ -169,7 +169,6 @@ def check_performance( net (torch.nn.Module): The neural network model to evaluate. datafolder (str): Path to the folder containing the data. my_loss (torch.nn.Module): The loss function used for evaluation. - device (torch.device): The device to perform the computations on (e.g., 'cpu' or 'cuda'). batchsize (int, optional): The batch size for data loading. Defaults to 50. Returns: @@ -259,6 +258,7 @@ def load_model(directory_path: str) -> Tuple[nn.Module, Dict[str, Any]]: """ # Pattern to match the checkpoint files pattern = r"model_epoch_counter_(\d+)_data_time_\d+\.pth" + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # List to store the checkpoint file names and their indices checkpoints = [] @@ -342,9 +342,14 @@ def load_model(directory_path: str) -> Tuple[nn.Module, Dict[str, Any]]: path_to_largest_checkpoint_file = os.path.join( directory_path, largest_checkpoint_file ) + print(device) # Load the model state if restarting if os.path.exists(path_to_largest_checkpoint_file): - net.load_state_dict(torch.load(path_to_largest_checkpoint_file)) + net.load_state_dict( + torch.load( + path_to_largest_checkpoint_file, map_location=torch.device(device) + ) + ) print(f"loaded model from {path_to_largest_checkpoint_file}") else: print(