Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions crates/apr-cli/src/commands/pretrain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,30 @@ pub(crate) struct ResolvedHp {
pub target_val_loss: f32,
}

/// SPEC §82 P1-A: Estimate transformer parameter count from arch dims.
///
/// Formula (decoder-only, tied or untied embedding):
/// N ≈ vocab × hidden (embedding)
/// + L × (4·hidden² + 3·hidden·intermediate) (per-layer attn + ffn)
/// + hidden (final norm)
///
/// Embedding is counted once (assumes tied lm_head; for untied add a 2nd
/// `vocab × hidden`). This is a coarse estimate suitable for Chinchilla
/// scaling sanity checks, not a precise param report — for that, use
/// `apr inspect --json | jq .parameters`.
fn estimate_param_count(arch: &TransformerConfig) -> u64 {
let vocab = arch.vocab_size as u64;
let hidden = arch.hidden_size as u64;
let inter = arch.intermediate_size as u64;
let layers = arch.num_hidden_layers as u64;
let embed = vocab.saturating_mul(hidden);
let attn_per_layer = 4u64.saturating_mul(hidden).saturating_mul(hidden);
let ffn_per_layer = 3u64.saturating_mul(hidden).saturating_mul(inter);
let per_layer = attn_per_layer.saturating_add(ffn_per_layer);
let layer_total = layers.saturating_mul(per_layer);
embed.saturating_add(layer_total).saturating_add(hidden)
}

pub(crate) fn mode_defaults(
mode: PretrainMode,
vocab_size: u32,
Expand Down Expand Up @@ -172,6 +196,42 @@ pub(crate) fn run(

let hp = mode_defaults(mode, vocab_size, lr, warmup_steps, target_val_loss);

// SPEC §82 P1-A: Chinchilla compute-optimal gate (arXiv:2203.15556).
// Compute-optimal pretraining requires train tokens D ≈ 20·N where N is
// the parameter count. If D < 5·N we're severely under-trained; the
// model will memorize the small corpus instead of generalizing.
//
// Triggered for `--init` runs where we can read the arch dims to
// estimate N; from-scratch synthetic runs are exempt because the
// operator usually knows what they're doing. Non-fatal warning only.
if let Some(arch) = init_arch.as_ref() {
let n_params = estimate_param_count(arch);
let d_tokens = (num_steps as u64)
.saturating_mul(batch_size as u64)
.saturating_mul(seq_length as u64);
let ratio = d_tokens as f64 / n_params as f64;
if ratio < 5.0 {
eprintln!(
"[P1-A] Chinchilla gate WARNING: train tokens D = {} ({:.1}M) is {:.2}× param count N = {} ({:.1}M); \
Chinchilla compute-optimal target is D ≈ 20·N. Run is severely under-trained — \
expect val_loss plateau driven by capacity exhaustion, not optimization. \
Consider increasing --num-steps to ~{} or reducing model size.",
d_tokens, d_tokens as f64 / 1e6,
ratio,
n_params, n_params as f64 / 1e6,
(20 * n_params) / (batch_size as u64 * seq_length as u64),
);
} else if ratio < 20.0 {
eprintln!(
"[P1-A] Chinchilla gate: train tokens D = {} ({:.1}M) is {:.1}× param count N = {} ({:.1}M); \
below compute-optimal 20·N target — model has room for more training.",
d_tokens, d_tokens as f64 / 1e6,
ratio,
n_params, n_params as f64 / 1e6,
);
}
}

// Validation: GATE-TRAIN-003 requires target_val_loss > 0.
if hp.target_val_loss <= 0.0 {
return Err(CliError::ValidationFailed(format!(
Expand Down Expand Up @@ -819,6 +879,47 @@ mod tests {
std::fs::write(dir.join("vocab.json"), json).expect("write vocab.json");
}

/// SPEC §82 P1-A: parameter count estimator should be order-of-magnitude
/// correct for known reference models. Qwen2.5-0.5B has ~500M params;
/// our coarse formula should be within 2× of that.
#[test]
fn estimate_param_count_qwen2_05b_within_2x() {
let mut cfg = TransformerConfig::llama2_7b();
cfg.hidden_size = 896;
cfg.num_hidden_layers = 24;
cfg.num_attention_heads = 14;
cfg.num_kv_heads = 2;
cfg.intermediate_size = 4864;
cfg.vocab_size = 151936;
let n = estimate_param_count(&cfg);
// True Qwen2.5-0.5B = ~494M. Our estimate counts tied embedding once
// and ignores GQA reduction; expect ~400-700M.
let ref_params: u64 = 494_000_000;
assert!(
n > ref_params / 2 && n < ref_params * 2,
"Qwen2.5-0.5B estimate {n} should be within 2× of 494M",
);
}

/// SPEC §82 P1-A: estimator should scale super-linearly with depth.
#[test]
fn estimate_param_count_scales_with_layers() {
let mut cfg = TransformerConfig::llama2_7b();
cfg.hidden_size = 512;
cfg.num_hidden_layers = 1;
cfg.intermediate_size = 2048;
cfg.vocab_size = 32000;
let n1 = estimate_param_count(&cfg);
cfg.num_hidden_layers = 24;
let n24 = estimate_param_count(&cfg);
// 24× per-layer params + shared embedding ≈ 5-6× total for small models
// where embedding dominates per-layer contribution.
assert!(
n24 > n1 * 4,
"24-layer model {n24} should be at least 4× 1-layer model {n1}",
);
}

#[test]
fn preflight_accepts_matching_vocab() {
// GATE-ARCH-370M-011 acceptance case: tokenizer vocab.json with
Expand Down
Loading