From a6e761d0dde8b2761220651d00c44412611a4852 Mon Sep 17 00:00:00 2001 From: RdoubleA Date: Mon, 13 Jan 2025 08:52:02 -0800 Subject: [PATCH] split out tile dimension in rope --- .../modules/test_position_embeddings.py | 79 +++++-------------- torchtune/models/clip/_component_builders.py | 1 + torchtune/modules/position_embeddings.py | 52 ++++++------ 3 files changed, 50 insertions(+), 82 deletions(-) diff --git a/tests/torchtune/modules/test_position_embeddings.py b/tests/torchtune/modules/test_position_embeddings.py index 282fce085c..af9dead941 100644 --- a/tests/torchtune/modules/test_position_embeddings.py +++ b/tests/torchtune/modules/test_position_embeddings.py @@ -136,20 +136,21 @@ def test_rope_init_meta_device(self, input_params): class TestVisionRotaryPositionEmbedding: - EXPECTED_X_OUT_MEAN = tensor(0.0789793) - EXPECTED_X_OUT_SUM = tensor(25.2733822) - EXPECTED_X_OUT_MAX = tensor(3.1225626) + EXPECTED_X_OUT_MEAN = tensor(-0.00903320) + EXPECTED_X_OUT_SUM = tensor(-29.48437119) + EXPECTED_X_OUT_MAX = tensor(4.07074356) @pytest.fixture def input_params(self): bsz = 2 + max_num_tiles = 3 num_heads = 8 embed_dim = 32 head_dim = embed_dim // num_heads - seq_len = 5 patch_size = 4 tile_size = 16 - return bsz, num_heads, head_dim, seq_len, patch_size, tile_size + seq_len = ((tile_size // patch_size) ** 2 + 1) * max_num_tiles + return bsz, num_heads, head_dim, seq_len, max_num_tiles, patch_size, tile_size @pytest.fixture def input(self, input_params) -> tensor: @@ -158,9 +159,12 @@ def input(self, input_params) -> tensor: @pytest.fixture def rope(self, input_params): - _, _, head_dim, _, patch_size, tile_size = input_params + _, _, head_dim, _, max_num_tiles, patch_size, tile_size = input_params return VisionRotaryPositionalEmbeddings( - patch_size=patch_size, tile_size=tile_size, dim=head_dim // 2 + patch_size=patch_size, + tile_size=tile_size, + max_num_tiles=max_num_tiles, + dim=head_dim // 2, ) @mps_ignored_test() @@ -175,63 +179,20 @@ def test_forward(self, input, rope) -> None: # check shapes assert_expected(x_out.shape, input.shape) - @mps_ignored_test() - def test_forward_with_curr_pos(self, input, rope) -> None: - ( - _, - seq_len, - _, - _, - ) = input.shape - x_out = rope(input, input_pos=torch.arange(seq_len)) - - # these values should be exactly the same as test_forward - # since in this case input_pos covers the entire input - # sequence. This tests that input_pos works as expected i.e. - # extracts the embeddings for the relevant positions - assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4) - assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) - assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) - - # check shapes - assert_expected(x_out.shape, input.shape) - - @mps_ignored_test() - def test_forward_with_packed_pos(self, input, rope) -> None: - """ - Use input_pos to indicate positions of each token relative to its sequence - when sample is packed. - """ - ( - bsz, - seq_len, - _, - _, - ) = input.shape - x_out = rope( - input, input_pos=torch.arange(seq_len).unsqueeze(0).expand(bsz, seq_len) - ) - - # these values should be exactly the same as test_forward - # AND test_forward_with_current_pos. In this case input_pos - # covers the entire batch dim and is defined for each sample separately. - # This tests that input_pos works as expected i.e. - # extracts the embeddings for the relevant positions for each sample - assert_expected(x_out.mean(), self.EXPECTED_X_OUT_MEAN, atol=1e-4) - assert_expected(x_out.sum(), self.EXPECTED_X_OUT_SUM) - assert_expected(x_out.max(), self.EXPECTED_X_OUT_MAX) - - # check shapes - assert_expected(x_out.shape, input.shape) - def test_rope_init_meta_device(self, input_params): - _, _, head_dim, _, patch_size, tile_size = input_params + _, _, head_dim, _, max_num_tiles, patch_size, tile_size = input_params rope_on_device = VisionRotaryPositionalEmbeddings( - dim=head_dim, patch_size=patch_size, tile_size=tile_size + dim=head_dim, + patch_size=patch_size, + max_num_tiles=max_num_tiles, + tile_size=tile_size, ) with torch.device("meta"): meta_rope = VisionRotaryPositionalEmbeddings( - dim=head_dim, patch_size=patch_size, tile_size=tile_size + dim=head_dim, + patch_size=patch_size, + tile_size=tile_size, + max_num_tiles=max_num_tiles, ) meta_rope.rope_init() diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index b67061a3f1..edbb31ad32 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -104,6 +104,7 @@ def clip_vision_encoder( VisionRotaryPositionalEmbeddings( patch_size=patch_size, tile_size=tile_size, + max_num_tiles=max_num_tiles, dim=head_dim // 2, base=10_000, append_cls_token=append_cls_token, diff --git a/torchtune/modules/position_embeddings.py b/torchtune/modules/position_embeddings.py index 5f07772d82..ad4004c89e 100644 --- a/torchtune/modules/position_embeddings.py +++ b/torchtune/modules/position_embeddings.py @@ -127,7 +127,9 @@ class VisionRotaryPositionalEmbeddings(nn.Module): This class implements two-dimensional Rotary Positional Embeddings (RoPE) for images based on the axial frequency 2D RoPE described in https://arxiv.org/pdf/2403.13298. - The position embedding is simply applied to the x-axis and y-axis separately. + The position embedding is simply applied to the x-axis and y-axis separately, encoding + the x and y position of each patch within every tile.. The embedding is applied to each + tile identically. Note: This module assumes the CLS token embedding is appended at the end of the sequence. @@ -136,6 +138,8 @@ class VisionRotaryPositionalEmbeddings(nn.Module): E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches. tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, the size of the full input image. In this case, the function will consider your image as a single tile. + max_num_tiles (int): The maximum number of tiles in the image. This is used to unfold the input sequence + length into sequence length per tile so RoPE can be applied to each tile separately. dim (int): Embedding dimension. Unlike :class:`~torchtune.modules.RotaryPositionalEmbeddings`, this is usually set to the dim of each head in the attention module divided by 2, computed as ``embed_dim // num_heads // 2``. The divide by 2 accounts for x and y positions. @@ -149,12 +153,14 @@ def __init__( self, patch_size: int, tile_size: int, + max_num_tiles: int, dim: int, base: int = 10_000, append_cls_token: bool = True, ) -> None: super().__init__() self.patch_grid_size = tile_size // patch_size + self.max_num_tiles = max_num_tiles self.dim = dim self.base = base self.append_cls_token = append_cls_token @@ -209,46 +215,46 @@ def build_rope_cache(self) -> None: self.register_buffer("cache", cache, persistent=False) def forward( - self, x: torch.Tensor, *, input_pos: Optional[torch.Tensor] = None + self, + x: torch.Tensor, + **kwargs: Any, ) -> torch.Tensor: """ Args: - x (torch.Tensor): input tensor with shape - ``[b, s, n_h, h_d]`` - input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids - of each token. During training, this is used to indicate the positions - of each token relative to its sample when packed, shape [b, s]. - During inference, this indicates the position of the current token. - If none, assume the index of the token is its position id. Default is None. + x (torch.Tensor): input tensor with shape ``[b, s, n_h, h_d]`` + **kwargs (Any): additional keyword arguments. This is kept to match the forward signature of + :class:`~torchtune.modules.RotaryPositionalEmbeddings`. Returns: torch.Tensor: output tensor with shape ``[b, s, n_h, h_d]`` + Raises: + ValueError: if sequence length of input tensor does not match the 2D RoPE cache's sequence length + Notation used for tensor shapes: - b: batch size - s: sequence length - n_h: num heads - h_d: head dim """ - # input tensor has shape [b, s, n_h, h_d] - seq_len = x.size(1) - - # extract the values based on whether input_pos is set or not - rope_cache = ( - self.cache[:seq_len] if input_pos is None else self.cache[input_pos] - ) + bsz, _, n_h, h_d = x.shape # reshape input; the last dimension is used for computing the output. + # Split tile dimension from the sequence dimension # Cast to float to match the reference implementation - # tensor has shape [b, s, n_h, h_d // 2, 2] - xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + # tensor has shape [b, max_num_tiles, s // max_num_tiles, n_h, h_d // 2, 2] + xshaped = x.float().reshape(bsz, self.max_num_tiles, -1, n_h, h_d // 2, 2) + seq_len = xshaped.size(2) + + if seq_len != self.cache.shape[0]: + raise ValueError( + f"Input sequence length {seq_len} does not match 2D RoPE cache sequence length {self.cache.shape[0]}." + ) # reshape the cache for broadcasting - # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, - # otherwise has shape [1, s, 1, h_d // 2, 2] - rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) + rope_cache = self.cache.view(1, 1, seq_len, 1, h_d // 2, 2) - # tensor has shape [b, s, n_h, h_d // 2, 2] + # tensor has shape [b, max_num_tiles, s, n_h, h_d // 2, 2] x_out = torch.stack( [ xshaped[..., 0] * rope_cache[..., 0] @@ -260,5 +266,5 @@ def forward( ) # tensor has shape [b, s, n_h, h_d] - x_out = x_out.flatten(3) + x_out = x_out.reshape(bsz, self.max_num_tiles * seq_len, n_h, h_d) return x_out.type_as(x)