Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: hide continuous batching complexity through forward context #9098

Open
1 task done
youkaichao opened this issue Oct 5, 2024 · 1 comment
Open
1 task done
Labels

Comments

@youkaichao
Copy link
Member

youkaichao commented Oct 5, 2024

Motivation.

take a look at the current llama forward computation logic:

class LlamaMLP(nn.Module):
    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class LlamaModel(nn.Module):
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        hidden_states = self.get_input_embeddings(input_ids)
        residual = None

        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i - self.start_layer],
                attn_metadata,
                residual,
            )

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class LlamaForCausalLM(nn.Module, SupportsLoRA):
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        model_output = self.model(input_ids, positions, kv_caches,
                                  attn_metadata)
        return model_output

if we don't consider attn_metadata and kv_caches, it can be simplified as:

class LlamaMLP(nn.Module):
    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class LlamaAttention(nn.Module):
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class LlamaDecoderLayer(nn.Module):
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class LlamaModel(nn.Module):
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
        hidden_states = self.get_input_embeddings(input_ids)
        residual = None

        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class LlamaForCausalLM(nn.Module):
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
        model_output = self.model(input_ids, positions)
        return model_output

Arguably, attn_metadata is the most complicated part in the forward computation logic. And it becomes even more complicated when we consider:

  • continuous batching, where we batch data from different sequences together
  • heterogeneous models, where we can have different attention metadata for different layers (e.g. Gemma 2)
  • optimized torch.compile logic, where we want to hide the complexity of attention layer from the compiler

Therefore, I'm considering to hide the complexity of continuous batching through forward context. The idea is to have a global forward context, which can be set by the model runner during every forward pass. The forward context can be used to store the attention metadata, and the model can access the attention metadata through the forward context.

Proposed Change.

The changes are:

  • the model runner will set the forward context before running the model, and the forward context will be used to store the attention metadata and kvcache.
    • For the sake of generality, the forward context should contain a list of attention metadata and kvcache, where each element in the list corresponds to the attention metadata and kvcache for a layer. In the common case where all the layers share the same attention metadata, the model runner is responsible for duplicating the attention metadata.
  • all the files in vllm/model_executor/models will know nothing about attention metadata and kvcache. They will only know about the input tensors and the output tensors, as if they are just doing token-wise computation. Every attention layer will have a new self.layer_index attribute, which will be used to index the attention metadata and kvcache in the forward context.
  • all the attention implementation will be wrapped into a PyTorch custom op so that it is easy to compile. The custom op will only take input tensors, and retrieve the attention metadata and kvcache from the forward context. This way, the complexity of attention metadata and kvcache will be hidden from the compiler.

see #9029 and #9097 for initial steps.

Feedback Period.

No response

CC List.

No response

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@youkaichao youkaichao added the RFC label Oct 5, 2024
@youkaichao
Copy link
Member Author

the model runner will set the forward context before running the model, and the forward context will be used to store the attention metadata and kvcache.

there's one alternative: the top-level model owns and sets the forward context. In the llama case, LlamaForCausalLM sets the forward context.

this approach works better for encoder-decoder models, or multi-modality models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant