diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 5d4ca00..db3f48f 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -11,6 +11,7 @@ trigger: include: - "main" - "release/*" + - "model_parallel_exp_support" # temporarily add for new test infra enhancement validation - "refs/tags/*" paths: include: @@ -86,11 +87,21 @@ jobs: python -m coverage run --source src/finetuning_scheduler -m pytest src/finetuning_scheduler tests -v --junitxml=$(Build.Repository.LocalPath)/test-results.xml --durations=50 displayName: 'Testing: standard' + # - bash: | + # . /tmp/venvs/fts_dev/bin/activate + # bash ./tests/standalone_tests.sh -k test_f + # displayName: 'Testing: standalone multi-gpu' + - bash: | . /tmp/venvs/fts_dev/bin/activate - bash ./tests/standalone_tests.sh -k test_f + bash ./tests/special_tests.sh --mark_type=standalone --filter_pattern='test_f' displayName: 'Testing: standalone multi-gpu' + # - bash: | + # . /tmp/venvs/fts_dev/bin/activate + # bash ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='test_f' --experiment_patch_mask="1 0 0" + # displayName: 'Testing: experimental einsum patch' + - bash: | . /tmp/venvs/fts_dev/bin/activate python -m coverage report diff --git a/tests/.experiments b/tests/.experiments new file mode 100644 index 0000000..1b9f259 --- /dev/null +++ b/tests/.experiments @@ -0,0 +1,3 @@ +ENABLE_FTS_EINSUM_STRATEGY_PATCH +ENABLE_FTS_NUMPY_EXTRACTOR_PATCH +ENABLE_FTS_TRITON_CODEGEN_PATCH diff --git a/tests/conftest.py b/tests/conftest.py index 724e846..19bc8b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -163,9 +163,10 @@ def single_process_pg(): os.environ.clear() os.environ.update(orig_environ) - def pytest_collection_modifyitems(items): - # filter out special tests + # select special tests, all special tests run standalone + # note standalone tests take precedence over experimental tests if both env vars are set + # tests depending on experimental patches do not run in CI by default if os.getenv("PL_RUN_STANDALONE_TESTS", "0") == "1": items[:] = [ item @@ -174,11 +175,10 @@ def pytest_collection_modifyitems(items): # has `@RunIf(standalone=True)` if marker.name == "skipif" and marker.kwargs.get("standalone") ] - elif os.getenv("PL_RUN_SLOW_TESTS", "0") == "1": + elif os.getenv("FTS_EXPERIMENTAL_PATCH_TESTS", "0") == "1": items[:] = [ item for item in items for marker in item.own_markers - # has `@RunIf(slow=True)` - if marker.name == "skipif" and marker.kwargs.get("slow") + if marker.name == "skipif" and marker.kwargs.get("exp_patch") ] diff --git a/tests/helpers/common.py b/tests/helpers/common.py index 6aef27e..44bd695 100644 --- a/tests/helpers/common.py +++ b/tests/helpers/common.py @@ -53,7 +53,7 @@ def multiwarn_check( unmatched_warns = partial(multiwarn_check, expected_mode=True) class ExpectedResults(NamedTuple): - expected_state: Dict + expected_state: Optional[Dict] = None warns_expected: Optional[Tuple] = None exceptions_expected: Optional[Tuple] = None diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index d4b7121..c90b497 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -28,10 +28,12 @@ # RunIf aliases RUNIF_MAP = { "min2_5": {"min_torch": "2.5.0"}, + "alone": {"standalone": True}, + "bf16_cuda_alone": {"bf16_cuda": True, "standalone": True}, "min2_2": {"min_torch": "2.2.0"}, "max3_12_min2_3": {"max_python": "3.12", "min_torch": "2.3.0"}, "max3_12_min2_2": {"max_python": "3.12", "min_torch": "2.2.0"}, - "einsum_exp": {"exp_patch": {ExpPatch.EINSUM_STRATEGIES}, "min_torch": "2.5.0"}, + "einsum_exp": {"exp_patch": {ExpPatch.EINSUM_STRATEGIES}}, } @@ -59,7 +61,6 @@ def __new__( skip_mac_os: bool = False, standalone: bool = False, deepspeed: bool = False, - slow: bool = False, exp_patch: Optional[ExpPatch|Set[ExpPatch]] = None, **kwargs, ): @@ -78,8 +79,6 @@ def __new__( This requires that the ``PL_RUN_STANDALONE_TESTS=1`` environment variable is set. deepspeed: Require that microsoft/DeepSpeed is installed. exp_patch: Require that a given experimental patch is installed. - slow: Mark the test as slow, our CI will run it in a separate job. - This requires that the ``PL_RUN_SLOW_TESTS=1`` environment variable is set. **kwargs: Any :class:`pytest.mark.skipif` keyword arguments. """ conditions = [] @@ -148,17 +147,19 @@ def __new__( reasons.append("Deepspeed") if exp_patch: - if not isinstance(exp_patch, Set): - exp_patch = {exp_patch} - conditions.append(not exp_patch.issubset(_ACTIVE_PATCHES)) - reasons.append(f"Required experimental patch configuration {exp_patch} is not active.") - - if slow: - env_flag = os.getenv("PL_RUN_SLOW_TESTS", "0") - conditions.append(env_flag != "1") - reasons.append("Slow test") - # used in tests/conftest.py::pytest_collection_modifyitems - kwargs["slow"] = True + # since we want to ensure we separate all experimental test combinations from normal unpatched tests, we + # gate experimental patches with both an environmental flag and the required subset of active patches + env_flag = os.getenv("FTS_EXPERIMENTAL_PATCH_TESTS", "0") + if env_exp_flag := (env_flag != "1"): + conditions.append(env_exp_flag) + reasons.append("Experimental tests not enabled via 'FTS_EXPERIMENTAL_PATCH_TESTS' env variable") + else: + if not isinstance(exp_patch, Set): + exp_patch = {exp_patch} + conditions.append(not exp_patch.issubset(_ACTIVE_PATCHES)) + reasons.append(f"Required experimental patch configuration {exp_patch} is not active.") + # used in conftest.py::pytest_collection_modifyitems + kwargs["exp_patch"] = True reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif( diff --git a/tests/infra_utils.sh b/tests/infra_utils.sh new file mode 100755 index 0000000..b3af515 --- /dev/null +++ b/tests/infra_utils.sh @@ -0,0 +1,192 @@ +#!/bin/bash +# Test infra utility functions +# Note we use local variables for many of these to allow more usage flexibility in different contexts + +toggle_experimental_patches() { + # Function to encapsulate toggling of the current FTS experimental patch flags on and off. Usage example: + # toggle_experimental_patches /path/to/.experiments 1 0 1 + export patch_report='' + filepath="$1" + shift + + declare -a exp_patch_flags=($(cat "$filepath")) + declare -a patch_mask=("$@") + + if [[ ${#exp_patch_flags[@]} -ne ${#patch_mask[@]} ]]; then + echo "Error: There are currently ${#exp_patch_flags[@]} defined experiments, provided mask should have that length." >&2 + return 1 + fi + + for i in "${!exp_patch_flags[@]}"; do + let arg_idx=i+1 + if [[ ${patch_mask[$i]} -eq 1 ]]; then + export "${exp_patch_flags[$i]}"=1 + patch_report+="${exp_patch_flags[$i]} value is now: ${!exp_patch_flags[$i]}\n" + else + unset "${exp_patch_flags[$i]}" + fi + done +} + +collect_tests(){ + local collect_def="$1" + local collect_log="$2" + if special_tests=$(python3 ${collect_def}); then + # match only lines with tests + declare -a -g parameterizations=($(grep -oP '\S+::test_\S+' <<< "$special_tests")) + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $collect_log + printf "Collected the following tests: \n" | tee -a $collect_log + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $collect_log + printf '%s\n' "${parameterizations[@]}" | tee -a $collect_log + num_collected_tests="${#parameterizations[@]}" + echo "Total number of tests: ${#parameterizations[@]}" | tee -a $collect_log + printf '\n' | tee -a $collect_log + else + printf "No tests were found with the following collection command: python3 ${collect_def} \n" | tee -a $collect_log + printf "Exiting without running tests. \n" | tee -a $collect_log + export no_tests_collected=1 + exit 0 + fi +} + +execute_tests(){ + ensure_tests + local execute_def="$1" + local execute_log="$2" + local tmp_out="$3" + # hardcoded tests to skip - space separated + blocklist='' + export report='' + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $execute_log + printf "Running the collected tests: \n" | tee -a $execute_log + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $execute_log + + for i in "${!parameterizations[@]}"; do + parameterization=${parameterizations[$i]} + + # check blocklist + if echo $blocklist | grep -F "${parameterization}"; then + report+="Skipped\t$parameterization\n" + continue + fi + + # run the test + echo "Running ${parameterization}" | tee -a $execute_log + (python ${execute_def} ${parameterization} 2>&1 | sed "s,\x1b\[[0-9;]*[a-zA-Z],,g" >> $tmp_out) > /dev/null + test_to_find=`echo ${parameterization} | sed 's/\[/\\\[/g; s/\]/\\\]/g'` + if pass_or_fail=$(grep -E "(PASSED|FAILED|XPASS|XFAIL) .*${test_to_find}" $tmp_out); then + parameterization_result=`echo $pass_or_fail | awk 'NR==1 {print $2 ": " $1}'`; + elif skipped=$(grep -E "${test_to_find}.*SKIPPED" $tmp_out); then + parameterization_result=`echo $skipped | awk 'NR==1 {print $1 ": " $2}'`; + else + echo "Could not parse result!" | tee -a $execute_log + parameterization_result="UNKNOWN: see $tmp_out" + fi + report+="Ran\t${parameterization_result}\n" + done +} + +show_test_counts(){ + local test_log="$1" + export num_failed=0 + export num_other=0 + if grep_succ=($(printf "$report" | grep -c "PASSED\|XPASSED\|XFAIL")); then num_succ=$grep_succ; else num_succ=0; fi + if grep_failed=($(printf "$report" | grep -c "FAILED")); then num_failed=$grep_failed; fi + if grep_skipped=($(printf "$report" | grep -c "SKIPPED")); then num_skipped=$grep_skipped; else num_skipped=0; fi + printf "\n" | tee -a $test_log + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $test_log + printf "Test count summary: \n" | tee -a $test_log + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $test_log + printf "Collected: $num_collected_tests \n" | tee -a $test_log + printf "Succeeded (passed+xpassed+xfail): $num_succ \n" | tee -a $test_log + printf "Intentionally skipped: $num_skipped \n" | tee -a $test_log + printf "Failed: $num_failed \n" | tee -a $test_log + num_other=$(($num_collected_tests - $num_succ - $num_failed - $num_skipped)) + if [ $num_other -gt 0 ]; then + printf "Other (usually tests skipped due to prior test failure): $num_other \n" | tee -a $test_log + fi + printf '\n' | tee -a $test_log +} + +show_summary(){ + local test_log="$1" + # summarize test report + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $test_log + printf "Finished Tests: \n" | tee -a $test_log + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $test_log + printf "$report" | tee -a $test_log +} + +show_final_summary(){ + local test_log="$1" + local tmp_out="${2:-}" + show_summary "$test_log" + show_test_counts "$test_log" + show_elapsed_time "$test_log" + exit_with_status "$test_log" +} + +exit_with_status(){ + local test_log="$1" + exit_code=0 + if [ $num_failed -gt 0 ] || [ $num_other -gt 0 ]; then + exit_code=1 + printf "**Failure (${num_failed}) or other (${num_other}) test counts were greater than 0**! \n" | tee -a $test_log + else + printf "Failure (${num_failed}) and other (${num_other}) test counts were not greater than 0. \n" | tee -a $test_log + fi + printf "Exiting with status code ${exit_code}. \n" | tee -a $test_log + exit $exit_code +} + +ensure_tests(){ + if [ -n "$no_tests_collected" ]; then + exit 0 + fi +} + +show_test_results(){ + ensure_tests + local test_log="$1" + local tmp_out="$2" + if [ -f ${tmp_out} ]; then + if grep_errors=($(grep --ignore-case --extended-regexp 'error|exception|traceback|failed' ${tmp_out})); then + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $test_log + printf "Potential errors detected. See ${tmp_out} for details. Exception/error lines to follow. \n" | tee -a $test_log + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $test_log + printf "\n" | tee -a $test_log + show_final_summary "$test_log" + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $test_log + printf "Grepped exception/error lines: \n" | tee -a $test_log + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $test_log + grep --ignore-case --extended-regexp 'error|exception' ${tmp_out} | tee -a $test_log + printf "\n" | tee -a $test_log + else + printf "No detected errors. \n" | tee -a $test_log + printf "\n" | tee -a $test_log + show_final_summary "$test_log" + fi + elif [ -f ${test_log} ]; then # if the log but not the out exists, check for collection errors + if grep --ignore-case --extended-regexp 'traceback|failed' ${test_log} ; then + echo "Potential collection error!" | tee -a $test_log + show_final_summary "$test_log" + exit 1 + fi + fi +} + +show_elapsed_time(){ + local test_log="$1" + script_name=${2:-$(basename "$0")} + ## write elapsed time in user-friendly fashion + end_time=$(date +%s) + elapsed_seconds=$(($end_time-$start_time)) + if (( $elapsed_seconds/60 == 0 )); then + printf "${script_name} completed in $elapsed_seconds seconds \n" | tee -a $test_log + elif (( $elapsed_seconds%60 == 0 )); then + printf "${script_name} completed in $(($elapsed_seconds/60)) minutes \n" | tee -a $test_log + else + printf "${script_name} completed in $(($elapsed_seconds/60)) minutes and $(($elapsed_seconds%60)) seconds \n" | tee -a $test_log + fi + printf "\n" | tee -a $test_log +} diff --git a/tests/special_tests.sh b/tests/special_tests.sh new file mode 100755 index 0000000..ca45eea --- /dev/null +++ b/tests/special_tests.sh @@ -0,0 +1,135 @@ +#!/bin/bash +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail # only disable this when debugging to allow more context + +unset mark_type +unset log_file +unset filter_pattern +unset experiments_list +unset experiment_patch_mask +unset PL_RUN_STANDALONE_TESTS +unset FTS_RUN_STANDALONE_TESTS +unset FTS_EXPERIMENTAL_PATCH_TESTS + +source $(dirname "$0")/infra_utils.sh + +usage(){ +>&2 cat << EOF +Usage: $0 + [ --mark_type input] + [ --log_file input] + [ --filter_pattern input] + [ --experiments_list input] + [ --experiment_patch_mask input] + [ --help ] + Examples: + # run all standalone tests (but not experimental ones, note --mark_type defaults to 'standalone' ): + # ./tests/special_tests.sh + # run all standalone tests following a pattern: + # ./tests/special_tests.sh --mark_type=standalone --filter_pattern='test_f' + # run all standalone tests passing a parent process log file to use: + # ./tests/special_tests.sh --mark_type=standalone --log_file=/tmp/some_parent_process_file_to_append_to.log + # run all experimental tests following a pattern that are supported by a given experimental patch mask using the + # default `tests/.experiments` experiments definition location: + # ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='test_f' --experiment_patch_mask="1 0 0" + # same as above, but use a custom experiments definition location: + # ./tests/special_tests.sh --mark_type=exp_patch --filter_pattern='model_parallel' --experiments_list=tests/.my_experiments --experiment_patch_mask="1 0 0" +EOF +exit 1 +} + +args=$(getopt -o '' --long mark_type:,log_file:,filter_pattern:,experiments_list:,experiment_patch_mask:,help -- "$@") +if [[ $? -gt 0 ]]; then + usage +fi + + +eval set -- ${args} +while : +do + case $1 in + --mark_type) mark_type=$2 ; shift 2 ;; + --log_file) log_file=$2 ; shift 2 ;; + --filter_pattern) filter_pattern=$2 ; shift 2 ;; + --experiments_list) experiments_list=$2 ; shift 2 ;; + --experiment_patch_mask) experiment_patch_mask+=($2) ; shift 2 ;; + --help) usage ; shift ;; + --) shift; break ;; + *) >&2 echo Unsupported option: $1 + usage ;; + esac +done + +d=`date +%Y%m%d%H%M%S` +tmp_log_dir="/tmp" +mark_type=${mark_type:-"standalone"} +experiments_list=${experiments_list:-$(dirname "$0")/.experiments} +if [ -s "${experiments_list}" ]; then + if [ -z "${experiment_patch_mask:-}" ]; then + experiment_patch_mask=($(cat tests/.experiments | awk '{for(i=1;i<=NF;i++) print "0"}')) + fi +fi +special_test_session_log=${log_file:-"${tmp_log_dir}/special_tests_${mark_type}_${d}.log"} +test_session_tmp_out="${tmp_log_dir}/special_tests_${mark_type}_${d}.out" + +# default python coverage arguments +exec_defaults='-m coverage run --source src/finetuning_scheduler --append -m pytest --capture=no --no-header -v -s -rA' +collect_defaults='-m pytest tests -q --collect-only --pythonwarnings ignore' +start_time=$(date +%s) +echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $special_test_session_log +printf "FTS special tests beginning execution at ${d} PT \n" | tee -a $special_test_session_log +echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $special_test_session_log +printf "\n" | tee -a $special_test_session_log + +define_configuration(){ + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $special_test_session_log + printf "Run configuration: \n" | tee -a $special_test_session_log + echo `printf "%0.s-" {1..120} && printf "\n"` | tee -a $special_test_session_log + case ${mark_type} in + standalone) + echo "Collecting and running standalone tests." | tee -a $special_test_session_log + export PL_RUN_STANDALONE_TESTS=1 + ;; + exp_patch) + echo "Collecting and running only experimental patch tests supported w/ provided patch mask (${experiment_patch_mask[@]})." | tee -a $special_test_session_log + export FTS_EXPERIMENTAL_PATCH_TESTS=1 + ;; + *) + echo "no matching `mark_type` found, exiting..." | tee -a $special_test_session_log + exit 1 + ;; + esac + + if [ -s "${experiments_list}" ]; then + # toggle optional experimental patches if requested + toggle_experimental_patches ${experiments_list} "${experiment_patch_mask[@]}" + else + echo "No experimental patches were found in the environment." | tee -a $special_test_session_log + fi + printf "${patch_report}" | tee -a $special_test_session_log + + if [[ -n ${filter_pattern} ]]; then + echo "Using filter pattern: ${filter_pattern}" | tee -a $special_test_session_log + exec_defaults+=" -k ${filter_pattern}" + collect_defaults+=" -k ${filter_pattern}" + fi + printf '\n' | tee -a $special_test_session_log +} + +trap 'show_test_results "$special_test_session_log" "$test_session_tmp_out"' EXIT # show the output on exit + +## Special coverage collection flow +define_configuration +collect_tests "$collect_defaults" "$special_test_session_log" +execute_tests "$exec_defaults" "$special_test_session_log" "$test_session_tmp_out" diff --git a/tests/standalone_tests.sh b/tests/standalone_tests.sh deleted file mode 100755 index 2f5df59..0000000 --- a/tests/standalone_tests.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/bin/bash -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Originally based on https://bit.ly/3AZGVVT -set -e - -# this environment variable allows special tests to run -export PL_RUN_STANDALONE_TESTS=1 -# python arguments -defaults='-m coverage run --source src/finetuning_scheduler --append -m pytest --capture=no --no-header -v -s' - -# find tests marked as `@RunIf(standalone=True)`. done manually instead of with pytest because it is faster -grep_output=$(grep --recursive --word-regexp 'tests' --regexp 'standalone=True' --include '*.py' --exclude 'tests/conftest.py') - -# file paths, remove duplicates -files=$(echo "$grep_output" | cut -f1 -d: | sort | uniq) - -# get the list of parametrizations. we need to call them separately. the last two lines are removed. -# note: if there's a syntax error, this will fail with some garbled output -if [[ "$OSTYPE" == "darwin"* ]]; then - parametrizations=$(pytest $files --collect-only --quiet "$@" | tail -r | sed -e '1,3d' | tail -r) -else - parametrizations=$(pytest $files --collect-only --quiet "$@" | head -n -2) -fi -parametrizations_arr=($parametrizations) - -# tests to skip - space separated -blocklist='' -report='' - -rm -f standalone_test_output.txt # in case it exists, remove it -function show_output { - if [ -f standalone_test_output.txt ]; then # if exists - cat standalone_test_output.txt - # heuristic: stop if there's mentions of errors. this can prevent false negatives when only some of the ranks fail - if grep --quiet --ignore-case --extended-regexp 'error|exception|traceback|failed' standalone_test_output.txt; then - echo "Potential error! Stopping." - rm standalone_test_output.txt - exit 1 - fi - rm standalone_test_output.txt - fi -} -trap show_output EXIT # show the output on exit - -for i in "${!parametrizations_arr[@]}"; do - parametrization=${parametrizations_arr[$i]} - - # check blocklist - if echo $blocklist | grep -F "${parametrization}"; then - report+="Skipped\t$parametrization\n" - continue - fi - - # run the test - echo "Running ${parametrization}" - python ${defaults} "${parametrization}" - - report+="Ran\t$parametrization\n" -done - -show_output - -# echo test report -printf '=%.s' {1..80} -printf "\n$report" -printf '=%.s' {1..80} -printf '\n' diff --git a/tests/test_finetuning_scheduler_callback.py b/tests/test_finetuning_scheduler_callback.py index b274e14..a24af15 100644 --- a/tests/test_finetuning_scheduler_callback.py +++ b/tests/test_finetuning_scheduler_callback.py @@ -313,7 +313,7 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) - self.expected_state = expected_state + self.expected_state = expected_state or {} self.lrs_state = lrs_state self.mock_strategy = mock_strategy self.state_log_dir = state_log_dir diff --git a/tests/test_model_parallel.py b/tests/test_model_parallel.py index d0a2429..481108e 100644 --- a/tests/test_model_parallel.py +++ b/tests/test_model_parallel.py @@ -208,16 +208,18 @@ def on_train_epoch_start(self, trainer, pl_module): super(TestFinetuningScheduler, self).on_train_epoch_start(trainer, pl_module) model_parallel_sample = {} state_key = trainer.current_epoch - target_p_keys = self.expected_state[state_key][0]['p_states'].keys() - model_parallel_sample['p_states'] = self._collect_p_states(target_p_keys) - if target_mod_keys := self.expected_state[state_key][0].get('fsdp_mod_states', {}).keys(): - model_parallel_sample['fsdp_mod_states'] = self._collect_fsdp_mod_states(target_mod_keys) - current_state = ( - model_parallel_sample, - len(self._fts_state._curr_thawed_params), - ) - lrs_state = None - self.inspect_or_assert(current_state, lrs_state, state_key) + if expected_epoch_state := self.expected_state.get(state_key): + if target_p_keys := expected_epoch_state[0].get('p_states', {}).keys(): + model_parallel_sample['p_states'] = self._collect_p_states(target_p_keys) + if target_mod_keys := expected_epoch_state[0].get('fsdp_mod_states', {}).keys(): + model_parallel_sample['fsdp_mod_states'] = self._collect_fsdp_mod_states(target_mod_keys) + if target_mod_keys or target_p_keys: + current_state = ( + model_parallel_sample, + len(self._fts_state._curr_thawed_params), + ) + lrs_state = None + self.inspect_or_assert(current_state, lrs_state, state_key) def _collect_p_states(self, tp_keys: KeysView) -> Dict[Any, Dict]: p_states = {} @@ -425,8 +427,12 @@ def gen_apply_transformer_tp_plan(model: nn.Module, device_mesh: DeviceMesh, los ## Lightning Trainer Configuration Aliases trainer_defaults = {"accelerator": "gpu", "devices": 2, 'limit_train_batches': 2, 'limit_val_batches': 2, 'num_sanity_val_steps': 0} -no_sanity_val = {"num_sanity_val_steps": 0} -max_epoch_4 = {"max_epochs": 4} +# no_sanity_val = {"num_sanity_val_steps": 0} +# max_epoch_4 = {"max_epochs": 4} + +## Precision Configuration Aliases +fp16 = {"precision": "16-true"} +bf16 = {"precision": "bf16-true"} ## cust ckpt cfg no_ckpt_save = {"save_top_k": 0} @@ -442,7 +448,6 @@ def gen_apply_transformer_tp_plan(model: nn.Module, device_mesh: DeviceMesh, los @dataclass class ModelParallelTestConfig: model_cfg_key: str - expected_results: ExpectedResults model_cls: Callable model_cfg: Dict = field(default_factory=dict) trainer_cfg: Dict = field(default_factory=lambda: {'max_epochs': 3}) @@ -457,6 +462,7 @@ class ModelParallelTestConfig: es_cfg: Dict = field(default_factory=lambda: {"patience": 1}) ckpt_cfg: Dict = field(default_factory=lambda: {"save_top_k": 3}) ckpt_cls: Callable = FTSCheckpoint + expected_results: ExpectedResults = ExpectedResults() runif_alias: Optional[str] = None def __post_init__(self): self.default_fts_cfg = { @@ -473,24 +479,32 @@ def test_torch_greater_equal_2_5(): ModelParallelStrategyAdapter() ## Model Parallel Test Definitions -FTS_MODEL_PARALLEL_TESTS = ( - ModelParallelTestConfig(model_cfg_key="tt_fsdp_tp", model_cls=tt_mod_parallel, +FTS_MODEL_PARALLEL_PATH_TESTS = ( + ModelParallelTestConfig(model_cfg_key="path_tt_fsdp_tp", model_cls=tt_mod_parallel, model_cfg=tt_fsdp_tp, fts_cfg=no_restore_best, ckpt_cfg=no_ckpt_save, strategy_cfg=dp2_tp1, runif_alias="einsum_exp", expected_results=ExpectedResults(expected_state=path_tt_fsdp_tp)), - ModelParallelTestConfig(model_cfg_key="tt_fsdp_no_tp", model_cls=tt_mod_parallel, - model_cfg=tt_fsdp_no_tp, strategy_cfg=dp2_tp1, runif_alias="min2_5", + ModelParallelTestConfig(model_cfg_key="path_tt_fsdp_no_tp", model_cls=tt_mod_parallel, + model_cfg=tt_fsdp_no_tp, strategy_cfg=dp2_tp1, runif_alias="alone", expected_results=ExpectedResults(expected_state=path_tt_fsdp_no_tp)), - ModelParallelTestConfig(model_cfg_key="tt_tp_no_fsdp_lp", model_cls=tt_mod_parallel, - model_cfg=tt_tp_no_fsdp_lp, strategy_cfg=dp1_tp2, runif_alias="min2_5", + ModelParallelTestConfig(model_cfg_key="path_tt_tp_no_fsdp_lp", model_cls=tt_mod_parallel, + model_cfg=tt_tp_no_fsdp_lp, strategy_cfg=dp1_tp2, runif_alias="alone", expected_results=ExpectedResults(expected_state=path_tt_tp_no_fsdp)), - ModelParallelTestConfig(model_cfg_key="tt_tp_no_fsdp_no_lp", model_cls=tt_mod_parallel, - model_cfg=tt_tp_no_fsdp_no_lp, strategy_cfg=dp1_tp2, runif_alias="min2_5", + ModelParallelTestConfig(model_cfg_key="path_tt_tp_no_fsdp_no_lp", model_cls=tt_mod_parallel, + model_cfg=tt_tp_no_fsdp_no_lp, strategy_cfg=dp1_tp2, runif_alias="alone", expected_results=ExpectedResults(expected_state=path_tt_tp_no_fsdp)), + ModelParallelTestConfig(model_cfg_key="tt_fsdp_no_tp_fp16", model_cls=tt_mod_parallel, + fts_cfg=no_restore_best, + precision_opts=fp16, + model_cfg=tt_fsdp_no_tp, strategy_cfg=dp2_tp1, runif_alias="alone"), + # ModelParallelTestConfig(model_cfg_key="tt_tp_no_fsdp_bf16", model_cls=tt_mod_parallel, precision_opts=bf16, + # model_cfg=tt_tp_no_fsdp_lp, strategy_cfg=dp1_tp2, runif_alias="bf16_alone"), + # ModelParallelTestConfig(model_cfg_key="tt_tp_no_fsdp_fp16", model_cls=tt_mod_parallel, precision_opts=fp16, + # model_cfg=tt_tp_no_fsdp_no_lp, strategy_cfg=dp1_tp2, runif_alias="alone") ) -@RunIf(min_cuda_gpus=2, standalone=True) -@pytest.mark.parametrize("test_cfg", pytest_param_factory(FTS_MODEL_PARALLEL_TESTS)) -def test_fts_model_parallel(tmpdir, recwarn, model_parallel_ft_schedule, test_cfg): +@RunIf(min_cuda_gpus=2, min_torch="2.5.0") +@pytest.mark.parametrize("test_cfg", pytest_param_factory(FTS_MODEL_PARALLEL_PATH_TESTS)) +def test_fts_model_parallel_integration(tmpdir, recwarn, model_parallel_ft_schedule, test_cfg): """Validate :class:`~finetuning_scheduler.FinetuningScheduler` functions properly in a supported 'ddp' distributed context.""" seed_everything(42) @@ -506,8 +520,8 @@ def test_fts_model_parallel(tmpdir, recwarn, model_parallel_ft_schedule, test_cf with trainer.init_module(empty_init=True): model = test_cfg.model_cls(**test_cfg.model_cfg) # TODO: verify updated tt_cfg is applied here configured_model = torch.compile(model) if use_dynamo else model - if test_cfg.expected_results.exceptions_expected: - gen_exceptions(trainer, configured_model, test_cfg.model_cfg_key, test_cfg.expected_results.exceptions_expected) + if exc_expect := test_cfg.expected_results.exceptions_expected: + gen_exceptions(trainer, configured_model, test_cfg.model_cfg_key, exc_expect) else: trainer.fit(configured_model) default_fts_sanity_chk(trainer) @@ -516,7 +530,6 @@ def test_fts_model_parallel(tmpdir, recwarn, model_parallel_ft_schedule, test_cf warns_expected=test_cfg.expected_results.warns_expected, expected_warns_dynamo=MODEL_PARALLEL_DYNAMO_EXPECTED_WARNS, use_dynamo=use_dynamo) - def gen_exceptions(trainer, model, exception_expected): with pytest.raises(MisconfigurationException, match=exception_expected): trainer.fit(model)