Skip to content

Commit 05b1cce

Browse files
pcullitoncopybara-github
authored andcommitted
Add support for a secondary EOS token
PiperOrigin-RevId: 738898976
1 parent 83219e3 commit 05b1cce

File tree

7 files changed

+21
-8
lines changed

7 files changed

+21
-8
lines changed

examples/hello_world/run.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ int main(int argc, char** argv) {
8383
++generated;
8484
if (generated < prompt_size) {
8585
// print feedback
86-
} else if (token != gcpp::EOS_ID) {
86+
} else if (!model.GetModelConfig().IsEOS(token)) {
8787
std::string token_text;
8888
HWY_ASSERT(model.Tokenizer().Decode({token}, &token_text));
8989
std::cout << token_text << std::flush;

examples/simplified_gemma/gemma.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class SimplifiedGemma {
8080
++generated;
8181
if (generated < prompt_size) {
8282
// print feedback
83-
} else if (token != gcpp::EOS_ID) {
83+
} else if (!this->model_.GetModelConfig().IsEOS(token)) {
8484
std::string token_text;
8585
HWY_ASSERT(this->model_.Tokenizer().Decode({token}, &token_text));
8686
std::cout << token_text << std::flush;

gemma/configs.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ static ModelConfig ConfigGemmaTiny() {
195195
config.attention_window_sizes = FixedAttentionWindowSizes<3>(32);
196196
// This is required for optimize_test to pass.
197197
config.final_cap = 30.0f;
198+
config.eos_id = 11;
199+
config.secondary_eos_id = 11;
198200
return config;
199201
}
200202

@@ -333,6 +335,8 @@ static ModelConfig ConfigBaseGemmaV3() {
333335
ModelConfig config = ConfigNoSSM();
334336
config.att_cap = 0.0f;
335337
config.final_cap = 0.0f;
338+
config.eos_id = 1;
339+
config.secondary_eos_id = 106;
336340
return config;
337341
}
338342

gemma/configs.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ struct ModelConfig : public IFields {
294294

295295
const char* Name() const override { return "ModelConfig"; }
296296

297+
bool IsEOS(int id) const { return (id == eos_id || id == secondary_eos_id); }
298+
297299
void VisitFields(IFieldsVisitor& visitor) override {
298300
visitor(model_family_version);
299301
visitor(model_name);
@@ -315,6 +317,8 @@ struct ModelConfig : public IFields {
315317
visitor(norm_num_groups);
316318
visitor(vit_config);
317319
visitor(pool_dim);
320+
visitor(eos_id);
321+
visitor(secondary_eos_id);
318322
}
319323

320324
// Major version of the model family. It is used as a fallback to distinguish
@@ -341,6 +345,8 @@ struct ModelConfig : public IFields {
341345
// Dimensions related to image processing.
342346
VitConfig vit_config;
343347
uint32_t pool_dim = 1; // used only for VitConfig copy
348+
int eos_id = 1;
349+
int secondary_eos_id = 1;
344350
};
345351

346352
// Returns the config for the given model.

gemma/gemma-inl.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,16 +1299,17 @@ static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
12991299
// Holds "is at end of stream" state for each query.
13001300
class TokenStreamer {
13011301
public:
1302-
explicit TokenStreamer(const RuntimeConfig& runtime_config)
1303-
: runtime_config_(runtime_config) {}
1302+
explicit TokenStreamer(const RuntimeConfig& runtime_config,
1303+
const ModelConfig& model_config)
1304+
: runtime_config_(runtime_config), model_config_(model_config) {}
13041305

13051306
// Returns whether the query was already at, or has just reached, the end of
13061307
// the stream: either via token == eos_id, or StreamToken returning false.
13071308
bool operator()(size_t query_idx, size_t pos, int token, float prob) {
13081309
if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true;
13091310

13101311
if (!runtime_config_.StreamToken(query_idx, pos, token, prob) ||
1311-
token == runtime_config_.eos_id) {
1312+
model_config_.IsEOS(token)) {
13121313
is_eos_.Set(query_idx);
13131314
return true;
13141315
}
@@ -1318,6 +1319,7 @@ class TokenStreamer {
13181319

13191320
private:
13201321
const RuntimeConfig& runtime_config_;
1322+
const ModelConfig& model_config_;
13211323
hwy::BitSet4096<> is_eos_;
13221324
};
13231325

@@ -1425,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
14251427
// Sanity check: prompts should not be empty, nor start with EOS.
14261428
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
14271429
const PromptTokens& prompt = queries_prompt[query_idx];
1428-
HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id);
1430+
HWY_ASSERT(prompt.size() != 0 && !model.Config().IsEOS(prompt[0]));
14291431
}
14301432

14311433
const size_t num_queries = queries_prompt.size();
@@ -1469,7 +1471,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
14691471
std::vector<int> gen_tokens(num_queries);
14701472

14711473
// Stream the last prompt token from each query and fill gen_tokens.
1472-
TokenStreamer token_streamer(runtime_config);
1474+
TokenStreamer token_streamer(runtime_config, model.Config());
14731475
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
14741476
size_t last_token_pos_in_prompt =
14751477
queries_mutable_pos[query_idx] - queries_pos_in[query_idx];

gemma/run.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
118118
// callback function invoked for each generated token.
119119
auto stream_token = [&](int token, float) {
120120
++abs_pos;
121-
if (token == EOS_ID) {
121+
if (model.GetModelConfig().IsEOS(token)) {
122122
if (app.verbosity >= 2) {
123123
std::cout << "\n[ End ]\n";
124124
}

gemma/tokenizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace gcpp {
2929

3030
// The tokenizer's end of sentence and beginning of sentence token ids.
3131
constexpr int EOS_ID = 1;
32+
constexpr int SECONDARY_EOS_ID = 106; // for Gemma 3
3233
constexpr int BOS_ID = 2;
3334

3435
class GemmaTokenizer {

0 commit comments

Comments
 (0)