Skip to content

Commit

Permalink
Merge pull request #2327 from mahsanghani:patch-2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673915278
  • Loading branch information
copybara-github committed Sep 12, 2024
2 parents 7df1042 + 9555b66 commit 7d04d5d
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions site/en/tutorials/distribute/keras.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@
"# Define the checkpoint directory to store the checkpoints.\n",
"checkpoint_dir = './training_checkpoints'\n",
"# Define the name of the checkpoint files.\n",
"checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")"
"checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch:04d}.weights.h5\")"
]
},
{
Expand Down Expand Up @@ -396,7 +396,7 @@
"# Define a callback for printing the learning rate at the end of each epoch.\n",
"class PrintLR(tf.keras.callbacks.Callback):\n",
" def on_epoch_end(self, epoch, logs=None):\n",
" print('\\nLearning rate for epoch {} is {}'.format( epoch + 1, model.optimizer.lr.numpy()))"
" print('\\nLearning rate for epoch {} is {}'.format(epoch + 1, model.optimizer.learning_rate.numpy()))"
]
},
{
Expand Down Expand Up @@ -486,7 +486,10 @@
},
"outputs": [],
"source": [
"model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))\n",
"import pathlib\n",
"latest_checkpoint = sorted(pathlib.Path(checkpoint_dir).glob('*'))[-1]\n",
"\n",
"model.load_weights(latest_checkpoint)\n",
"\n",
"eval_loss, eval_acc = model.evaluate(eval_dataset)\n",
"\n",
Expand Down

0 comments on commit 7d04d5d

Please sign in to comment.