Skip to content

Commit a6cb078

Browse files
committed
varlen attention tutorial
1 parent 86b1c62 commit a6cb078

File tree

1 file changed

+283
-0
lines changed

1 file changed

+283
-0
lines changed
Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Using Variable Length Attention in PyTorch\n",
8+
"\n",
9+
"## Summary\n",
10+
"\n",
11+
"In this tutorial, we will introduce a variable length attention API. This API is called `varlen_attn` and is a custom op in PyTorch, meaning it is also compilable using `torch.compile`. "
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"> **Note:** \n",
19+
"> This tutorial currently requires you to use the PyTorch nightly build.\n",
20+
"\n",
21+
"### What you will learn\n",
22+
"\n",
23+
"- Variable length attention and how it differs from `scaled_dot_product_attention`\n",
24+
"- Explore an example of how to use `varlen_attn` in a simple Transformer attention layer \n",
25+
"\n",
26+
"### Prerequisites\n",
27+
"\n",
28+
"- PyTorch v2.10.0.dev or later\n",
29+
"- A basic understanding of attention and our current offerings. Please reference these tutorials for more details on flex attention and SDPA. "
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"## Overview of Variable Length Attention \n",
37+
"\n",
38+
"In normal SDPA, sequences are expected to be a fixed length. In practice, this means that input tensors are often **padded** to the same length in a batch. However, this wastes both memory and compute through storing this padding and performing unnecessary computations. \n",
39+
"\n",
40+
"Variable length attention handles sequences of varying length by **packing** the tensors in a batch together and essentially collapsing the batch dimension. \n",
41+
"\n",
42+
"However, we still need to maintain the boundaries between documents. To do so, we compute cumulative sequence positions for query and key/value that mark the end of documents. For example, if doc 1 is 3 tokens long and doc 2 is 5 tokens long, then `cu_seq = [0, 3, 8]`."
43+
]
44+
},
45+
{
46+
"cell_type": "markdown",
47+
"metadata": {},
48+
"source": [
49+
"Below is the definition of `varlen_attn`. \n",
50+
"\n",
51+
"```python\n",
52+
"def varlen_attn(\n",
53+
" query: torch.Tensor,\n",
54+
" key: torch.Tensor,\n",
55+
" value: torch.Tensor,\n",
56+
" cu_seq_q: torch.Tensor,\n",
57+
" cu_seq_k: torch.Tensor,\n",
58+
" max_q: int,\n",
59+
" max_k: int,\n",
60+
" is_causal: bool = False,\n",
61+
" return_aux: AuxRequest | None = None,\n",
62+
") -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:\n",
63+
"```\n",
64+
"\n",
65+
"`query`, `key`, and `value` correspond to the `q`, `k`, and `v` of the packed input. `cu_seq_q` and `cu_seq_k` are the cumulative indices for query and key/value, respectively. These mark the logical boundaries that separate the documents in our input. `max_q` and `max_k` are the maximum sequence lengths of query and key, respectively. `is_causal` applies causal masking if set to True. "
66+
]
67+
},
68+
{
69+
"cell_type": "markdown",
70+
"metadata": {},
71+
"source": [
72+
"Given an input batch, how would we construct the metadata that `varlen_attn` expects? More specifically, how do we calculate the cumulative sequence indices? \n",
73+
"\n",
74+
"The helper function `create_varlen_metadata` returns the required `cu_seqlens` and `max_seqlen` given `input_batch` and the end of sequence token ID that marks the end of documents."
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"metadata": {},
81+
"outputs": [],
82+
"source": [
83+
"import torch\n",
84+
"\n",
85+
"def create_varlen_metadata(input_batch: torch.Tensor, eos_id: int):\n",
86+
" batch_size, seq_len = input_batch.shape\n",
87+
" device = input_batch.device\n",
88+
" cu_seqlens_list, all_seq_lengths = [], []\n",
89+
" offset = 0\n",
90+
"\n",
91+
" for b in range(batch_size):\n",
92+
" tokens = input_batch[b]\n",
93+
" eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32)\n",
94+
"\n",
95+
" # we use the position of the eos tokens to mark the end of documents\n",
96+
" sample_cu_seqlens = torch.cat(\n",
97+
" [\n",
98+
" torch.tensor([0], dtype=torch.int32, device=device),\n",
99+
" eos_positions + 1,\n",
100+
" torch.tensor([seq_len], dtype=torch.int32, device=device),\n",
101+
" ]\n",
102+
" )\n",
103+
" sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens)\n",
104+
"\n",
105+
" seq_lengths = torch.diff(sample_cu_seqlens)\n",
106+
" all_seq_lengths.append(seq_lengths)\n",
107+
"\n",
108+
" cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset\n",
109+
" cu_seqlens_list.append(cu_seqlens_adjusted)\n",
110+
"\n",
111+
" offset += seq_len\n",
112+
"\n",
113+
" packed_cu_seqlens = torch.cat(\n",
114+
" cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]\n",
115+
" )\n",
116+
"\n",
117+
" max_seqlen = 0\n",
118+
" if len(all_seq_lengths) > 0:\n",
119+
" all_seq_lengths = torch.cat(all_seq_lengths)\n",
120+
" max_seqlen = all_seq_lengths.max().item()\n",
121+
"\n",
122+
" return packed_cu_seqlens, max_seqlen"
123+
]
124+
},
125+
{
126+
"cell_type": "markdown",
127+
"metadata": {},
128+
"source": [
129+
"Let's explore how we would use `varlen_attn` in an Attention module. We define an attention module as usual, but in the `forward` method, we call the new `varlen_attn` custom op. \n",
130+
"\n",
131+
"This function expects the `cu_seq` indices amd `max_len` that we computed earlier using `create_varlen_metadata` to mark the boundaries of the different documents. \n",
132+
"\n",
133+
"Before we call `varlen_attn`, we also pack our input so that it has the shape `(total tokens, dim)`. Recall that variable length attention allows us to collapse the `batch_size` dimension so that we can lay out our input samples contiguously. "
134+
]
135+
},
136+
{
137+
"cell_type": "code",
138+
"execution_count": null,
139+
"metadata": {},
140+
"outputs": [],
141+
"source": [
142+
"import torch\n",
143+
"import torch.nn as nn\n",
144+
"from torch.nn.attention.varlen import varlen_attn\n",
145+
"\n",
146+
"\n",
147+
"class SimpleVarlenAttention(nn.Module):\n",
148+
" def __init__(self, embed_dim: int, num_heads: int):\n",
149+
" super().__init__()\n",
150+
" self.embed_dim = embed_dim\n",
151+
" self.num_heads = num_heads\n",
152+
" self.head_dim = embed_dim // num_heads\n",
153+
"\n",
154+
" self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)\n",
155+
" self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)\n",
156+
"\n",
157+
" def forward(\n",
158+
" self, x: torch.Tensor, cu_seq: torch.Tensor, max_len: int\n",
159+
" ) -> torch.Tensor:\n",
160+
" batch_size, seq_len, _ = x.shape\n",
161+
" x_packed = x.view(batch_size * seq_len, -1) # pack x into (total_tokens, dim)\n",
162+
"\n",
163+
" qkv = self.qkv_proj(x_packed)\n",
164+
" q, k, v = qkv.chunk(3, dim=-1)\n",
165+
"\n",
166+
" q = q.view(-1, self.num_heads, self.head_dim)\n",
167+
" k = k.view(-1, self.num_heads, self.head_dim)\n",
168+
" v = v.view(-1, self.num_heads, self.head_dim)\n",
169+
"\n",
170+
" attn_out = varlen_attn(\n",
171+
" query=q,\n",
172+
" key=k,\n",
173+
" value=v,\n",
174+
" cu_seq_q=cu_seq,\n",
175+
" cu_seq_k=cu_seq,\n",
176+
" max_q=max_len,\n",
177+
" max_k=max_len,\n",
178+
" is_causal=True,\n",
179+
" )\n",
180+
" attn_out = attn_out.view(-1, self.embed_dim)\n",
181+
" attn_out = self.out_proj(attn_out)\n",
182+
" return attn_out.view(batch_size, seq_len, self.embed_dim)"
183+
]
184+
},
185+
{
186+
"cell_type": "markdown",
187+
"metadata": {},
188+
"source": [
189+
"Now, we can use this `SimpleVarlenAttention` module in a simple Transformer."
190+
]
191+
},
192+
{
193+
"cell_type": "code",
194+
"execution_count": null,
195+
"metadata": {},
196+
"outputs": [],
197+
"source": [
198+
"class SimpleVarlenTransformer(nn.Module):\n",
199+
" \"\"\"\n",
200+
" simple 1 layer transformer with varlen attention\n",
201+
" \"\"\"\n",
202+
"\n",
203+
" def __init__(self, vocab_size: int, embed_dim: int, num_heads: int):\n",
204+
" super().__init__()\n",
205+
" self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)\n",
206+
" self.attention = SimpleVarlenAttention(embed_dim, num_heads)\n",
207+
" self.norm = nn.LayerNorm(embed_dim)\n",
208+
"\n",
209+
" def forward(\n",
210+
" self, tokens: torch.Tensor, cu_seq: torch.Tensor, max_len: int\n",
211+
" ) -> torch.Tensor:\n",
212+
" x = self.tok_embeddings(tokens)\n",
213+
" x = x + self.attention(x, cu_seq, max_len)\n",
214+
" x = self.norm(x)\n",
215+
" return x"
216+
]
217+
},
218+
{
219+
"cell_type": "markdown",
220+
"metadata": {},
221+
"source": [
222+
"Now we're ready to put all the pieces together! Let's run a training step with our `SimpleVarlenTransformer`. We define our model, compute `cu_seq` and `max_len` using `create_varlen_metadata`, and run a forward and backward pass. "
223+
]
224+
},
225+
{
226+
"cell_type": "code",
227+
"execution_count": null,
228+
"metadata": {},
229+
"outputs": [],
230+
"source": [
231+
"def main():\n",
232+
" torch.manual_seed(42)\n",
233+
"\n",
234+
" batch_size = 3\n",
235+
" seq_len = 64\n",
236+
" vocab_size = 1000\n",
237+
" embed_dim = 128\n",
238+
" num_heads = 4\n",
239+
" eos_id = 2\n",
240+
" num_docs = 3\n",
241+
" device = \"cuda\"\n",
242+
" dtype = torch.bfloat16\n",
243+
"\n",
244+
" model = SimpleVarlenTransformer(vocab_size, embed_dim, num_heads).to(\n",
245+
" device=device, dtype=dtype\n",
246+
" )\n",
247+
"\n",
248+
" # create input_batch tokens\n",
249+
" input_batch = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)\n",
250+
"\n",
251+
" for b in range(batch_size):\n",
252+
" # getting random positions to cut the input into multiple documents\n",
253+
" doc_positions = torch.randint(10, seq_len - 1, (num_docs - 1,))\n",
254+
" for pos in doc_positions:\n",
255+
" input_batch[b, pos] = eos_id # insert eos token to simulate end of sample\n",
256+
" input_batch[b, -1] = eos_id\n",
257+
"\n",
258+
" cu_seq, max_len = create_varlen_metadata(input_batch, eos_id)\n",
259+
" print(f\"cu_seq: {cu_seq}, max_len: {max_len}\") # cu_seq: tensor([0, 32, 47, 64, 92, 103, 128, 168, 177, 192]), max_len: 40\n",
260+
"\n",
261+
" # fwd pass\n",
262+
" output = model(input_batch, cu_seq, max_len)\n",
263+
" print(f\"output shape: {output.shape}\") # (3, 64, 128)\n",
264+
"\n",
265+
" # bwd pass\n",
266+
" loss = output.mean()\n",
267+
" loss.backward()\n",
268+
"\n",
269+
" print(f\"embedding grad shape: {model.tok_embeddings.weight.grad.shape}\") # (1000, 128)\n",
270+
" print(f\"embedding grad norm: {model.tok_embeddings.weight.grad.norm().item()}\")\n",
271+
"\n",
272+
"\n",
273+
"if __name__ == \"__main__\":\n",
274+
" main()"
275+
]
276+
}
277+
],
278+
"metadata": {
279+
"orig_nbformat": 4
280+
},
281+
"nbformat": 4,
282+
"nbformat_minor": 2
283+
}

0 commit comments

Comments
 (0)