diff --git a/comfy/ldm/flux/controlnet_xlabs.py b/comfy/ldm/flux/controlnet_xlabs.py
index 3f40021b2e2..e5743321758 100644
--- a/comfy/ldm/flux/controlnet_xlabs.py
+++ b/comfy/ldm/flux/controlnet_xlabs.py
@@ -64,8 +64,8 @@ def forward_orig(
         img = img + controlnet_cond
         vec = self.time_in(timestep_embedding(timesteps, 256))
         if self.params.guidance_embed:
-            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
-        vec = vec + self.vector_in(y)
+            vec.add_(self.guidance_in(timestep_embedding(guidance, 256)))
+        vec.add_(self.vector_in(y))
         txt = self.txt_in(txt)
 
         ids = torch.cat((txt_ids, img_ids), dim=1)
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index da0cf61b1c3..0cd2200c6a3 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -150,14 +150,16 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
 
         # prepare image for attention
         img_modulated = self.img_norm1(img)
-        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+        img_mod1.scale += 1
+        img_modulated = img_mod1.scale * img_modulated + img_mod1.shift
         img_qkv = self.img_attn.qkv(img_modulated)
         img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
         img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
 
         # prepare txt for attention
         txt_modulated = self.txt_norm1(txt)
-        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+        txt_mod1.scale += 1
+        txt_modulated = txt_mod1.scale * txt_modulated + txt_mod1.shift
         txt_qkv = self.txt_attn.qkv(txt_modulated)
         txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
         txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@@ -170,12 +172,12 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
         txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
 
         # calculate the img bloks
-        img = img + img_mod1.gate * self.img_attn.proj(img_attn)
-        img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
+        img = img.addcmul(img_mod1.gate, self.img_attn.proj(img_attn))
+        img.addcmul_(img_mod2.gate, self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift))
 
         # calculate the txt bloks
-        txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
-        txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
+        txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
+        txt.addcmul_(txt_mod2.gate, self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift))
 
         if txt.dtype == torch.float16:
             txt = txt.clip(-65504, 65504)
@@ -221,8 +223,8 @@ def __init__(
 
     def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
         mod, _ = self.modulation(vec)
-        x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
-        qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
+        mod.scale += 1
+        qkv, mlp = torch.split(self.linear1(mod.scale * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
 
         q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
         q, k = self.norm(q, k, v)
@@ -230,8 +232,7 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
         # compute attention
         attn = attention(q, k, v, pe=pe)
         # compute activation in mlp stream, cat again and run second linear layer
-        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
-        x += mod.gate * output
+        x.addcmul_(mod.gate, self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)))
         if x.dtype == torch.float16:
             x = x.clip(-65504, 65504)
         return x
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index b5373540a84..f052bd0e161 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -106,9 +106,9 @@ def forward_orig(
         if self.params.guidance_embed:
             if guidance is None:
                 raise ValueError("Didn't get guidance strength for guidance distilled model.")
-            vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
+            vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)))
 
-        vec = vec + self.vector_in(y)
+        vec.add_(self.vector_in(y))
         txt = self.txt_in(txt)
 
         ids = torch.cat((txt_ids, img_ids), dim=1)