Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 110 additions & 77 deletions crates/aprender-train/src/train/pretrain_real.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,8 @@ pub fn populate_trainer_from_init_tensors(
transformer: &mut Transformer,
init_tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
) -> Result<usize, String> {
let expected: Vec<(String, usize)> = transformer
.named_parameters()
.into_iter()
.map(|(name, t)| (name, t.len()))
.collect();
let expected: Vec<(String, usize)> =
transformer.named_parameters().into_iter().map(|(name, t)| (name, t.len())).collect();
let mut populated = 0usize;
let mut errors: Vec<String> = Vec::new();

Expand All @@ -162,9 +159,7 @@ pub fn populate_trainer_from_init_tensors(
}
let tensor = Tensor::from_vec(data.clone(), true);
if !transformer.set_named_parameter(name, tensor) {
errors.push(format!(
"{name}: set_named_parameter rejected the assignment"
));
errors.push(format!("{name}: set_named_parameter rejected the assignment"));
continue;
}
populated += 1;
Expand Down Expand Up @@ -458,8 +453,8 @@ mod tests {
fn load_init_tensors_missing_file_errors_with_falsifier_id() {
let tmp = tempfile::TempDir::new().expect("tempdir");
let missing = tmp.path().join("does-not-exist.apr");
let err = load_init_tensors_from_apr(&missing)
.expect_err("missing init APR file MUST fail-fast");
let err =
load_init_tensors_from_apr(&missing).expect_err("missing init APR file MUST fail-fast");
assert!(
err.contains("FALSIFY-APR-PRETRAIN-INIT-006"),
"error must cite falsifier id (auditability): {err}"
Expand All @@ -485,10 +480,7 @@ mod tests {
// if the signature drifts, this test stops compiling.
fn _check_signature<F>(_f: F)
where
F: Fn(
&Path,
)
-> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String>,
F: Fn(&Path) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>, String>,
{
}
_check_signature(|p| load_init_tensors_from_apr(p));
Expand Down Expand Up @@ -524,10 +516,7 @@ mod tests {
assert_eq!(result.intermediate_size, baseline.intermediate_size);
assert_eq!(result.num_hidden_layers, baseline.num_hidden_layers);
assert_eq!(result.vocab_size, baseline.vocab_size);
assert_eq!(
result.max_position_embeddings,
baseline.max_position_embeddings
);
assert_eq!(result.max_position_embeddings, baseline.max_position_embeddings);
assert!((result.rms_norm_eps - baseline.rms_norm_eps).abs() < f32::EPSILON);
assert!((result.rope_theta - baseline.rope_theta).abs() < f32::EPSILON);
assert_eq!(result.use_bias, baseline.use_bias);
Expand All @@ -548,29 +537,17 @@ mod tests {
let qwen = TransformerConfig::qwen2_0_5b();
let result = build_transformer_config(Some(&qwen));
assert_eq!(result.hidden_size, qwen.hidden_size, "hidden_size");
assert_eq!(
result.num_attention_heads, qwen.num_attention_heads,
"num_attention_heads"
);
assert_eq!(result.num_attention_heads, qwen.num_attention_heads, "num_attention_heads");
assert_eq!(result.num_kv_heads, qwen.num_kv_heads, "num_kv_heads");
assert_eq!(
result.intermediate_size, qwen.intermediate_size,
"intermediate_size"
);
assert_eq!(
result.num_hidden_layers, qwen.num_hidden_layers,
"num_hidden_layers"
);
assert_eq!(result.intermediate_size, qwen.intermediate_size, "intermediate_size");
assert_eq!(result.num_hidden_layers, qwen.num_hidden_layers, "num_hidden_layers");
assert_eq!(result.vocab_size, qwen.vocab_size, "vocab_size");
assert_eq!(
result.max_position_embeddings, qwen.max_position_embeddings,
"max_position_embeddings"
);
assert_eq!(result.use_bias, qwen.use_bias, "use_bias");
assert_eq!(
result.tie_word_embeddings, qwen.tie_word_embeddings,
"tie_word_embeddings"
);
assert_eq!(result.tie_word_embeddings, qwen.tie_word_embeddings, "tie_word_embeddings");
assert_eq!(result.architecture, qwen.architecture, "architecture");
// GQA-7:1 ratio preserved (Qwen2.5-0.5B: 14 / 2 = 7)
assert_eq!(
Expand Down Expand Up @@ -650,10 +627,7 @@ mod tests {
err.contains("FALSIFY-APR-PRETRAIN-ARCH-007"),
"error must cite falsifier id: {err}"
);
assert!(
err.contains("Encoder"),
"error must name the architecture family: {err}"
);
assert!(err.contains("Encoder"), "error must name the architecture family: {err}");
assert!(
err.contains("decoder-only"),
"error must explain why this is wrong (decoder trainer): {err}"
Expand Down Expand Up @@ -758,10 +732,7 @@ mod tests {
let init_tensors = tensors_map_from_transformer(&transformer);
let expected_count = transformer.named_parameters().len();
let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
assert!(
result.is_ok(),
"happy-path populate must succeed: {result:?}"
);
assert!(result.is_ok(), "happy-path populate must succeed: {result:?}");
assert_eq!(
result.unwrap(),
expected_count,
Expand All @@ -778,16 +749,11 @@ mod tests {
let mut transformer = tiny_test_transformer();
let mut init_tensors = tensors_map_from_transformer(&transformer);
// Inject a fictitious extra parameter that the model does not have.
init_tensors.insert(
"model.layers.999.fictitious.weight".to_string(),
(vec![0.0; 4], vec![4]),
);
init_tensors
.insert("model.layers.999.fictitious.weight".to_string(), (vec![0.0; 4], vec![4]));
let expected_count = transformer.named_parameters().len();
let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
assert!(
result.is_ok(),
"extra init entries must NOT cause Err: {result:?}"
);
assert!(result.is_ok(), "extra init entries must NOT cause Err: {result:?}");
assert_eq!(result.unwrap(), expected_count);
}

Expand All @@ -802,19 +768,13 @@ mod tests {
let any_name = transformer.named_parameters()[0].0.clone();
init_tensors.insert(any_name.clone(), (vec![0.0; 7], vec![7]));
let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
assert!(
result.is_err(),
"length-mismatch must Err, not silently truncate"
);
assert!(result.is_err(), "length-mismatch must Err, not silently truncate");
let err = result.unwrap_err();
assert!(
err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
"error must cite falsifier id; got: {err}"
);
assert!(
err.contains(&any_name),
"error must name the offending parameter; got: {err}"
);
assert!(err.contains(&any_name), "error must name the offending parameter; got: {err}");
assert!(
err.contains("init length 7"),
"error must report the actual init length; got: {err}"
Expand All @@ -834,19 +794,13 @@ mod tests {
let any_name = transformer.named_parameters()[0].0.clone();
init_tensors.remove(&any_name);
let result = populate_trainer_from_init_tensors(&mut transformer, &init_tensors);
assert!(
result.is_err(),
"missing-required must Err, not silently leave random init"
);
assert!(result.is_err(), "missing-required must Err, not silently leave random init");
let err = result.unwrap_err();
assert!(
err.contains("FALSIFY-APR-PRETRAIN-INIT-007"),
"error must cite falsifier id; got: {err}"
);
assert!(
err.contains(&any_name),
"error must name the missing parameter; got: {err}"
);
assert!(err.contains(&any_name), "error must name the missing parameter; got: {err}");
assert!(
err.contains("not present in init APR"),
"error must say what was missing; got: {err}"
Expand All @@ -867,10 +821,10 @@ mod tests {
// The baseline polymorphic dispatch produces a Llama370M-shaped model.
// Embedding shape `vocab × hidden` is the cleanest non-stale check.
let embed_len = model.model().named_parameters()[0].1.len();
let expected_embed_len =
Llama370MConfig::VOCAB_SIZE * Llama370MConfig::HIDDEN_DIM;
let expected_embed_len = Llama370MConfig::VOCAB_SIZE * Llama370MConfig::HIDDEN_DIM;
assert_eq!(
embed_len, expected_embed_len,
embed_len,
expected_embed_len,
"init=None must produce Llama370M-shaped embedding (vocab={} × hidden={})",
Llama370MConfig::VOCAB_SIZE,
Llama370MConfig::HIDDEN_DIM
Expand All @@ -885,17 +839,11 @@ mod tests {
// arch Some, path None
let cfg = TransformerConfig::qwen2_0_5b();
let result = build_shared_trainer_with_init(1.0e-4, 128, 42, Some(&cfg), None);
assert!(
result.is_err(),
"unpaired (arch=Some, path=None) must Err"
);
assert!(result.is_err(), "unpaired (arch=Some, path=None) must Err");
// arch None, path Some
let dummy_path = std::path::PathBuf::from("/dev/null");
let result = build_shared_trainer_with_init(1.0e-4, 128, 42, None, Some(&dummy_path));
assert!(
result.is_err(),
"unpaired (arch=None, path=Some) must Err"
);
assert!(result.is_err(), "unpaired (arch=None, path=Some) must Err");
}

/// `build_shared_trainer_with_init(Some(encoder), Some(path))` rejects
Expand Down Expand Up @@ -940,4 +888,89 @@ mod tests {
"decoder family must NOT trigger encoder-rejection; got: {err}"
);
}

/// FALSIFY-H4-CPU-FORWARD-001 (H4 residual cascade — bisect to CPU vs CUDA):
/// CPU `aprender::Transformer::forward` on a populated Qwen 0.5B model
/// MUST produce sensibly-distributed logits. Host-gated test that
/// bisects whether the val_loss > ln(vocab) defect is in the
/// populate path / CPU forward (RED here = bug there) or in CUDA
/// (GREEN here, RED in eval_batch = bug in CUDA path).
#[test]
fn falsify_h4_cpu_forward_qwen_logits_sensible() {
let fresh = std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-fresh.apr");
let legacy =
std::path::Path::new("/mnt/nvme-raid0/models/qwen2.5-coder-0.5b-instruct-fp16.apr");
let path = if fresh.exists() {
fresh
} else if legacy.exists() {
legacy
} else {
eprintln!("[falsify-h4-cpu-forward-001] skipping: host lacks Qwen 0.5B APR");
return;
};

let tensors = load_init_tensors_from_apr(path).expect("load_init_tensors_from_apr");
let cfg = TransformerConfig::qwen2_0_5b();
let mut transformer = Transformer::new(&cfg);
let populated = populate_trainer_from_init_tensors(&mut transformer, &tensors)
.expect("populate_trainer_from_init_tensors");
eprintln!("[falsify-h4-cpu-forward-001] populated {populated} tensors");

let token_ids = vec![100_u32];
let logits = transformer.forward(&token_ids);
let data = logits.data();
let slice = data.as_slice().expect("logits contiguous");

let mut nan_count = 0usize;
let mut inf_count = 0usize;
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let mut sum = 0.0_f64;
let mut sum_sq = 0.0_f64;
let mut argmax_idx = 0_usize;
for (i, &v) in slice.iter().enumerate() {
if v.is_nan() {
nan_count += 1;
} else if v.is_infinite() {
inf_count += 1;
} else {
if v < min {
min = v;
}
if v > max {
max = v;
argmax_idx = i;
}
sum += v as f64;
sum_sq += (v as f64) * (v as f64);
}
}
let n = slice.len() as f64;
let mean = sum / n;
let std = (sum_sq / n - mean * mean).sqrt();

eprintln!(
"[falsify-h4-cpu-forward-001] token=100 logits: n={} nan={nan_count} inf={inf_count} \
min={min:.4} max={max:.4} mean={mean:.4} std={std:.4} argmax={argmax_idx}",
slice.len()
);

assert_eq!(nan_count, 0, "logits contain NaN — forward corruption");
assert_eq!(inf_count, 0, "logits contain Inf — forward corruption");
assert!(
std > 0.01,
"FALSIFY-H4-CPU-FORWARD-001: logits std={std} < 0.01 — essentially constant"
);
let peak_to_mean = (max as f64 - mean).abs() / std.max(1e-9);
assert!(
peak_to_mean > 1.5,
"FALSIFY-H4-CPU-FORWARD-001: peak-to-mean ratio = {peak_to_mean} < 1.5 — \
logits are essentially uniform"
);
assert!(
(argmax_idx as u32) < cfg.vocab_size as u32,
"FALSIFY-H4-CPU-FORWARD-001: argmax_idx={argmax_idx} >= vocab_size={}",
cfg.vocab_size
);
}
}
Loading