From 20fc35fa705a6f0a5cea936df8afc3216b036456 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Mon, 4 Aug 2025 21:53:34 +0000 Subject: [PATCH] Allocate arguments from topmost frame into temporary in init_fn_tail_call --- .../rustc_const_eval/src/interpret/call.rs | 45 +++++++++++++++++-- .../rustc_const_eval/src/interpret/stack.rs | 10 +++++ .../rustc_const_eval/src/interpret/step.rs | 2 +- src/tools/miri/tests/pass/tail_call_temp.rs | 12 +++++ 4 files changed, 64 insertions(+), 5 deletions(-) create mode 100644 src/tools/miri/tests/pass/tail_call_temp.rs diff --git a/compiler/rustc_const_eval/src/interpret/call.rs b/compiler/rustc_const_eval/src/interpret/call.rs index b8a653698258f..471a658af650c 100644 --- a/compiler/rustc_const_eval/src/interpret/call.rs +++ b/compiler/rustc_const_eval/src/interpret/call.rs @@ -5,6 +5,7 @@ use std::borrow::Cow; use either::{Left, Right}; use rustc_abi::{self as abi, ExternAbi, FieldIdx, Integer, VariantIdx}; +use rustc_data_structures::fx::FxHashSet; use rustc_hir::def_id::DefId; use rustc_middle::ty::layout::{IntegerExt, TyAndLayout}; use rustc_middle::ty::{self, AdtDef, Instance, Ty, VariantDef}; @@ -19,7 +20,7 @@ use super::{ Projectable, Provenance, ReturnAction, ReturnContinuation, Scalar, StackPopInfo, interp_ok, throw_ub, throw_ub_custom, throw_unsup_format, }; -use crate::interpret::EnteredTraceSpan; +use crate::interpret::{EnteredTraceSpan, MemoryKind}; use crate::{enter_trace_span, fluent_generated as fluent}; /// An argument passed to a function. @@ -752,11 +753,41 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { &mut self, fn_val: FnVal<'tcx, M::ExtraFnVal>, (caller_abi, caller_fn_abi): (ExternAbi, &FnAbi<'tcx, Ty<'tcx>>), - args: &[FnArg<'tcx, M::Provenance>], + mut args: Vec>, with_caller_location: bool, ) -> InterpResult<'tcx> { trace!("init_fn_tail_call: {:#?}", fn_val); + let mut local_temps = vec![]; + let frame_locals = &self.stack().last().unwrap().locals; + if frame_locals.iter().any(|frame_local| frame_local.is_allocation()) { + // Allocations corresponding to the locals in the last frame. + let local_allocs: FxHashSet<_> = frame_locals + .iter() + .filter_map(|local| local.as_mplace_or_imm()?.left()?.0.provenance?.get_alloc_id()) + .collect(); + + for arg in &mut args { + let mplace = match arg { + FnArg::Copy(op) => match op.as_mplace_or_imm() { + Left(mplace) => mplace, + Right(_) => continue, + }, + FnArg::InPlace(mplace) => mplace.clone(), + }; + + if let Some(prov) = mplace.ptr().provenance + && let Some(alloc_id) = prov.get_alloc_id() + && local_allocs.contains(&alloc_id) + { + let temp_mplace = self.allocate(*arg.layout(), MemoryKind::Stack)?; + self.copy_op(&mplace, &temp_mplace)?; + local_temps.push(temp_mplace.clone()); + *arg = FnArg::Copy(temp_mplace.into()); + } + } + } + // This is the "canonical" implementation of tails calls, // a pop of the current stack frame, followed by a normal call // which pushes a new stack frame, with the return address from @@ -785,12 +816,18 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { self.init_fn_call( fn_val, (caller_abi, caller_fn_abi), - args, + &args, with_caller_location, &return_place, ret, unwind, - ) + )?; + + for local_temp in local_temps { + self.deallocate_ptr(local_temp.ptr(), None, MemoryKind::Stack)?; + } + + interp_ok(()) } pub(super) fn init_drop_in_place_call( diff --git a/compiler/rustc_const_eval/src/interpret/stack.rs b/compiler/rustc_const_eval/src/interpret/stack.rs index 73cc87508ef95..86c996c43f7ff 100644 --- a/compiler/rustc_const_eval/src/interpret/stack.rs +++ b/compiler/rustc_const_eval/src/interpret/stack.rs @@ -205,6 +205,16 @@ impl<'tcx, Prov: Provenance> LocalState<'tcx, Prov> { LocalValue::Live(val) => interp_ok(val), } } + + pub(super) fn is_allocation(&self) -> bool { + match self.value { + LocalValue::Dead => false, + LocalValue::Live(val) => match val { + Operand::Immediate(_) => false, + Operand::Indirect(_) => true, + }, + } + } } /// What we store about a frame in an interpreter backtrace. diff --git a/compiler/rustc_const_eval/src/interpret/step.rs b/compiler/rustc_const_eval/src/interpret/step.rs index 9df49c0f4ccdf..684dff3bb656e 100644 --- a/compiler/rustc_const_eval/src/interpret/step.rs +++ b/compiler/rustc_const_eval/src/interpret/step.rs @@ -544,7 +544,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> { let EvaluatedCalleeAndArgs { callee, args, fn_sig, fn_abi, with_caller_location } = self.eval_callee_and_args(terminator, func, args)?; - self.init_fn_tail_call(callee, (fn_sig.abi, fn_abi), &args, with_caller_location)?; + self.init_fn_tail_call(callee, (fn_sig.abi, fn_abi), args, with_caller_location)?; if self.frame_idx() != old_frame_idx { span_bug!( diff --git a/src/tools/miri/tests/pass/tail_call_temp.rs b/src/tools/miri/tests/pass/tail_call_temp.rs new file mode 100644 index 0000000000000..7d6634b3a30be --- /dev/null +++ b/src/tools/miri/tests/pass/tail_call_temp.rs @@ -0,0 +1,12 @@ +#![feature(explicit_tail_calls)] +#![expect(incomplete_features)] + +struct Wrapper(#[expect(unused)] usize); + +fn f(t: bool, x: Wrapper) { + if t { become f(false, x); } +} + +fn main() { + f(true, Wrapper(5)); +}