Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 285 additions & 0 deletions intermediate_source/variable_length_attention_tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using Variable Length Attention in PyTorch\n",
"\n",
"## Summary\n",
"\n",
"In this tutorial, we will introduce a variable length attention API. This API is called `varlen_attn` and is a custom op in PyTorch, meaning it is also compilable using `torch.compile`. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> **Note:** \n",
"> This tutorial currently requires you to use the PyTorch nightly build.\n",
"\n",
"### What you will learn\n",
"\n",
"- Variable length attention and how it differs from `scaled_dot_product_attention`\n",
"- Explore an example of how to use `varlen_attn` in a simple Transformer attention layer \n",
"\n",
"### Prerequisites\n",
"\n",
"- PyTorch v2.10.0.dev or later\n",
"- A basic understanding of attention and our current offerings. Please reference these tutorials for more details on flex attention and SDPA. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Overview of Variable Length Attention \n",
"\n",
"In normal SDPA, sequences are expected to be a fixed length. In practice, this means that input tensors are often **padded** to the same length in a batch. However, this wastes both memory and compute through storing this padding and performing unnecessary computations. \n",
"\n",
"Variable length attention handles sequences of varying length by **packing** the tensors in a batch together and essentially collapsing the batch dimension. \n",
"\n",
"However, we still need to maintain the boundaries between documents. To do so, we compute cumulative sequence positions for query and key/value that mark the end of documents. For example, if doc 1 is 3 tokens long and doc 2 is 5 tokens long, then `cu_seq = [0, 3, 8]`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Below is the definition of `varlen_attn`. \n",
"\n",
"```python\n",
"def varlen_attn(\n",
" query: torch.Tensor,\n",
" key: torch.Tensor,\n",
" value: torch.Tensor,\n",
" cu_seq_q: torch.Tensor,\n",
" cu_seq_k: torch.Tensor,\n",
" max_q: int,\n",
" max_k: int,\n",
" is_causal: bool = False,\n",
" return_aux: AuxRequest | None = None,\n",
") -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n",
"```\n",
"\n",
"`query`, `key`, and `value` correspond to the `q`, `k`, and `v` of the packed input. `cu_seq_q` and `cu_seq_k` are the cumulative indices for query and key/value, respectively. These mark the logical boundaries that separate the documents in our input. `max_q` and `max_k` are the maximum sequence lengths of query and key, respectively. `is_causal` applies causal masking if set to True and `return_aux` specifies which auxiliary outputs to return (ie `lse`).\n",
"\n",
"`varlen_attn` returns the output tensor from the attention computation. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Given an input batch, how would we construct the metadata that `varlen_attn` expects? More specifically, how do we calculate the cumulative sequence indices? \n",
"\n",
"The helper function `create_varlen_metadata` returns the required `cu_seqlens` and `max_seqlen` given `input_batch` and the end of sequence token ID that marks the end of documents."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"def create_varlen_metadata(input_batch: torch.Tensor, eos_id: int):\n",
" batch_size, seq_len = input_batch.shape\n",
" device = input_batch.device\n",
" cu_seqlens_list, all_seq_lengths = [], []\n",
" offset = 0\n",
"\n",
" for b in range(batch_size):\n",
" tokens = input_batch[b]\n",
" eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32)\n",
"\n",
" # we use the position of the eos tokens to mark the end of documents\n",
" sample_cu_seqlens = torch.cat(\n",
" [\n",
" torch.tensor([0], dtype=torch.int32, device=device),\n",
" eos_positions + 1,\n",
" torch.tensor([seq_len], dtype=torch.int32, device=device),\n",
" ]\n",
" )\n",
" sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens)\n",
"\n",
" seq_lengths = torch.diff(sample_cu_seqlens)\n",
" all_seq_lengths.append(seq_lengths)\n",
"\n",
" cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset\n",
" cu_seqlens_list.append(cu_seqlens_adjusted)\n",
"\n",
" offset += seq_len\n",
"\n",
" packed_cu_seqlens = torch.cat(\n",
" cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]\n",
" )\n",
"\n",
" max_seqlen = 0\n",
" if len(all_seq_lengths) > 0:\n",
" all_seq_lengths = torch.cat(all_seq_lengths)\n",
" max_seqlen = all_seq_lengths.max().item()\n",
"\n",
" return packed_cu_seqlens, max_seqlen"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's explore how we would use `varlen_attn` in an Attention module. We define an attention module as usual, but in the `forward` method, we call the new `varlen_attn` custom op. \n",
"\n",
"This function expects the `cu_seq` indices and `max_len` that we computed earlier using `create_varlen_metadata` to mark the boundaries of the different documents. \n",
"\n",
"Before we call `varlen_attn`, we also pack our input so that it has the shape `(total tokens, dim)`. Recall that variable length attention allows us to collapse the `batch_size` dimension so that we can lay out our input samples contiguously. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn.attention.varlen import varlen_attn\n",
"\n",
"\n",
"class SimpleVarlenAttention(nn.Module):\n",
" def __init__(self, embed_dim: int, num_heads: int):\n",
" super().__init__()\n",
" self.embed_dim = embed_dim\n",
" self.num_heads = num_heads\n",
" self.head_dim = embed_dim // num_heads\n",
"\n",
" self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)\n",
" self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)\n",
"\n",
" def forward(\n",
" self, x: torch.Tensor, cu_seq: torch.Tensor, max_len: int\n",
" ) -> torch.Tensor:\n",
" batch_size, seq_len, _ = x.shape\n",
" x_packed = x.view(batch_size * seq_len, -1) # pack x into (total_tokens, dim)\n",
"\n",
" qkv = self.qkv_proj(x_packed)\n",
" q, k, v = qkv.chunk(3, dim=-1)\n",
"\n",
" q = q.view(-1, self.num_heads, self.head_dim)\n",
" k = k.view(-1, self.num_heads, self.head_dim)\n",
" v = v.view(-1, self.num_heads, self.head_dim)\n",
"\n",
" attn_out = varlen_attn(\n",
" query=q,\n",
" key=k,\n",
" value=v,\n",
" cu_seq_q=cu_seq,\n",
" cu_seq_k=cu_seq,\n",
" max_q=max_len,\n",
" max_k=max_len,\n",
" is_causal=True,\n",
" )\n",
" attn_out = attn_out.view(-1, self.embed_dim)\n",
" attn_out = self.out_proj(attn_out)\n",
" return attn_out.view(batch_size, seq_len, self.embed_dim)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can use this `SimpleVarlenAttention` module in a simple Transformer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SimpleVarlenTransformer(nn.Module):\n",
" \"\"\"\n",
" simple 1 layer transformer with varlen attention\n",
" \"\"\"\n",
"\n",
" def __init__(self, vocab_size: int, embed_dim: int, num_heads: int):\n",
" super().__init__()\n",
" self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)\n",
" self.attention = SimpleVarlenAttention(embed_dim, num_heads)\n",
" self.norm = nn.LayerNorm(embed_dim)\n",
"\n",
" def forward(\n",
" self, tokens: torch.Tensor, cu_seq: torch.Tensor, max_len: int\n",
" ) -> torch.Tensor:\n",
" x = self.tok_embeddings(tokens)\n",
" x = x + self.attention(x, cu_seq, max_len)\n",
" x = self.norm(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we're ready to put all the pieces together! Let's run a training step with our `SimpleVarlenTransformer`. We define our model, compute `cu_seq` and `max_len` using `create_varlen_metadata`, and run a forward and backward pass. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def main():\n",
" torch.manual_seed(42)\n",
"\n",
" batch_size = 3\n",
" seq_len = 64\n",
" vocab_size = 1000\n",
" embed_dim = 128\n",
" num_heads = 4\n",
" eos_id = 2\n",
" num_docs = 3\n",
" device = \"cuda\"\n",
" dtype = torch.bfloat16\n",
"\n",
" model = SimpleVarlenTransformer(vocab_size, embed_dim, num_heads).to(\n",
" device=device, dtype=dtype\n",
" )\n",
"\n",
" # create input_batch tokens\n",
" input_batch = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)\n",
"\n",
" for b in range(batch_size):\n",
" # getting random positions to cut the input into multiple documents\n",
" doc_positions = torch.randint(10, seq_len - 1, (num_docs - 1,))\n",
" for pos in doc_positions:\n",
" input_batch[b, pos] = eos_id # insert eos token to simulate end of sample\n",
" input_batch[b, -1] = eos_id\n",
"\n",
" cu_seq, max_len = create_varlen_metadata(input_batch, eos_id)\n",
" print(f\"cu_seq: {cu_seq}, max_len: {max_len}\") # cu_seq: tensor([0, 32, 47, 64, 92, 103, 128, 168, 177, 192]), max_len: 40\n",
"\n",
" # fwd pass\n",
" output = model(input_batch, cu_seq, max_len)\n",
" print(f\"output shape: {output.shape}\") # (3, 64, 128)\n",
"\n",
" # bwd pass\n",
" loss = output.mean()\n",
" loss.backward()\n",
"\n",
" print(f\"embedding grad shape: {model.tok_embeddings.weight.grad.shape}\") # (1000, 128)\n",
" print(f\"embedding grad norm: {model.tok_embeddings.weight.grad.norm().item()}\")\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" main()"
]
}
],
"metadata": {
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading