diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1f29622bdf20..5e9863ab0d0a 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -545,11 +545,14 @@ def get_1d_rotary_pos_embed( assert dim % 2 == 0 if isinstance(pos, int): - pos = np.arange(pos) + pos = torch.arange(pos) + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # type: ignore # [S] + theta = theta * ntk_factor freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] - t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] - freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] + freqs = freqs.to(pos.device) + freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] @@ -626,7 +629,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor: n_axes = ids.shape[-1] cos_out = [] sin_out = [] - pos = ids.squeeze().float().cpu().numpy() + pos = ids.squeeze().float() is_mps = ids.device.type == "mps" freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes):