Skip to content

Commit 5f5f42c

Browse files
committed
doc: graph: update sdpa and gqa backward pattern
1 parent abdc349 commit 5f5f42c

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

doc/graph/fusion_patterns/gqa.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,17 @@ The blue nodes are required while the brown nodes are optional.
132132
7. The Dropout operation takes the gradients of the dropped probabilities as
133133
input and computes gradients with respect to the normalized probabilities.
134134
See [Dropout](@ref dev_guide_op_dropout) in Graph API.
135-
8. The SoftMaxBackward operation computes the gradients of the scaled output.
136-
See [SoftMaxBackward](@ref dev_guide_op_softmaxbackward) in Graph API.
137-
9. The Scale node after SoftMaxBackward corresponds to the forward Scale node
135+
8. The Multiply, ReduceSum, Subtract, and Multiply operations are used to
136+
compute the gradients with respect to the scaled output according to the
137+
formula: dS = P * (dP - ReduceSum(O * dO)), where P denotes the normalized
138+
probabilities and dP denotes the gradients with respect to them.
139+
9. The Scale node after Multiply corresponds to the forward Scale node
138140
and is used to compute the gradients of the score.
139141
10. The TypeCast, two MatMul and ReduceSum operations after the Scale node
140142
compute the gradients with respect to Query and Key, respectively. TypeCast
141143
is required for bf16 and f16 training scenarios. ReduceSum reduces the Key
142144
gradients from (N, H_kv, N_rep, S, D) to (N, H_kv, 1, S, D).
143-
11. The optional End operation marks the output of SoftMaxBackward as a
145+
11. The optional End operation marks the output of Multiply as a
144146
partition output, representing the gradients with respect to the Mask. Note
145147
that the output shape of `dM` is (N, H_kv, N_rep, S, S) and the data
146148
type is f32. The library does not perform any reduction or typecast on this
6.24 KB
Loading
3.63 KB
Loading

doc/graph/fusion_patterns/sdpa.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,16 @@ are optional.
135135
7. The Dropout operation takes the gradients of the dropped probabilities as
136136
input and computes gradients with respect to the normalized probabilities.
137137
See [Dropout](@ref dev_guide_op_dropout) in Graph API.
138-
8. The SoftMaxBackward operation computes the gradients of the scaled output.
139-
See [SoftMaxBackward](@ref dev_guide_op_softmaxbackward) in Graph API.
140-
9. The Scale node after SoftMaxBackward corresponds to the forward Scale node
138+
8. The Multiply, ReduceSum, Subtract, and Multiply operations are used to
139+
compute the gradients with respect to the scaled output according to the
140+
formula: dS = P * (dP - ReduceSum(O * dO)), where P denotes the normalized
141+
probabilities and dP denotes the gradients with respect to them.
142+
9. The Scale node after Multiply corresponds to the forward Scale node
141143
and is used to compute the gradients of the score.
142144
10. The TypeCast and two MatMul operations after the Scale node compute the
143145
gradients with respect to Query and Key, respectively. TypeCast is required
144146
for bf16 and f16 training scenarios.
145-
11. The optional End operation marks the output of SoftMaxBackward as a
147+
11. The optional End operation marks the output of Multiply as a
146148
partition output, representing the gradients with respect to the Mask. Note
147149
that the output shape of `dM` is (N, H, S, S) and the data
148150
type is f32. The library does not perform any reduction or typecast on this

0 commit comments

Comments
 (0)