Skip to content

Commit 0f5eff2

Browse files
authored
test(fmt): ensure fn header sizes are computed correctly (#12350)
* fix(fmt): properly calc fn header size * docs: add more cmnts * fix: revert bun.lock changes * test: estimate_header_size * simplify tests * style: clippy
1 parent 0cf0bac commit 0f5eff2

File tree

1 file changed

+146
-17
lines changed

1 file changed

+146
-17
lines changed

crates/fmt/src/state/sol.rs

Lines changed: 146 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -491,14 +491,12 @@ impl<'ast> State<'_, 'ast> {
491491
let params_format = match header_style {
492492
MultilineFuncHeaderStyle::ParamsAlways => ListFormat::always_break(),
493493
MultilineFuncHeaderStyle::All
494-
if header.parameters.len() > 1
495-
&& !self.can_header_be_inlined(header, body.is_some()) =>
494+
if header.parameters.len() > 1 && !self.can_header_be_inlined(func) =>
496495
{
497496
ListFormat::always_break()
498497
}
499498
MultilineFuncHeaderStyle::AllParams
500-
if !header.parameters.is_empty()
501-
&& !self.can_header_be_inlined(header, body.is_some()) =>
499+
if !header.parameters.is_empty() && !self.can_header_be_inlined(func) =>
502500
{
503501
ListFormat::always_break()
504502
}
@@ -552,7 +550,7 @@ impl<'ast> State<'_, 'ast> {
552550

553551
let attrib_box = self.config.multiline_func_header.params_first()
554552
|| (self.config.multiline_func_header.attrib_first()
555-
&& !self.can_header_params_be_inlined(header));
553+
&& !self.can_header_params_be_inlined(func));
556554
if attrib_box {
557555
self.s.cbox(0);
558556
}
@@ -2554,49 +2552,66 @@ impl<'ast> State<'_, 'ast> {
25542552
els_opt.is_none_or(|els| self.is_inline_stmt(els, 6))
25552553
}
25562554

2557-
fn can_header_be_inlined(&mut self, header: &ast::FunctionHeader<'_>, has_body: bool) -> bool {
2555+
fn can_header_be_inlined(&mut self, func: &ast::ItemFunction<'_>) -> bool {
2556+
self.estimate_header_size(func) <= self.space_left()
2557+
}
2558+
2559+
fn can_header_params_be_inlined(&mut self, func: &ast::ItemFunction<'_>) -> bool {
2560+
self.estimate_header_params_size(func) <= self.space_left()
2561+
}
2562+
2563+
fn estimate_header_size(&mut self, func: &ast::ItemFunction<'_>) -> usize {
2564+
let ast::ItemFunction { kind: _, ref header, ref body, body_span: _ } = *func;
2565+
25582566
// ' ' + visibility
25592567
let visibility = header.visibility.map_or(0, |v| self.estimate_size(v.span) + 1);
25602568
// ' ' + state mutability
25612569
let mutability = header.state_mutability.map_or(0, |sm| self.estimate_size(sm.span) + 1);
25622570
// ' ' + modifier + (' ' + modifier)
2563-
let modifiers =
2564-
header.modifiers.iter().fold(0, |len, m| len + self.estimate_size(m.span())) + 1;
2571+
let m = header.modifiers.iter().fold(0, |len, m| len + self.estimate_size(m.span()));
2572+
let modifiers = if m != 0 { m + 1 } else { 0 };
25652573
// ' ' + override
25662574
let override_ = header.override_.as_ref().map_or(0, |o| self.estimate_size(o.span) + 1);
2575+
// ' ' + virtual
2576+
let virtual_ = if header.virtual_.is_none() { 0 } else { 8 };
25672577
// ' returns(' + var + (', ' + var) + ')'
25682578
let returns = header.returns.as_ref().map_or(0, |ret| {
25692579
ret.vars
25702580
.iter()
25712581
.fold(0, |len, p| if len != 0 { len + 2 } else { 10 } + self.estimate_size(p.span))
25722582
});
25732583
// ' {' or ';'
2574-
let end = if has_body { 2 } else { 1 };
2584+
let end = if body.is_some() { 2 } else { 1 };
25752585

2576-
self.estimate_header_params_size(header) // accounts for 'function name(..)'
2586+
self.estimate_header_params_size(func)
25772587
+ visibility
25782588
+ mutability
25792589
+ modifiers
25802590
+ override_
2591+
+ virtual_
25812592
+ returns
25822593
+ end
2583-
<= self.space_left()
25842594
}
25852595

2586-
fn can_header_params_be_inlined(&mut self, header: &ast::FunctionHeader<'_>) -> bool {
2587-
self.estimate_header_params_size(header) <= self.space_left()
2588-
}
2596+
fn estimate_header_params_size(&mut self, func: &ast::ItemFunction<'_>) -> usize {
2597+
let ast::ItemFunction { kind, ref header, body: _, body_span: _ } = *func;
2598+
2599+
let kw = match kind {
2600+
ast::FunctionKind::Constructor => 11, // 'constructor'
2601+
ast::FunctionKind::Function => 9, // 'function '
2602+
ast::FunctionKind::Modifier => 9, // 'modifier '
2603+
ast::FunctionKind::Fallback => 8, // 'fallback'
2604+
ast::FunctionKind::Receive => 7, // 'receive'
2605+
};
25892606

2590-
fn estimate_header_params_size(&mut self, header: &ast::FunctionHeader<'_>) -> usize {
25912607
// '(' + param + (', ' + param) + ')'
25922608
let params = header
25932609
.parameters
25942610
.vars
25952611
.iter()
25962612
.fold(0, |len, p| if len != 0 { len + 2 } else { 2 } + self.estimate_size(p.span));
25972613

2598-
// 'function ' + name + ' ' + params
2599-
9 + header.name.map_or(0, |name| self.estimate_size(name.span) + 1) + params
2614+
kw + header.name.map_or(0, |name| self.estimate_size(name.span)) + std::cmp::max(2, params)
26002615
}
26012616

26022617
fn estimate_lhs_size(&self, expr: &ast::Expr<'_>, parent_op: &ast::BinOp) -> usize {
@@ -2958,3 +2973,117 @@ pub(super) fn get_callee_head_size(callee: &ast::Expr<'_>) -> usize {
29582973
_ => 0,
29592974
}
29602975
}
2976+
2977+
#[cfg(test)]
2978+
mod tests {
2979+
use super::*;
2980+
use crate::{FormatterConfig, InlineConfig};
2981+
use foundry_common::comments::Comments;
2982+
use solar::{
2983+
interface::{Session, source_map::FileName},
2984+
sema::Compiler,
2985+
};
2986+
use std::sync::Arc;
2987+
2988+
/// This helper extracts function headers from the AST and passes them to the test function.
2989+
fn parse_and_test<F>(source: &str, test_fn: F)
2990+
where
2991+
F: FnOnce(&mut State<'_, '_>, &ast::ItemFunction<'_>) + Send,
2992+
{
2993+
let session = Session::builder().with_buffer_emitter(Default::default()).build();
2994+
let mut compiler = Compiler::new(session);
2995+
2996+
compiler
2997+
.enter_mut(|c| -> solar::interface::Result<()> {
2998+
let mut pcx = c.parse();
2999+
pcx.set_resolve_imports(false);
3000+
3001+
// Create a source file using stdin as the filename
3002+
let file = c
3003+
.sess()
3004+
.source_map()
3005+
.new_source_file(FileName::Stdin, source)
3006+
.map_err(|e| c.sess().dcx.err(e.to_string()).emit())?;
3007+
3008+
pcx.add_file(file.clone());
3009+
pcx.parse();
3010+
c.dcx().has_errors()?;
3011+
3012+
// Get AST from parsed source and setup the formatter
3013+
let gcx = c.gcx();
3014+
let (_, source_obj) = gcx.get_ast_source(&file.name).expect("Failed to get AST");
3015+
let ast = source_obj.ast.as_ref().expect("No AST found");
3016+
let comments =
3017+
Comments::new(&source_obj.file, gcx.sess.source_map(), true, false, None);
3018+
let config = Arc::new(FormatterConfig::default());
3019+
let inline_config = InlineConfig::default();
3020+
let mut state = State::new(gcx.sess.source_map(), config, inline_config, comments);
3021+
3022+
// Extract the first function header (either top-level or inside a contract)
3023+
let func = ast
3024+
.items
3025+
.iter()
3026+
.find_map(|item| match &item.kind {
3027+
ast::ItemKind::Function(func) => Some(func),
3028+
ast::ItemKind::Contract(contract) => {
3029+
contract.body.iter().find_map(|contract_item| {
3030+
match &contract_item.kind {
3031+
ast::ItemKind::Function(func) => Some(func),
3032+
_ => None,
3033+
}
3034+
})
3035+
}
3036+
_ => None,
3037+
})
3038+
.expect("No function found in source");
3039+
3040+
// Run the closure
3041+
test_fn(&mut state, func);
3042+
3043+
Ok(())
3044+
})
3045+
.expect("Test failed");
3046+
}
3047+
3048+
#[test]
3049+
fn test_estimate_header_sizes() {
3050+
let test_cases = [
3051+
("function foo();", 14, 15),
3052+
("function foo() {}", 14, 16),
3053+
("function foo() public {}", 14, 23),
3054+
("function foo(uint256 a) public {}", 23, 32),
3055+
("function foo(uint256 a, address b, bool c) public {}", 42, 51),
3056+
("function foo() public pure {}", 14, 28),
3057+
("function foo() public virtual {}", 14, 31),
3058+
("function foo() public override {}", 14, 32),
3059+
("function foo() public onlyOwner {}", 14, 33),
3060+
("function foo() public returns(uint256) {}", 14, 40),
3061+
("function foo() public returns(uint256, address) {}", 14, 49),
3062+
("function foo(uint256 a) public virtual override returns(uint256) {}", 23, 66),
3063+
("function foo() external payable {}", 14, 33),
3064+
// other function types
3065+
("contract C { constructor() {} }", 13, 15),
3066+
("contract C { constructor(uint256 a) {} }", 22, 24),
3067+
("contract C { modifier onlyOwner() {} }", 20, 22),
3068+
("contract C { modifier onlyRole(bytes32 role) {} }", 31, 33),
3069+
("contract C { fallback() external payable {} }", 10, 29),
3070+
("contract C { receive() external payable {} }", 9, 28),
3071+
];
3072+
3073+
for (source, expected_params, expected_header) in &test_cases {
3074+
parse_and_test(source, |state, func| {
3075+
let params_size = state.estimate_header_params_size(func);
3076+
assert_eq!(
3077+
params_size, *expected_params,
3078+
"Failed params size: expected {expected_params}, got {params_size} for source: {source}",
3079+
);
3080+
3081+
let header_size = state.estimate_header_size(func);
3082+
assert_eq!(
3083+
header_size, *expected_header,
3084+
"Failed header size: expected {expected_header}, got {header_size} for source: {source}",
3085+
);
3086+
});
3087+
}
3088+
}
3089+
}

0 commit comments

Comments
 (0)