@@ -37,12 +37,12 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
37
37
all_row_maxes = torch .full ((* q .shape [:- 1 ], 1 ), max_neg_value , device = device )
38
38
39
39
scale = (q .shape [- 1 ] ** - 0.5 )
40
- q = q * scale
41
40
42
41
if not exists (mask ):
43
42
mask = (None ,) * math .ceil (q .shape [- 2 ] / q_bucket_size )
44
43
else :
45
- mask = mask .split (q_bucket_size , dim = - 2 )
44
+ mask = rearrange (mask , 'b n -> b 1 1 n' )
45
+ mask = mask .split (q_bucket_size , dim = - 1 )
46
46
47
47
row_splits = zip (
48
48
q .split (q_bucket_size , dim = - 2 ),
@@ -63,7 +63,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
63
63
for k_ind , (kc , vc ) in enumerate (col_splits ):
64
64
k_start_index = k_ind * k_bucket_size
65
65
66
- attn_weights = einsum ('... i d, ... j d -> ... i j' , qc , kc )
66
+ attn_weights = einsum ('... i d, ... j d -> ... i j' , qc , kc ) * scale
67
67
68
68
if exists (row_mask ):
69
69
attn_weights .masked_fill_ (~ row_mask , max_neg_value )
@@ -73,7 +73,6 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
73
73
attn_weights .masked_fill_ (causal_mask , max_neg_value )
74
74
75
75
block_row_maxes = attn_weights .amax (dim = - 1 , keepdims = True )
76
-
77
76
attn_weights -= block_row_maxes
78
77
exp_weights = torch .exp (attn_weights )
79
78
@@ -82,7 +81,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
82
81
83
82
block_row_sums = exp_weights .sum (dim = - 1 , keepdims = True ).clamp (min = EPSILON )
84
83
85
- new_row_maxes = torch .maximum (block_row_maxes , row_sums )
84
+ new_row_maxes = torch .maximum (block_row_maxes , row_maxes )
86
85
87
86
exp_values = einsum ('... i j, ... j d -> ... i d' , exp_weights , vc )
88
87
@@ -92,10 +91,11 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
92
91
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
93
92
94
93
oc .mul_ ((row_sums / new_row_sums ) * exp_row_max_diff ).add_ ((exp_block_row_max_diff / new_row_sums ) * exp_values )
94
+
95
95
row_maxes .copy_ (new_row_maxes )
96
96
row_sums .copy_ (new_row_sums )
97
97
98
- ctx .args = (causal , mask , q_bucket_size , k_bucket_size )
98
+ ctx .args = (causal , scale , mask , q_bucket_size , k_bucket_size )
99
99
ctx .save_for_backward (q , k , v , o , all_row_sums , all_row_maxes )
100
100
101
101
return o
@@ -105,7 +105,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
105
105
def backward (ctx , do ):
106
106
""" Algorithm 4 in the paper """
107
107
108
- causal , mask , q_bucket_size , k_bucket_size = ctx .args
108
+ causal , scale , mask , q_bucket_size , k_bucket_size = ctx .args
109
109
q , k , v , o , l , m = ctx .saved_tensors
110
110
111
111
device = q .device
@@ -117,8 +117,6 @@ def backward(ctx, do):
117
117
dk = torch .zeros_like (k )
118
118
dv = torch .zeros_like (v )
119
119
120
- scale = q .shape [- 1 ] ** - 0.5
121
-
122
120
row_splits = zip (
123
121
q .split (q_bucket_size , dim = - 2 ),
124
122
o .split (q_bucket_size , dim = - 2 ),
@@ -142,8 +140,7 @@ def backward(ctx, do):
142
140
for k_ind , (kc , vc , dkc , dvc ) in enumerate (col_splits ):
143
141
k_start_index = k_ind * k_bucket_size
144
142
145
- qc_scaled = qc * scale
146
- attn_weights = einsum ('... i d, ... j d -> ... i j' , qc_scaled , kc )
143
+ attn_weights = einsum ('... i d, ... j d -> ... i j' , qc , kc ) * scale
147
144
148
145
if causal and q_start_index < (k_start_index + k_bucket_size - 1 ):
149
146
causal_mask = torch .ones ((qc .shape [- 2 ], kc .shape [- 2 ]), dtype = torch .bool , device = device ).triu (q_start_index - k_start_index + 1 )
@@ -197,7 +194,7 @@ def __init__(
197
194
198
195
self .to_q = nn .Linear (dim , inner_dim , bias = False )
199
196
self .to_kv = nn .Linear (dim , inner_dim * 2 , bias = False )
200
- self .to_out = nn .Linear (inner_dim , dim )
197
+ self .to_out = nn .Linear (inner_dim , dim , bias = False )
201
198
202
199
# memory efficient attention related parameters
203
200
# can be overriden on forward
0 commit comments