@@ -43,90 +43,152 @@ Expr evaluate_polynomial(Expr x, float *coeff, int n) {
4343 }
4444}
4545
46- // Copied from halide_ext, plan is to add this to Halide.
47- Halide::Tuple halide_ext_exp (const Expr &x_full) {
48- // Type type = x_full.type();
49- // CHECK_EQ(type.element_of(), Float(32));
50-
51- const float ln2_part1 = 0 .6931457519f ;
52- const float ln2_part2 = 1 .4286067653e-6f ;
53- const float one_over_ln2 = 1 .0f / logf (2 .0f );
46+ /* Extended exponential which produces two output values,
47+ * each of the same precision as the input, as described in
48+ * "The Two-Pass Softmax Algorithm" by Marat Dukhan and
49+ * Artsiom Ablavatski [https://arxiv.org/abs/2001.04438].
50+ *
51+ * The first element of the returned Tuple is a psuedo-mantissa while
52+ * the second is an exponent which is an integer. The product of the
53+ * pseudo-mantissa and 2 raised to the returned exponent is the
54+ * desired result e^a. For arguments up to slightly greater than
55+ * 11629079, the pseudo-mantissa is guaranteed to be within the
56+ * interval (-e, e). For larger arguments, the exponent result of the
57+ * tuple may not be able to represent the exact integer necessary to
58+ * keep the pseudo-mantissa within bounds. Thus it can become
59+ * progressively larger in magnitude as the argument increases.
60+ *
61+ * Ideally this routine will maintain a degree of accuracy through the
62+ * entire range and be able to produce results out to the end of the
63+ * numeric range. At present neither of these properties are true due to
64+ * the following issues:
65+ * - Range reduction may overflow when scaling the argument.
66+ * - Range reduction is increasingly inaccurate in reducing the value
67+ * due to the implementation. This results in overflow in the polynomial
68+ * evaluation.
69+ * - Even if the above to issues were resolved, the approximation polynomial
70+ * would have to run on values outside its intended approximation range.
71+ */
72+ Halide::Tuple extended_exp (const Expr &x_full) {
73+ float ln2_part1 = 0 .6931457519f ;
74+ float ln2_part2 = 1 .4286067653e-6f ;
75+ float one_over_ln2 = 1 .0f / logf (2 .0f );
5476
5577 Expr scaled = x_full * one_over_ln2;
5678 Expr k_real = floor (scaled);
5779
5880 Expr x = x_full - k_real * ln2_part1;
59- x -= k_real * ln2_part2;
60-
61- float coeff[] = {0 .00031965933071842413f ,
62- 0 .00119156835564003744f ,
63- 0 .00848988645943932717f ,
64- 0 .04160188091348320655f ,
65- 0 .16667983794100929562f ,
66- 0 .49999899033463041098f ,
67- 1 .0f ,
68- 1 .0f };
81+ x = x - k_real * ln2_part2;
82+
83+ float coeff[] = {
84+ 0 .00031965933071842413f ,
85+ 0 .00119156835564003744f ,
86+ 0 .00848988645943932717f ,
87+ 0 .04160188091348320655f ,
88+ 0 .16667983794100929562f ,
89+ 0 .49999899033463041098f ,
90+ 1 .0f ,
91+ 1 .0f };
6992 Expr result = evaluate_polynomial (x, coeff, sizeof (coeff) / sizeof (coeff[0 ]));
7093
71- result = Halide::Internal::common_subexpression_elimination (result);
94+ // Ensure that the mantissa part is not a NaN or itself an infinity.
95+ result = strict_float (select (!is_finite (k_real), 1 , result));
96+ result = common_subexpression_elimination (result);
7297
7398 return {result, k_real};
7499}
75100
76101} // anonymous namespace
77102
78103struct Softmax : public Halide ::NamesInterface {
79- Softmax (const std::string &base_name)
104+ enum class Algorithm {
105+ Naive,
106+ TwoPass,
107+ ThreePass,
108+ };
109+
110+ Softmax (const std::string &base_name,
111+ Algorithm algorithm = Algorithm::TwoPass)
80112 : base_name(base_name),
113+ algorithm (algorithm),
81114 result(base_name + " _softmax" ),
82115 ext_exp(base_name + " _softmax_ext_exp" ),
83116 exponentials(base_name + " _softmax_exponentials" ),
84- softmax_sums (base_name + " _softmax_sum" ) {
117+ softmax_sum (base_name + " _softmax_sum" ) {
85118 }
86119 std::string base_name;
120+ Algorithm algorithm;
87121 Func result;
88- Func ext_exp;
122+
123+ // Naive algorithm
89124 Func exponentials;
90- Func softmax_sums;
125+
126+ // Two pass algorithm
127+ Func ext_exp;
128+
129+ // Three pass algorithm
130+ Func max_bias;
131+ Func biased_exp;
132+
133+ // Common to different algorithms
134+ Func softmax_sum;
91135 Var result_inner;
92136 RVar softmax_sum_inner; // TODO: Remove this.
93137 Var softmax_sum_inner_var;
94138 LoopLevel softmax_sum_compute_at;
95139
96- // Keeping this to either use for testing or turn into a comment.
97- #if 0
98- void naive_algorithm(Func input, const Type &generating_type) {
99- auto args = input.args();
100- RDom r(0, size);
101-
102- exponentials(args) =
103- default_exp(cast<double>(clamp(input(args), -1e12f, 1e12f)));
104-
105- std::vector<Var> args_sum(args.begin() + 1, args.end());
106- std::vector<Expr> args_reduction;
107- args_reduction.emplace_back(r.x);
108- args_reduction.insert(args_reduction.end(), args_sum.begin(),
109- args_sum.end());
110-
111- softmax_sum(args_sum) = Expr(0.0);
112- softmax_sum(args_sum) += exponentials(args_reduction);
113- softmax_sum_inner = r.x;
114-
115- result(args) = cast(generating_type,
116- input(args) / select(softmax_sum(args_sum) < Expr(1e-5),
117- 1, softmax_sum(args_sum)));
118- result_inner = args[0];
119- }
120- #endif
140+ void apply (Func input, Expr size, const Type &generating_type) {
141+ switch (algorithm) {
142+ case Algorithm::Naive:
143+ naive_algorithm (input, size, generating_type);
144+ break ;
145+ case Algorithm::TwoPass:
146+ two_pass_algorithm (input, size, generating_type);
147+ break ;
148+ case Algorithm::ThreePass:
149+ three_pass_algorithm (input, size, generating_type);
150+ break ;
151+ };
152+ }
153+
154+ void naive_algorithm (Func input, Expr size, const Type &generating_type) {
155+ auto args = input.args ();
156+ RDom r (0 , size);
157+
158+ exponentials (args) =
159+ default_exp (cast<double >(clamp (input (args), -1e12f, 1e12f)));
160+
161+ std::vector<Var> args_sum (args.begin () + 1 , args.end ());
162+ std::vector<Expr> args_reduction;
163+ args_reduction.emplace_back (r.x );
164+ args_reduction.insert (args_reduction.end (), args_sum.begin (),
165+ args_sum.end ());
166+
167+ softmax_sum (args_sum) = Expr (0.0 );
168+ softmax_sum (args_sum) += exponentials (args_reduction);
169+ softmax_sum_inner = r.x ;
170+ softmax_sum_inner_var = args_sum[0 ];
171+
172+ result (args) = cast (generating_type,
173+ input (args) / select (softmax_sum (args_sum) < Expr (1e-5 ),
174+ 1 , softmax_sum (args_sum)));
175+ result_inner = args[0 ];
176+ softmax_sum_compute_at = LoopLevel (result, args[1 ]);
177+ }
121178
122179 // Implementation based on the algorithm in
123180 // https://arxiv.org/pdf/2001.04438.pdf
124- void apply (Func input, Expr size, const Type &generating_type) {
181+ void two_pass_algorithm (Func input, Expr size, const Type &generating_type) {
125182 auto args = input.args ();
126183 RDom r (0 , size);
127184
128- // TODO: avoid needing double here
129- ext_exp (args) = halide_ext_exp (cast<double >(input (args)));
185+ // TODO: It should not be necessary to use double for computation here.
186+ #define USE_DOUBLE 1
187+ #if USE_DOUBLE
188+ ext_exp (args) = extended_exp (cast<double >(input (args)));
189+ #else
190+ ext_exp (args) = extended_exp (input (args));
191+ #endif
130192
131193 std::vector<Var> args_inner (args.begin () + 1 , args.end ());
132194 std::vector<Expr> args_reduction;
@@ -136,32 +198,71 @@ struct Softmax : public Halide::NamesInterface {
136198
137199 // This reduction maintains a Tuple of with the sum and the maximum exponent
138200 // so far, both as floating point numbers.
139- softmax_sums (args_inner) =
140- Tuple (cast<double >(0 ), Expr (std::numeric_limits<double >::lowest ()));
201+ softmax_sum (args_inner) =
202+ #if USE_DOUBLE
203+ Halide::Tuple (Expr (0.0 ), Expr (std::numeric_limits<double >::lowest ()));
204+ #else
205+ Halide::Tuple (0 .0f , Expr (std::numeric_limits<float >::lowest ()));
206+ #endif
141207 Expr running_max_exp =
142- max (softmax_sums (args_inner)[1 ], ext_exp (args_reduction)[1 ]);
208+ max (softmax_sum (args_inner)[1 ], ext_exp (args_reduction)[1 ]);
143209 Expr m_sub_i_term = ext_exp (args_reduction)[0 ] *
144210 pow (2 .0f , ext_exp (args_reduction)[1 ] - running_max_exp);
145- Expr m_sum_term = softmax_sums (args_inner)[0 ] *
146- pow (2 .0f , softmax_sums (args_inner)[1 ] - running_max_exp);
211+ Expr m_sum_term = softmax_sum (args_inner)[0 ] *
212+ pow (2 .0f , softmax_sum (args_inner)[1 ] - running_max_exp);
147213 Expr running_sum = m_sub_i_term + m_sum_term;
148- softmax_sums (args_inner) = Tuple (running_sum, running_max_exp);
149- Expr lambda = 1 / softmax_sums (args_inner)[0 ];
214+ softmax_sum (args_inner) = Tuple (running_sum, running_max_exp);
215+ Expr lambda = 1 / softmax_sum (args_inner)[0 ];
150216 Expr t =
151217 cast (generating_type,
152218 ext_exp (args)[0 ] * lambda *
153- pow (2 .0f , ext_exp (args)[1 ] - softmax_sums (args_inner)[1 ]));
219+ pow (2 .0f , ext_exp (args)[1 ] - softmax_sum (args_inner)[1 ]));
154220 result (args) = t;
155221 result_inner = args[0 ];
156222 softmax_sum_inner = r;
157223 softmax_sum_inner_var = args_inner[0 ];
158224 softmax_sum_compute_at = LoopLevel (result, args[1 ]);
159225 }
160226
227+ void three_pass_algorithm (Func input, Expr size, const Type &generating_type) {
228+ auto args = input.args ();
229+ RDom r (0 , size);
230+
231+ std::vector<Var> args_inner (args.begin () + 1 , args.end ());
232+ std::vector<Expr> args_reduction;
233+ args_reduction.emplace_back (r.x );
234+ args_reduction.insert (args_reduction.end (), args_inner.begin (),
235+ args_inner.end ());
236+
237+ max_bias (args_inner) = std::numeric_limits<float >::lowest ();
238+ max_bias (args_inner) = max (max_bias (args_inner), input (args_reduction));
239+
240+ biased_exp (args) = halide_exp (input (args) - max_bias (args_inner));
241+ softmax_sum (args_inner) = 0 .0f ;
242+ softmax_sum (args_inner) += biased_exp (args_reduction);
243+
244+ Expr lambda = 1 / softmax_sum (args_inner);
245+ result (args) = halide_exp (input (args) - max_bias (args_inner)) * lambda;
246+ result_inner = args[0 ];
247+ softmax_sum_inner = r;
248+ softmax_sum_inner_var = args_inner[0 ];
249+ softmax_sum_compute_at = LoopLevel (result, args[1 ]);
250+ }
251+
252+ // TODO: add support for resuse vs. recompute scheduling on exp operations.
253+
161254 void default_schedule (LoopLevel result_loop_level, const Target &t,
162255 bool vectorize) {
163- ext_exp.compute_inline ();
164- softmax_sums.compute_at (softmax_sum_compute_at)
256+ if (algorithm == Algorithm::Naive) {
257+ exponentials.compute_at (softmax_sum_compute_at);
258+ } else if (algorithm == Algorithm::TwoPass) {
259+ ext_exp.compute_inline ();
260+ } else if (algorithm == Algorithm::ThreePass) {
261+ max_bias.compute_at (softmax_sum_compute_at);
262+ // TODO: vectorize max loop, maybe parallelize
263+ biased_exp.compute_at (softmax_sum_compute_at);
264+ }
265+ softmax_sum.compute_at (softmax_sum_compute_at)
165266 .store_in (MemoryType::Register)
166267 .vectorize (softmax_sum_inner_var, t.natural_vector_size <float >())
167268 .update (0 )
@@ -170,7 +271,11 @@ struct Softmax : public Halide::NamesInterface {
170271 if (vectorize) {
171272 // In some modes, this dimension is narrow and we don't want to vectorize
172273 // it
274+ #if USE_DOUBLE
173275 result.vectorize (result_inner, t.natural_vector_size <double >());
276+ #else
277+ result.vectorize (result_inner, t.natural_vector_size <float >());
278+ #endif
174279 }
175280 }
176281};
0 commit comments