@@ -132,15 +132,17 @@ The blue nodes are required while the brown nodes are optional.
1321327 . The Dropout operation takes the gradients of the dropped probabilities as
133133 input and computes gradients with respect to the normalized probabilities.
134134 See [ Dropout] (@ref dev_guide_op_dropout) in Graph API.
135- 8 . The SoftMaxBackward operation computes the gradients of the scaled output.
136- See [ SoftMaxBackward] (@ref dev_guide_op_softmaxbackward) in Graph API.
137- 9 . The Scale node after SoftMaxBackward corresponds to the forward Scale node
135+ 8 . The Multiply, ReduceSum, Subtract, and Multiply operations are used to
136+ compute the gradients with respect to the scaled output according to the
137+ formula: dS = P * (dP - ReduceSum(O * dO)), where P denotes the normalized
138+ probabilities and dP denotes the gradients with respect to them.
139+ 9 . The Scale node after Multiply corresponds to the forward Scale node
138140 and is used to compute the gradients of the score.
13914110 . The TypeCast, two MatMul and ReduceSum operations after the Scale node
140142 compute the gradients with respect to Query and Key, respectively. TypeCast
141143 is required for bf16 and f16 training scenarios. ReduceSum reduces the Key
142144 gradients from (N, H_kv, N_rep, S, D) to (N, H_kv, 1, S, D).
143- 11 . The optional End operation marks the output of SoftMaxBackward as a
145+ 11 . The optional End operation marks the output of Multiply as a
144146 partition output, representing the gradients with respect to the Mask. Note
145147 that the output shape of ` dM ` is (N, H_kv, N_rep, S, S) and the data
146148 type is f32. The library does not perform any reduction or typecast on this
0 commit comments