Skip to content

Commit

Permalink
refactor rotary embedding 3: so it is not on cpu (#9307)
Browse files Browse the repository at this point in the history
change get_1d_rotary to accept pos as torch tensors
  • Loading branch information
yiyixuxu authored Aug 29, 2024
1 parent 4f495b0 commit 61d96c3
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,14 @@ def get_1d_rotary_pos_embed(
assert dim % 2 == 0

if isinstance(pos, int):
pos = np.arange(pos)
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]

theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
freqs = freqs.to(pos.device)
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
Expand Down Expand Up @@ -626,7 +629,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.squeeze().float().cpu().numpy()
pos = ids.squeeze().float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
Expand Down

0 comments on commit 61d96c3

Please sign in to comment.