Skip to content

Commit

Permalink
Add test for recompilation w/o hotswapping
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan committed Oct 18, 2024
1 parent ea12e0d commit ec4b0d5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
8 changes: 5 additions & 3 deletions tests/pipelines/run_compiled_model_hotswap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""

import os
import sys
import tempfile

import torch
Expand All @@ -36,7 +37,7 @@


def get_small_unet():
# from UNet2DConditionModelTests
# from UNet2DConditionModelTests
torch.manual_seed(0)
init_dict = {
"block_out_channels": (4, 8),
Expand Down Expand Up @@ -147,5 +148,6 @@ def check_hotswap(do_hotswap):


if __name__ == "__main__":
# check_hotswap(False) will trigger recompilation
check_hotswap(True)
# check_hotswap(True) does not trigger recompilation
# check_hotswap(False) triggers recompilation
check_hotswap(do_hotswap=sys.argv[1] == "1")
19 changes: 17 additions & 2 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2080,14 +2080,14 @@ class TestLoraHotSwapping:
@require_torch_gpu
@require_peft_backend
def test_hotswapping_compiled_model_does_not_trigger_recompilation(self):
# TODO: kinda slow, should it get a slow marker?
env = os.environ.copy()
env["TORCH_LOGS"] = "guards,recompiles"
here = os.path.dirname(__file__)
file_name = os.path.join(here, "run_compiled_model_hotswap.py")

# first test with hotswapping: should not trigger recompilation
process = subprocess.Popen(
[sys.executable, file_name], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE
[sys.executable, file_name, "1"], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)

# Communicate will read the output and error streams, preventing deadlock
Expand All @@ -2099,3 +2099,18 @@ def test_hotswapping_compiled_model_does_not_trigger_recompilation(self):

# check that the recompilation message is not present
assert "__recompiles" not in stderr.decode()

# next, contingency check: without hotswapping, we *do* get recompilation
process = subprocess.Popen(
[sys.executable, file_name, "0"], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)

# Communicate will read the output and error streams, preventing deadlock
stdout, stderr = process.communicate()
exit_code = process.returncode

# sanity check:
assert exit_code == 0

# check that the recompilation message is not present
assert "__recompiles" in stderr.decode()

0 comments on commit ec4b0d5

Please sign in to comment.