Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 73 additions & 66 deletions crates/apr-cli/src/commands/eval/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1533,52 +1533,41 @@ pub(crate) fn run_mbpp(
Ok(())
}

/// ALB-085: Run MBPP with actual model inference + Python test execution.
/// ALB-085 + PMAT-CODE-MBPP-H4-FIX (2026-05-12): Run MBPP with actual model
/// inference + Python test execution.
///
/// Routes through `realizar::run_inference` + `InferenceConfig::with_prompt`
/// (ChatML auto-wrap for instruct models) — mirrors the §70 HumanEval H4 +
/// R1+R2 cascade. MBPP prompts are natural language ("Write a python
/// function to..."); without ChatML wrap, instruct models emit NL-prose
/// continuations ("Example: Input: ... Output: ...") instead of code (see
/// `evidence/section-72-mbpp-cascade-2026-05-12/findings.json` for the
/// pre-fix MBPP/11 SyntaxError evidence).
///
/// Parse `\`\`\`python ... \`\`\`` markdown blocks from the response. MBPP
/// has no Python imports in the prompt, so the §70 RC3 prompt-preamble
/// handling does not apply — the extracted code block is the program.
#[cfg(feature = "inference")]
fn run_mbpp_inference(
model_path: &Path,
problems: &[MbppProblem],
_k_values: &[usize],
json_output: bool,
) -> std::result::Result<(usize, Vec<(String, String, bool)>), String> {
use realizar::apr_transformer::{AprKVCache, AprTransformer};
use realizar::safetensors_infer::SafetensorsToAprConverter;
use realizar::{run_inference, InferenceConfig};

if !json_output {
println!(" {} Loading model for inference...", "→".dimmed());
}
let transformer: AprTransformer = if model_path.extension().is_some_and(|e| e == "apr")
|| model_path.join("model-best.apr").exists()
{
let apr_path = if model_path.is_dir() {
model_path.join("model-best.apr")
} else {
model_path.to_path_buf()
};
AprTransformer::from_apr_file(&apr_path)
.map_err(|e| format!("Cannot load APR model: {e}"))?
} else {
SafetensorsToAprConverter::convert(model_path)
.map_err(|e| format!("Cannot load model: {e}"))?
.into_inner()
};

let tokenizer = realizar::apr::AprV2Model::load_tokenizer(model_path)
.ok_or_else(|| "No tokenizer found".to_string())?;

if !json_output {
println!(
" {} Model loaded ({} layers, vocab={})",
"✓".green(),
transformer.config.num_layers,
transformer.config.vocab_size
);
println!(" {} Tokenizer loaded", "✓".green());
}

let mut passed = 0usize;
let mut results = Vec::new();
let temperature = 0.0f32;
let mut rng_state: u64 = 42;

for (i, problem) in problems.iter().enumerate() {
let task_id = match &problem.task_id {
Expand All @@ -1587,49 +1576,67 @@ fn run_mbpp_inference(
v => format!("MBPP/{v}"),
};

// MBPP prompt: natural language description -> model writes complete function
let prompt = format!("{}\n", problem.text);

let prompt_tokens = tokenizer.encode(&prompt);
if prompt_tokens.is_empty() {
results.push((task_id, String::new(), false));
continue;
}

// Generate completion (max 512 tokens -- MBPP solutions are longer)
let mut cache = AprKVCache::new(&transformer.config);
let mut tokens = prompt_tokens.clone();

for (pos, &tok) in prompt_tokens.iter().enumerate() {
let _ = transformer.forward_with_cache(tok, &mut cache, pos);
}
// MBPP canonical prompt format: NL description + test_list hint.
//
// Without the test_list hint, the model invents its own function name
// (e.g., `remove_first_last_occurrence` for MBPP/11) and fails the
// assertion (`remove_Occ` expected). The standard MBPP format used by
// Bigcode + lm-eval-harness + the canonical paper includes the first
// 1-3 test assertions as `Your code should pass these tests:` hints —
// this implicitly specifies the function name and signature.
let test_hints = if problem.test_list.is_empty() {
String::new()
} else {
format!(
"\nYour code should pass these tests:\n{}\n",
problem.test_list.join("\n")
)
};
let prompt = format!("{}{}", problem.text, test_hints);

let max_new = 512;
for step in 0..max_new {
let pos = prompt_tokens.len() + step;
let last_tok = *tokens.last().expect("last(");
let logits = transformer
.forward_with_cache(last_tok, &mut cache, pos)
.map_err(|e| format!("Generation failed: {e}"))?;

let next = sample_token(&logits, temperature, &mut rng_state);
tokens.push(next);
// H4 fix: route through ChatML auto-wrap via `with_prompt` (instruct
// models). Raw NL → ChatML user message → assistant emits markdown
// code block.
let config_chatml = InferenceConfig::new(model_path)
.with_prompt(prompt.clone())
.with_max_tokens(512)
.with_temperature(0.0)
.with_top_k(1);

if next == 0 {
break;
}
if let Some(eos) = transformer.config.eos_token_id {
if next == eos {
break;
let result = match run_inference(&config_chatml) {
Ok(r) => r,
Err(e) => {
if !json_output {
eprintln!(" [FAIL] {task_id}: inference error: {e}");
}
results.push((task_id, String::new(), false));
continue;
}
}

let completion_tokens = &tokens[prompt_tokens.len()..];
let completion = tokenizer.decode(completion_tokens);
};

// Truncate at next top-level definition (same as HumanEval)
let completion = truncate_at_function_boundary(&completion);
// R1+R2: extract Python code block. MBPP has no entry_point in the
// problem schema (unlike HumanEval), so we pass None — the
// first-non-empty-block fallback is appropriate.
let completion_owned =
if let Some(code) = extract_python_code_block_targeted(&result.text, None) {
// ChatML/markdown path: assistant emitted `\`\`\`python\n…\n\`\`\``.
code
} else {
// Raw-continuation fallback (no code block found). Slice past the
// prompt; truncate at next top-level def.
let raw = if let Some(stripped) = result.text.strip_prefix(&prompt) {
stripped.to_string()
} else {
let completion_tokens = if result.tokens.len() > result.input_token_count {
&result.tokens[result.input_token_count..]
} else {
&result.tokens[..]
};
tokenizer.decode(completion_tokens)
};
truncate_at_function_boundary(&raw).to_string()
};
let completion: &str = &completion_owned;

// Build test program: completion + setup_code + test assertions
let setup = problem.test_setup_code.as_deref().unwrap_or("").trim();
Expand All @@ -1647,7 +1654,7 @@ fn run_mbpp_inference(
write_apr_eval_debug(
&task_id,
&prompt,
&tokenizer.decode(&tokens),
&result.text,
completion,
&full_program,
&exec_result,
Expand Down
Loading