@@ -1196,27 +1196,26 @@ class TokenStreamer {
11961196 hwy::BitSet4096<> is_eos_;
11971197};
11981198
1199- HWY_INLINE SampleFunc ChooseSampleFunc (int top_k,
1200- const RuntimeConfig& runtime_config) {
1199+ HWY_INLINE SampleFunc ChooseSampleFunc (const RuntimeConfig& runtime_config) {
12011200 // If user provided a sample_func, use it.
12021201 if (runtime_config.sample_func ) return runtime_config.sample_func ;
12031202
12041203 // Fast path for top-1 with no accept_token.
1205- if (top_k == 1 && !runtime_config.accept_token ) {
1204+ if (runtime_config. top_k == 1 && !runtime_config.accept_token ) {
12061205 return [](float * logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
12071206 PROFILER_ZONE (" Gen.Sample Top1" );
12081207 return Top1OfSoftmax (logits, vocab_size);
12091208 };
12101209 }
12111210
12121211 // General case: Softmax with top-k sampling.
1213- return [top_k, &runtime_config](float * logits,
1214- size_t vocab_size) HWY_ATTR -> TokenAndProb {
1212+ return [&runtime_config](float * logits,
1213+ size_t vocab_size) HWY_ATTR -> TokenAndProb {
12151214 PROFILER_ZONE (" Gen.Sample general" );
12161215 Softmax (logits, vocab_size);
1217- const int token =
1218- SampleTopK ( logits, top_k, vocab_size, *runtime_config.gen ,
1219- runtime_config.temperature , runtime_config.accept_token );
1216+ const int token = SampleTopK (
1217+ logits, runtime_config. top_k , vocab_size, *runtime_config.gen ,
1218+ runtime_config.temperature , runtime_config.accept_token );
12201219 return TokenAndProb{.token = token, .prob = logits[token]};
12211220 };
12221221}
@@ -1276,8 +1275,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
12761275 size_t max_prompt_size = MaxQueryLength (queries_prompt);
12771276 size_t max_generated_tokens = runtime_config.max_generated_tokens ;
12781277 RangeChecks (weights.weights_config , max_generated_tokens, max_prompt_size);
1279- const SampleFunc sample_token =
1280- ChooseSampleFunc (weights.weights_config .top_k , runtime_config);
1278+ const SampleFunc sample_token = ChooseSampleFunc (runtime_config);
12811279
12821280 // Prefill stops before min_prompt_size - 1 because the last prompt
12831281 // token is the first input token for generation.
0 commit comments