Skip to content

Commit

Permalink
split out tile dimension in rope
Browse files Browse the repository at this point in the history
  • Loading branch information
RdoubleA committed Jan 13, 2025
1 parent 8dd77e5 commit a6e761d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 82 deletions.
79 changes: 20 additions & 59 deletions tests/torchtune/modules/test_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,21 @@ def test_rope_init_meta_device(self, input_params):

class TestVisionRotaryPositionEmbedding:

EXPECTED_X_OUT_MEAN = tensor(0.0789793)
EXPECTED_X_OUT_SUM = tensor(25.2733822)
EXPECTED_X_OUT_MAX = tensor(3.1225626)
EXPECTED_X_OUT_MEAN = tensor(-0.00903320)
EXPECTED_X_OUT_SUM = tensor(-29.48437119)
EXPECTED_X_OUT_MAX = tensor(4.07074356)

@pytest.fixture
def input_params(self):
bsz = 2
max_num_tiles = 3
num_heads = 8
embed_dim = 32
head_dim = embed_dim // num_heads
seq_len = 5
patch_size = 4
tile_size = 16
return bsz, num_heads, head_dim, seq_len, patch_size, tile_size
seq_len = ((tile_size // patch_size) ** 2 + 1) * max_num_tiles
return bsz, num_heads, head_dim, seq_len, max_num_tiles, patch_size, tile_size

@pytest.fixture
def input(self, input_params) -> tensor:
Expand All @@ -158,9 +159,12 @@ def input(self, input_params) -> tensor:

@pytest.fixture
def rope(self, input_params):
_, _, head_dim, _, patch_size, tile_size = input_params
_, _, head_dim, _, max_num_tiles, patch_size, tile_size = input_params
return VisionRotaryPositionalEmbeddings(
patch_size=patch_size, tile_size=tile_size, dim=head_dim // 2
patch_size=patch_size,
tile_size=tile_size,
max_num_tiles=max_num_tiles,
dim=head_dim // 2,
)

@mps_ignored_test()
Expand All @@ -175,63 +179,20 @@ def test_forward(self, input, rope) -> None:
# check shapes
assert_expected(x_out.shape, input.shape)

@mps_ignored_test()
def test_forward_with_curr_pos(self, input, rope) -> None:
(
_,
seq_len,
_,
_,
) = input.shape
x_out = rope(input, input_pos=torch.arange(seq_len))

# these values should be exactly the same as test_forward
# since in this case input_pos covers the entire input
# sequence. This tests that input_pos works as expected i.e.
# extracts the embeddings for the relevant positions
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4)
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)

# check shapes
assert_expected(x_out.shape, input.shape)

@mps_ignored_test()
def test_forward_with_packed_pos(self, input, rope) -> None:
"""
Use input_pos to indicate positions of each token relative to its sequence
when sample is packed.
"""
(
bsz,
seq_len,
_,
_,
) = input.shape
x_out = rope(
input, input_pos=torch.arange(seq_len).unsqueeze(0).expand(bsz, seq_len)
)

# these values should be exactly the same as test_forward
# AND test_forward_with_current_pos. In this case input_pos
# covers the entire batch dim and is defined for each sample separately.
# This tests that input_pos works as expected i.e.
# extracts the embeddings for the relevant positions for each sample
assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4)
assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM)
assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX)

# check shapes
assert_expected(x_out.shape, input.shape)

def test_rope_init_meta_device(self, input_params):
_, _, head_dim, _, patch_size, tile_size = input_params
_, _, head_dim, _, max_num_tiles, patch_size, tile_size = input_params
rope_on_device = VisionRotaryPositionalEmbeddings(
dim=head_dim, patch_size=patch_size, tile_size=tile_size
dim=head_dim,
patch_size=patch_size,
max_num_tiles=max_num_tiles,
tile_size=tile_size,
)
with torch.device("meta"):
meta_rope = VisionRotaryPositionalEmbeddings(
dim=head_dim, patch_size=patch_size, tile_size=tile_size
dim=head_dim,
patch_size=patch_size,
tile_size=tile_size,
max_num_tiles=max_num_tiles,
)

meta_rope.rope_init()
Expand Down
1 change: 1 addition & 0 deletions torchtune/models/clip/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def clip_vision_encoder(
VisionRotaryPositionalEmbeddings(
patch_size=patch_size,
tile_size=tile_size,
max_num_tiles=max_num_tiles,
dim=head_dim // 2,
base=10_000,
append_cls_token=append_cls_token,
Expand Down
52 changes: 29 additions & 23 deletions torchtune/modules/position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ class VisionRotaryPositionalEmbeddings(nn.Module):
This class implements two-dimensional Rotary Positional Embeddings (RoPE) for images
based on the axial frequency 2D RoPE described in https://arxiv.org/pdf/2403.13298.
The position embedding is simply applied to the x-axis and y-axis separately.
The position embedding is simply applied to the x-axis and y-axis separately, encoding
the x and y position of each patch within every tile.. The embedding is applied to each
tile identically.
Note: This module assumes the CLS token embedding is appended at the end of the sequence.
Expand All @@ -136,6 +138,8 @@ class VisionRotaryPositionalEmbeddings(nn.Module):
E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches.
tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise,
the size of the full input image. In this case, the function will consider your image as a single tile.
max_num_tiles (int): The maximum number of tiles in the image. This is used to unfold the input sequence
length into sequence length per tile so RoPE can be applied to each tile separately.
dim (int): Embedding dimension. Unlike :class:`~torchtune.modules.RotaryPositionalEmbeddings`, this is
usually set to the dim of each head in the attention module divided by 2, computed as
``embed_dim // num_heads // 2``. The divide by 2 accounts for x and y positions.
Expand All @@ -149,12 +153,14 @@ def __init__(
self,
patch_size: int,
tile_size: int,
max_num_tiles: int,
dim: int,
base: int = 10_000,
append_cls_token: bool = True,
) -> None:
super().__init__()
self.patch_grid_size = tile_size // patch_size
self.max_num_tiles = max_num_tiles
self.dim = dim
self.base = base
self.append_cls_token = append_cls_token
Expand Down Expand Up @@ -209,46 +215,46 @@ def build_rope_cache(self) -> None:
self.register_buffer("cache", cache, persistent=False)

def forward(
self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None
self,
x: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
"""
Args:
x (torch.Tensor): input tensor with shape
``[b, s, n_h, h_d]``
input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b, s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
x (torch.Tensor): input tensor with shape ``[b, s, n_h, h_d]``
**kwargs (Any): additional keyword arguments. This is kept to match the forward signature of
:class:`~torchtune.modules.RotaryPositionalEmbeddings`.
Returns:
torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]``
Raises:
ValueError: if sequence length of input tensor does not match the 2D RoPE cache's sequence length
Notation used for tensor shapes:
- b: batch size
- s: sequence length
- n_h: num heads
- h_d: head dim
"""
# input tensor has shape [b, s, n_h, h_d]
seq_len = x.size(1)

# extract the values based on whether input_pos is set or not
rope_cache = (
self.cache[:seq_len] if input_pos is None else self.cache[input_pos]
)
bsz, _, n_h, h_d = x.shape

# reshape input; the last dimension is used for computing the output.
# Split tile dimension from the sequence dimension
# Cast to float to match the reference implementation
# tensor has shape [b, s, n_h, h_d // 2, 2]
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
# tensor has shape [b, max_num_tiles, s // max_num_tiles, n_h, h_d // 2, 2]
xshaped = x.float().reshape(bsz, self.max_num_tiles, -1, n_h, h_d // 2, 2)
seq_len = xshaped.size(2)

if seq_len != self.cache.shape[0]:
raise ValueError(
f"Input sequence length {seq_len} does not match 2D RoPE cache sequence length {self.cache.shape[0]}."
)

# reshape the cache for broadcasting
# tensor has shape [b, s, 1, h_d // 2, 2] if packed samples,
# otherwise has shape [1, s, 1, h_d // 2, 2]
rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2)
rope_cache = self.cache.view(1, 1, seq_len, 1, h_d // 2, 2)

# tensor has shape [b, s, n_h, h_d // 2, 2]
# tensor has shape [b, max_num_tiles, s, n_h, h_d // 2, 2]
x_out = torch.stack(
[
xshaped[..., 0] * rope_cache[..., 0]
Expand All @@ -260,5 +266,5 @@ def forward(
)

# tensor has shape [b, s, n_h, h_d]
x_out = x_out.flatten(3)
x_out = x_out.reshape(bsz, self.max_num_tiles * seq_len, n_h, h_d)
return x_out.type_as(x)

0 comments on commit a6e761d

Please sign in to comment.