Skip to content

Commit

Permalink
Fixing model checkpoint naming compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
dev-geof authored Nov 7, 2024
1 parent 7680687 commit acf7fd3
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def transformer_training(

class_CPname = (
outputdir
+ "/classification_training/checkpoints/model.{epoch:02d}-loss-{loss:.5f}-{val_loss:.5f}-acc-{accuracy:.5f}-{val_accuracy:.5f}-auc-{auc:.5f}-{val_auc:.5f}.h5"
+ "/classification_training/checkpoints/model.{epoch:02d}-loss-{loss:.5f}-{val_loss:.5f}-acc-{accuracy:.5f}-{val_accuracy:.5f}-auc-{auc:.5f}-{val_auc:.5f}.keras"
)
reg_CPname = filepath = (
outputdir
+ "/regression_training/checkpoints/model-{epoch:02d}-loss-{loss:.5f}-{val_loss:.5f}-mse-{mean_squared_error:.5f}-{val_mean_squared_error:.5f}.h5"
+ "/regression_training/checkpoints/model-{epoch:02d}-loss-{loss:.5f}-{val_loss:.5f}-mse-{mean_squared_error:.5f}-{val_mean_squared_error:.5f}.keras"
)
CPoutput = class_CPname if training_mode == "classification" else reg_CPname

Expand Down

0 comments on commit acf7fd3

Please sign in to comment.