@@ -26,7 +26,8 @@ def total_flops(self):
2626
2727 @property
2828 def total_memory (self ):
29- return 2 * self .batch * self .seq_len * self .dim * (self .heads + self .heads_kv ) * self .dtype .itemsize
29+ return 2 * self .batch * self .seq_len * self .dim * (self .heads +
30+ self .heads_kv ) * self .dtype .itemsize
3031
3132 def gen_inputs (self ):
3233 Q = torch .randn (
@@ -38,11 +39,12 @@ def gen_inputs(self):
3839 return Q , K , V
3940
4041 def ref_program (self , Q : torch .Tensor , K : torch .Tensor , V : torch .Tensor ):
41- q_bhsd = Q .transpose (1 , 2 ) # [B, H, S, D]
42+ q_bhsd = Q .transpose (1 , 2 ) # [B, H, S, D]
4243 k_bhsd = K .transpose (1 , 2 )
4344 v_bhsd = V .transpose (1 , 2 )
4445 with sdpa_kernel (backends = [SDPBackend .FLASH_ATTENTION ]):
45- output_bhsd = F .scaled_dot_product_attention (q_bhsd , k_bhsd , v_bhsd , is_causal = self .is_causal , enable_gqa = True )
46+ output_bhsd = F .scaled_dot_product_attention (
47+ q_bhsd , k_bhsd , v_bhsd , is_causal = self .is_causal , enable_gqa = True )
4648 output = output_bhsd .transpose (1 , 2 ).contiguous ()
4749 return output , None # do not check lse
4850
@@ -68,7 +70,8 @@ def total_flops(self):
6870
6971 @property
7072 def total_memory (self ):
71- return self .batch * (3 * self .heads + 4 * self .heads_kv ) * self .seq_len * self .dim * self .dtype .itemsize
73+ return self .batch * (3 * self .heads +
74+ 4 * self .heads_kv ) * self .seq_len * self .dim * self .dtype .itemsize
7275
7376 def gen_inputs (self ):
7477 Q = torch .randn (
@@ -127,7 +130,7 @@ def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, O: torc
127130
128131
129132class gqa_benchmark (Benchmark ):
130-
133+
131134 def __init__ (self , batch , heads , heads_kv , seq_len , dim , is_causal , dtype , grad = True ):
132135 self .batch = batch
133136 self .heads = heads
@@ -138,8 +141,10 @@ def __init__(self, batch, heads, heads_kv, seq_len, dim, is_causal, dtype, grad=
138141 self .dtype = dtype
139142 self .grad = grad
140143
141- self .gqa_fwd_bench = gqa_fwd_benchmark (batch , heads , heads_kv , seq_len , dim , is_causal , dtype )
142- self .gqa_bwd_bench = gqa_bwd_benchmark (batch , heads , heads_kv , seq_len , dim , is_causal , dtype )
144+ self .gqa_fwd_bench = gqa_fwd_benchmark (batch , heads , heads_kv , seq_len , dim , is_causal ,
145+ dtype )
146+ self .gqa_bwd_bench = gqa_bwd_benchmark (batch , heads , heads_kv , seq_len , dim , is_causal ,
147+ dtype )
143148
144149 @property
145150 def total_flops (self ):
@@ -148,14 +153,14 @@ def total_flops(self):
148153 @property
149154 def total_memory (self ):
150155 return self .gqa_fwd_bench .total_memory + self .gqa_bwd_bench .total_memory
151-
156+
152157 def gen_inputs (self ):
153158 if self .grad :
154159 Q , K , V , _ , _ , _ = self .gqa_bwd_bench .gen_inputs ()
155160 return Q , K , V
156161 else :
157162 return self .gqa_fwd_bench .gen_inputs ()
158-
163+
159164 def ref_program (self , Q : torch .Tensor , K : torch .Tensor , V : torch .Tensor ):
160165
161166 output = self .gqa_fwd_bench .ref_program (Q , K , V )[0 ]
@@ -165,4 +170,3 @@ def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
165170 loss = output .sum ()
166171 loss .backward ()
167172 return output , Q .grad , K .grad , V .grad
168-
0 commit comments