|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Using Variable Length Attention in PyTorch\n", |
| 8 | + "\n", |
| 9 | + "## Summary\n", |
| 10 | + "\n", |
| 11 | + "In this tutorial, we will introduce a variable length attention API. This API is called `varlen_attn` and is a custom op in PyTorch, meaning it is also compilable using `torch.compile`. " |
| 12 | + ] |
| 13 | + }, |
| 14 | + { |
| 15 | + "cell_type": "markdown", |
| 16 | + "metadata": {}, |
| 17 | + "source": [ |
| 18 | + "> **Note:** \n", |
| 19 | + "> This tutorial currently requires you to use the PyTorch nightly build.\n", |
| 20 | + "\n", |
| 21 | + "### What you will learn\n", |
| 22 | + "\n", |
| 23 | + "- Variable length attention and how it differs from `scaled_dot_product_attention`\n", |
| 24 | + "- Explore an example of how to use `varlen_attn` in a simple Transformer attention layer \n", |
| 25 | + "\n", |
| 26 | + "### Prerequisites\n", |
| 27 | + "\n", |
| 28 | + "- PyTorch v2.10.0.dev or later\n", |
| 29 | + "- A basic understanding of attention and our current offerings. Please reference these tutorials for more details on flex attention and SDPA. " |
| 30 | + ] |
| 31 | + }, |
| 32 | + { |
| 33 | + "cell_type": "markdown", |
| 34 | + "metadata": {}, |
| 35 | + "source": [ |
| 36 | + "## Overview of Variable Length Attention \n", |
| 37 | + "\n", |
| 38 | + "In normal SDPA, sequences are expected to be a fixed length. In practice, this means that input tensors are often **padded** to the same length in a batch. However, this wastes both memory and compute through storing this padding and performing unnecessary computations. \n", |
| 39 | + "\n", |
| 40 | + "Variable length attention handles sequences of varying length by **packing** the tensors in a batch together and essentially collapsing the batch dimension. \n", |
| 41 | + "\n", |
| 42 | + "However, we still need to maintain the boundaries between documents. To do so, we compute cumulative sequence positions for query and key/value that mark the end of documents. For example, if doc 1 is 3 tokens long and doc 2 is 5 tokens long, then `cu_seq = [0, 3, 8]`." |
| 43 | + ] |
| 44 | + }, |
| 45 | + { |
| 46 | + "cell_type": "markdown", |
| 47 | + "metadata": {}, |
| 48 | + "source": [ |
| 49 | + "Below is the definition of `varlen_attn`. \n", |
| 50 | + "\n", |
| 51 | + "```python\n", |
| 52 | + "def varlen_attn(\n", |
| 53 | + " query: torch.Tensor,\n", |
| 54 | + " key: torch.Tensor,\n", |
| 55 | + " value: torch.Tensor,\n", |
| 56 | + " cu_seq_q: torch.Tensor,\n", |
| 57 | + " cu_seq_k: torch.Tensor,\n", |
| 58 | + " max_q: int,\n", |
| 59 | + " max_k: int,\n", |
| 60 | + " is_causal: bool = False,\n", |
| 61 | + " return_aux: AuxRequest | None = None,\n", |
| 62 | + ") -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n", |
| 63 | + "```\n", |
| 64 | + "\n", |
| 65 | + "`query`, `key`, and `value` correspond to the `q`, `k`, and `v` of the packed input. `cu_seq_q` and `cu_seq_k` are the cumulative indices for query and key/value, respectively. These mark the logical boundaries that separate the documents in our input. `max_q` and `max_k` are the maximum sequence lengths of query and key, respectively. `is_causal` applies causal masking if set to True. " |
| 66 | + ] |
| 67 | + }, |
| 68 | + { |
| 69 | + "cell_type": "markdown", |
| 70 | + "metadata": {}, |
| 71 | + "source": [ |
| 72 | + "Given an input batch, how would we construct the metadata that `varlen_attn` expects? More specifically, how do we calculate the cumulative sequence indices? \n", |
| 73 | + "\n", |
| 74 | + "The helper function `create_varlen_metadata` returns the required `cu_seqlens` and `max_seqlen` given `input_batch` and the end of sequence token ID that marks the end of documents." |
| 75 | + ] |
| 76 | + }, |
| 77 | + { |
| 78 | + "cell_type": "code", |
| 79 | + "execution_count": null, |
| 80 | + "metadata": {}, |
| 81 | + "outputs": [], |
| 82 | + "source": [ |
| 83 | + "import torch\n", |
| 84 | + "\n", |
| 85 | + "def create_varlen_metadata(input_batch: torch.Tensor, eos_id: int):\n", |
| 86 | + " batch_size, seq_len = input_batch.shape\n", |
| 87 | + " device = input_batch.device\n", |
| 88 | + " cu_seqlens_list, all_seq_lengths = [], []\n", |
| 89 | + " offset = 0\n", |
| 90 | + "\n", |
| 91 | + " for b in range(batch_size):\n", |
| 92 | + " tokens = input_batch[b]\n", |
| 93 | + " eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32)\n", |
| 94 | + "\n", |
| 95 | + " # we use the position of the eos tokens to mark the end of documents\n", |
| 96 | + " sample_cu_seqlens = torch.cat(\n", |
| 97 | + " [\n", |
| 98 | + " torch.tensor([0], dtype=torch.int32, device=device),\n", |
| 99 | + " eos_positions + 1,\n", |
| 100 | + " torch.tensor([seq_len], dtype=torch.int32, device=device),\n", |
| 101 | + " ]\n", |
| 102 | + " )\n", |
| 103 | + " sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens)\n", |
| 104 | + "\n", |
| 105 | + " seq_lengths = torch.diff(sample_cu_seqlens)\n", |
| 106 | + " all_seq_lengths.append(seq_lengths)\n", |
| 107 | + "\n", |
| 108 | + " cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset\n", |
| 109 | + " cu_seqlens_list.append(cu_seqlens_adjusted)\n", |
| 110 | + "\n", |
| 111 | + " offset += seq_len\n", |
| 112 | + "\n", |
| 113 | + " packed_cu_seqlens = torch.cat(\n", |
| 114 | + " cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]\n", |
| 115 | + " )\n", |
| 116 | + "\n", |
| 117 | + " max_seqlen = 0\n", |
| 118 | + " if len(all_seq_lengths) > 0:\n", |
| 119 | + " all_seq_lengths = torch.cat(all_seq_lengths)\n", |
| 120 | + " max_seqlen = all_seq_lengths.max().item()\n", |
| 121 | + "\n", |
| 122 | + " return packed_cu_seqlens, max_seqlen" |
| 123 | + ] |
| 124 | + }, |
| 125 | + { |
| 126 | + "cell_type": "markdown", |
| 127 | + "metadata": {}, |
| 128 | + "source": [ |
| 129 | + "Let's explore how we would use `varlen_attn` in an Attention module. We define an attention module as usual, but in the `forward` method, we call the new `varlen_attn` custom op. \n", |
| 130 | + "\n", |
| 131 | + "This function expects the `cu_seq` indices amd `max_len` that we computed earlier using `create_varlen_metadata` to mark the boundaries of the different documents. \n", |
| 132 | + "\n", |
| 133 | + "Before we call `varlen_attn`, we also pack our input so that it has the shape `(total tokens, dim)`. Recall that variable length attention allows us to collapse the `batch_size` dimension so that we can lay out our input samples contiguously. " |
| 134 | + ] |
| 135 | + }, |
| 136 | + { |
| 137 | + "cell_type": "code", |
| 138 | + "execution_count": null, |
| 139 | + "metadata": {}, |
| 140 | + "outputs": [], |
| 141 | + "source": [ |
| 142 | + "import torch\n", |
| 143 | + "import torch.nn as nn\n", |
| 144 | + "from torch.nn.attention.varlen import varlen_attn\n", |
| 145 | + "\n", |
| 146 | + "\n", |
| 147 | + "class SimpleVarlenAttention(nn.Module):\n", |
| 148 | + " def __init__(self, embed_dim: int, num_heads: int):\n", |
| 149 | + " super().__init__()\n", |
| 150 | + " self.embed_dim = embed_dim\n", |
| 151 | + " self.num_heads = num_heads\n", |
| 152 | + " self.head_dim = embed_dim // num_heads\n", |
| 153 | + "\n", |
| 154 | + " self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)\n", |
| 155 | + " self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)\n", |
| 156 | + "\n", |
| 157 | + " def forward(\n", |
| 158 | + " self, x: torch.Tensor, cu_seq: torch.Tensor, max_len: int\n", |
| 159 | + " ) -> torch.Tensor:\n", |
| 160 | + " batch_size, seq_len, _ = x.shape\n", |
| 161 | + " x_packed = x.view(batch_size * seq_len, -1) # pack x into (total_tokens, dim)\n", |
| 162 | + "\n", |
| 163 | + " qkv = self.qkv_proj(x_packed)\n", |
| 164 | + " q, k, v = qkv.chunk(3, dim=-1)\n", |
| 165 | + "\n", |
| 166 | + " q = q.view(-1, self.num_heads, self.head_dim)\n", |
| 167 | + " k = k.view(-1, self.num_heads, self.head_dim)\n", |
| 168 | + " v = v.view(-1, self.num_heads, self.head_dim)\n", |
| 169 | + "\n", |
| 170 | + " attn_out = varlen_attn(\n", |
| 171 | + " query=q,\n", |
| 172 | + " key=k,\n", |
| 173 | + " value=v,\n", |
| 174 | + " cu_seq_q=cu_seq,\n", |
| 175 | + " cu_seq_k=cu_seq,\n", |
| 176 | + " max_q=max_len,\n", |
| 177 | + " max_k=max_len,\n", |
| 178 | + " is_causal=True,\n", |
| 179 | + " )\n", |
| 180 | + " attn_out = attn_out.view(-1, self.embed_dim)\n", |
| 181 | + " attn_out = self.out_proj(attn_out)\n", |
| 182 | + " return attn_out.view(batch_size, seq_len, self.embed_dim)" |
| 183 | + ] |
| 184 | + }, |
| 185 | + { |
| 186 | + "cell_type": "markdown", |
| 187 | + "metadata": {}, |
| 188 | + "source": [ |
| 189 | + "Now, we can use this `SimpleVarlenAttention` module in a simple Transformer." |
| 190 | + ] |
| 191 | + }, |
| 192 | + { |
| 193 | + "cell_type": "code", |
| 194 | + "execution_count": null, |
| 195 | + "metadata": {}, |
| 196 | + "outputs": [], |
| 197 | + "source": [ |
| 198 | + "class SimpleVarlenTransformer(nn.Module):\n", |
| 199 | + " \"\"\"\n", |
| 200 | + " simple 1 layer transformer with varlen attention\n", |
| 201 | + " \"\"\"\n", |
| 202 | + "\n", |
| 203 | + " def __init__(self, vocab_size: int, embed_dim: int, num_heads: int):\n", |
| 204 | + " super().__init__()\n", |
| 205 | + " self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)\n", |
| 206 | + " self.attention = SimpleVarlenAttention(embed_dim, num_heads)\n", |
| 207 | + " self.norm = nn.LayerNorm(embed_dim)\n", |
| 208 | + "\n", |
| 209 | + " def forward(\n", |
| 210 | + " self, tokens: torch.Tensor, cu_seq: torch.Tensor, max_len: int\n", |
| 211 | + " ) -> torch.Tensor:\n", |
| 212 | + " x = self.tok_embeddings(tokens)\n", |
| 213 | + " x = x + self.attention(x, cu_seq, max_len)\n", |
| 214 | + " x = self.norm(x)\n", |
| 215 | + " return x" |
| 216 | + ] |
| 217 | + }, |
| 218 | + { |
| 219 | + "cell_type": "markdown", |
| 220 | + "metadata": {}, |
| 221 | + "source": [ |
| 222 | + "Now we're ready to put all the pieces together! Let's run a training step with our `SimpleVarlenTransformer`. We define our model, compute `cu_seq` and `max_len` using `create_varlen_metadata`, and run a forward and backward pass. " |
| 223 | + ] |
| 224 | + }, |
| 225 | + { |
| 226 | + "cell_type": "code", |
| 227 | + "execution_count": null, |
| 228 | + "metadata": {}, |
| 229 | + "outputs": [], |
| 230 | + "source": [ |
| 231 | + "def main():\n", |
| 232 | + " torch.manual_seed(42)\n", |
| 233 | + "\n", |
| 234 | + " batch_size = 3\n", |
| 235 | + " seq_len = 64\n", |
| 236 | + " vocab_size = 1000\n", |
| 237 | + " embed_dim = 128\n", |
| 238 | + " num_heads = 4\n", |
| 239 | + " eos_id = 2\n", |
| 240 | + " num_docs = 3\n", |
| 241 | + " device = \"cuda\"\n", |
| 242 | + " dtype = torch.bfloat16\n", |
| 243 | + "\n", |
| 244 | + " model = SimpleVarlenTransformer(vocab_size, embed_dim, num_heads).to(\n", |
| 245 | + " device=device, dtype=dtype\n", |
| 246 | + " )\n", |
| 247 | + "\n", |
| 248 | + " # create input_batch tokens\n", |
| 249 | + " input_batch = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)\n", |
| 250 | + "\n", |
| 251 | + " for b in range(batch_size):\n", |
| 252 | + " # getting random positions to cut the input into multiple documents\n", |
| 253 | + " doc_positions = torch.randint(10, seq_len - 1, (num_docs - 1,))\n", |
| 254 | + " for pos in doc_positions:\n", |
| 255 | + " input_batch[b, pos] = eos_id # insert eos token to simulate end of sample\n", |
| 256 | + " input_batch[b, -1] = eos_id\n", |
| 257 | + "\n", |
| 258 | + " cu_seq, max_len = create_varlen_metadata(input_batch, eos_id)\n", |
| 259 | + " print(f\"cu_seq: {cu_seq}, max_len: {max_len}\") # cu_seq: tensor([0, 32, 47, 64, 92, 103, 128, 168, 177, 192]), max_len: 40\n", |
| 260 | + "\n", |
| 261 | + " # fwd pass\n", |
| 262 | + " output = model(input_batch, cu_seq, max_len)\n", |
| 263 | + " print(f\"output shape: {output.shape}\") # (3, 64, 128)\n", |
| 264 | + "\n", |
| 265 | + " # bwd pass\n", |
| 266 | + " loss = output.mean()\n", |
| 267 | + " loss.backward()\n", |
| 268 | + "\n", |
| 269 | + " print(f\"embedding grad shape: {model.tok_embeddings.weight.grad.shape}\") # (1000, 128)\n", |
| 270 | + " print(f\"embedding grad norm: {model.tok_embeddings.weight.grad.norm().item()}\")\n", |
| 271 | + "\n", |
| 272 | + "\n", |
| 273 | + "if __name__ == \"__main__\":\n", |
| 274 | + " main()" |
| 275 | + ] |
| 276 | + } |
| 277 | + ], |
| 278 | + "metadata": { |
| 279 | + "orig_nbformat": 4 |
| 280 | + }, |
| 281 | + "nbformat": 4, |
| 282 | + "nbformat_minor": 2 |
| 283 | +} |
0 commit comments