Skip to content

Commit

Permalink
added new models
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHelfer committed Aug 5, 2024
1 parent b2261e0 commit 3d281a2
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions SuperResolution/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def check_performance(
net (torch.nn.Module): The neural network model to evaluate.
datafolder (str): Path to the folder containing the data.
my_loss (torch.nn.Module): The loss function used for evaluation.
device (torch.device): The device to perform the computations on (e.g., 'cpu' or 'cuda').
batchsize (int, optional): The batch size for data loading. Defaults to 50.
Returns:
Expand Down Expand Up @@ -259,6 +258,7 @@ def load_model(directory_path: str) -> Tuple[nn.Module, Dict[str, Any]]:
"""
# Pattern to match the checkpoint files
pattern = r"model_epoch_counter_(\d+)_data_time_\d+\.pth"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# List to store the checkpoint file names and their indices
checkpoints = []
Expand Down Expand Up @@ -342,9 +342,14 @@ def load_model(directory_path: str) -> Tuple[nn.Module, Dict[str, Any]]:
path_to_largest_checkpoint_file = os.path.join(
directory_path, largest_checkpoint_file
)
print(device)
# Load the model state if restarting
if os.path.exists(path_to_largest_checkpoint_file):
net.load_state_dict(torch.load(path_to_largest_checkpoint_file))
net.load_state_dict(
torch.load(
path_to_largest_checkpoint_file, map_location=torch.device(device)
)
)
print(f"loaded model from {path_to_largest_checkpoint_file}")
else:
print(
Expand Down

0 comments on commit 3d281a2

Please sign in to comment.