diff --git a/end_to_end/test_multi_tier_checkpointing.sh b/end_to_end/test_multi_tier_checkpointing.sh new file mode 100644 index 000000000..977265332 --- /dev/null +++ b/end_to_end/test_multi_tier_checkpointing.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -ex + +RUN_NAME=${1}_$(date +%Y-%m-%d-%H) +OUTPUT_PATH=${2} +DATASET_PATH=${3} +export TPU_PREMAPPED_BUFFER_SIZE=20000014336 +export TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES=20000014336 + +# Train and save checkpoint +python3 MaxText/train.py MaxText/configs/base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ +steps=100 enable_emergency_checkpoint=true checkpoint_period=200 local_checkpoint_directory=/local local_checkpoint_period=20 run_name=$RUN_NAME metrics_file='saved_metrics.txt' + +# Retrieve checkpoint +python3 MaxText/train.py MaxText/configs/base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ +steps=110 enable_emergency_checkpoint=true checkpoint_period=200 local_checkpoint_directory=/local local_checkpoint_period=20 run_name=$RUN_NAME metrics_file='restored_metrics.txt' + + +python3 end_to_end/tpu/eval_assert.py checkpoint_save_restore metrics.txt learning/loss