@@ -295,8 +295,9 @@ HWY_NOINLINE void Attention(
295295 constexpr size_t kHeads = TConfig::kHeads ;
296296 constexpr size_t kKVHeads = TConfig::kKVHeads ;
297297 constexpr size_t kSeqLen = TConfig::kSeqLen ;
298- GEMMA_CONSTEXPR_SQRT const float kQueryScale =
298+ GEMMA_CONSTEXPR_SQRT float kQueryScale =
299299 1 .0f / Sqrt (static_cast <float >(kQKVDim ));
300+
300301 constexpr bool kIsMHA = TActivations::kIsMHA ; // Multi-Head Attention
301302 const size_t batch_start = batch_and_query_start / num_queries;
302303 const size_t num_tokens_and_queries = num_tokens * num_queries;
@@ -350,7 +351,9 @@ HWY_NOINLINE void Attention(
350351 // Skip past the Q part of `q`, and copy KV to `kv`.
351352 memcpy (kv, q + kQKVDim , 2 * kQKVDim * sizeof (float ));
352353 }
353- Rope (kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim , pos);
354+ if (TConfig::kPostQK == PostQKType::Rope) {
355+ Rope (kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim , pos);
356+ }
354357 });
355358
356359 static_assert ((kHeads % kKVHeads ) == 0 ,
@@ -373,7 +376,10 @@ HWY_NOINLINE void Attention(
373376 activations.att .data () + head * kSeqLen
374377 + batch_and_query_idx * kHeads * kSeqLen ;
375378
376- Rope (q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim , pos);
379+ if (TConfig::kPostQK == PostQKType::Rope) {
380+ Rope (q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim , pos);
381+ }
382+
377383 MulByConst (kQueryScale , q, kQKVDim );
378384
379385 // Compute Q dot K scores
@@ -465,10 +471,12 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
465471 namespace hn = hwy::HWY_NAMESPACE;
466472 using DF = hn::ScalableTag<float >;
467473 using VF = hn::Vec<DF>;
468- hn::Transform1 (DF (), activations.C1 .data (), kFFHiddenDim * num_tokens,
469- activations.C2 .data (), [](DF df, VF v, VF mul) HWY_ATTR {
470- return hn::Mul (mul, Gelu (df, v));
471- });
474+ if (TConfig::kActivation == ActivationType::Gelu) {
475+ hn::Transform1 (DF (), activations.C1 .data (), kFFHiddenDim * num_tokens,
476+ activations.C2 .data (), [](DF df, VF v, VF mul) HWY_ATTR {
477+ return hn::Mul (mul, Gelu (df, v));
478+ });
479+ }
472480
473481 MatMul_4x4_Batch<kFFHiddenDim , kModelDim >(num_tokens, activations.C1 .data (),
474482 layer_weights->linear_w .data (),
@@ -560,29 +568,34 @@ HWY_NOINLINE void TransformerLayer(
560568 layer_weights, kv_caches, pool);
561569 }
562570 }
563- if (TConfig::kPostNormScale ) {
571+
572+ if (TConfig::kPostNorm == PostNormType::Scale) {
564573 RMSNormInplaceBatched<kBatchSize * kQueryBatchSize >(
565574 num_tokens_and_queries,
566575 layer_weights->post_attention_norm_scale .data (),
567576 activations.att_post2 .data (), kModelDim );
568577 }
569- AddFromBatched<kBatchSize * kQueryBatchSize >(num_tokens_and_queries,
570- activations.att_post2 .data (),
571- activations.x .data (), kModelDim );
578+ if (TConfig::kResidual == ResidualType::Add) {
579+ AddFromBatched<kBatchSize * kQueryBatchSize >(
580+ num_tokens_and_queries, activations.att_post2 .data (),
581+ activations.x .data (), kModelDim );
582+ }
572583 RMSNormBatched<kBatchSize * kQueryBatchSize >(
573584 num_tokens_and_queries, activations.x .data (),
574585 layer_weights->pre_ffw_norm_scale .data (),
575586 activations.bf_pre_ffw_rms_out .data (), kModelDim );
576587 FFW<TConfig, kBatchSize * kQueryBatchSize >(
577588 activations, num_tokens_and_queries, layer_weights, pool);
578- if (TConfig::kPostNormScale ) {
589+ if (TConfig::kPostNorm == PostNormType::Scale ) {
579590 RMSNormInplaceBatched<kBatchSize * kQueryBatchSize >(
580591 num_tokens_and_queries, layer_weights->post_ffw_norm_scale .data (),
581592 activations.ffw_out .data (), kModelDim );
582593 }
583- AddFromBatched<kBatchSize * kQueryBatchSize >(
584- num_tokens_and_queries, activations.ffw_out .data (),
585- activations.x .data (), kModelDim );
594+ if (TConfig::kResidual == ResidualType::Add) {
595+ AddFromBatched<kBatchSize * kQueryBatchSize >(
596+ num_tokens_and_queries, activations.ffw_out .data (),
597+ activations.x .data (), kModelDim );
598+ }
586599}
587600
588601template <class TConfig , size_t kBatchSize , size_t kQueryBatchSize >
0 commit comments