diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index d873cb8deb58..5dcad9f6cc39 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -945,7 +945,7 @@ def main(args): # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging). # Initialize from (online) unet - target_unet = UNet2DConditionModel(**teacher_unet.config) + target_unet = UNet2DConditionModel.from_config(unet.config) target_unet.load_state_dict(unet.state_dict()) target_unet.train() target_unet.requires_grad_(False) diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 862777411ccc..a7deca72a86f 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -1004,7 +1004,7 @@ def main(args): # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging). # Initialize from (online) unet - target_unet = UNet2DConditionModel(**teacher_unet.config) + target_unet = UNet2DConditionModel.from_config(unet.config) target_unet.load_state_dict(unet.state_dict()) target_unet.train() target_unet.requires_grad_(False)