@@ -1299,16 +1299,17 @@ static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
1299
1299
// Holds "is at end of stream" state for each query.
1300
1300
class TokenStreamer {
1301
1301
public:
1302
- explicit TokenStreamer (const RuntimeConfig& runtime_config)
1303
- : runtime_config_(runtime_config) {}
1302
+ explicit TokenStreamer (const RuntimeConfig& runtime_config,
1303
+ const ModelConfig& model_config)
1304
+ : runtime_config_(runtime_config), model_config_(model_config) {}
1304
1305
1305
1306
// Returns whether the query was already at, or has just reached, the end of
1306
1307
// the stream: either via token == eos_id, or StreamToken returning false.
1307
1308
bool operator ()(size_t query_idx, size_t pos, int token, float prob) {
1308
1309
if (HWY_UNLIKELY (is_eos_.Get (query_idx))) return true ;
1309
1310
1310
1311
if (!runtime_config_.StreamToken (query_idx, pos, token, prob) ||
1311
- token == runtime_config_. eos_id ) {
1312
+ model_config_. IsEOS (token) ) {
1312
1313
is_eos_.Set (query_idx);
1313
1314
return true ;
1314
1315
}
@@ -1318,6 +1319,7 @@ class TokenStreamer {
1318
1319
1319
1320
private:
1320
1321
const RuntimeConfig& runtime_config_;
1322
+ const ModelConfig& model_config_;
1321
1323
hwy::BitSet4096<> is_eos_;
1322
1324
};
1323
1325
@@ -1425,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
1425
1427
// Sanity check: prompts should not be empty, nor start with EOS.
1426
1428
for (size_t query_idx = 0 ; query_idx < queries_prompt.size (); ++query_idx) {
1427
1429
const PromptTokens& prompt = queries_prompt[query_idx];
1428
- HWY_ASSERT (prompt.size () != 0 && prompt[0 ] != runtime_config. eos_id );
1430
+ HWY_ASSERT (prompt.size () != 0 && !model. Config (). IsEOS ( prompt[0 ]) );
1429
1431
}
1430
1432
1431
1433
const size_t num_queries = queries_prompt.size ();
@@ -1469,7 +1471,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
1469
1471
std::vector<int > gen_tokens (num_queries);
1470
1472
1471
1473
// Stream the last prompt token from each query and fill gen_tokens.
1472
- TokenStreamer token_streamer (runtime_config);
1474
+ TokenStreamer token_streamer (runtime_config, model. Config () );
1473
1475
for (size_t query_idx = 0 ; query_idx < num_queries; ++query_idx) {
1474
1476
size_t last_token_pos_in_prompt =
1475
1477
queries_mutable_pos[query_idx] - queries_pos_in[query_idx];
0 commit comments