@@ -491,14 +491,12 @@ impl<'ast> State<'_, 'ast> {
491491 let params_format = match header_style {
492492 MultilineFuncHeaderStyle :: ParamsAlways => ListFormat :: always_break ( ) ,
493493 MultilineFuncHeaderStyle :: All
494- if header. parameters . len ( ) > 1
495- && !self . can_header_be_inlined ( header, body. is_some ( ) ) =>
494+ if header. parameters . len ( ) > 1 && !self . can_header_be_inlined ( func) =>
496495 {
497496 ListFormat :: always_break ( )
498497 }
499498 MultilineFuncHeaderStyle :: AllParams
500- if !header. parameters . is_empty ( )
501- && !self . can_header_be_inlined ( header, body. is_some ( ) ) =>
499+ if !header. parameters . is_empty ( ) && !self . can_header_be_inlined ( func) =>
502500 {
503501 ListFormat :: always_break ( )
504502 }
@@ -552,7 +550,7 @@ impl<'ast> State<'_, 'ast> {
552550
553551 let attrib_box = self . config . multiline_func_header . params_first ( )
554552 || ( self . config . multiline_func_header . attrib_first ( )
555- && !self . can_header_params_be_inlined ( header ) ) ;
553+ && !self . can_header_params_be_inlined ( func ) ) ;
556554 if attrib_box {
557555 self . s . cbox ( 0 ) ;
558556 }
@@ -2554,49 +2552,66 @@ impl<'ast> State<'_, 'ast> {
25542552 els_opt. is_none_or ( |els| self . is_inline_stmt ( els, 6 ) )
25552553 }
25562554
2557- fn can_header_be_inlined ( & mut self , header : & ast:: FunctionHeader < ' _ > , has_body : bool ) -> bool {
2555+ fn can_header_be_inlined ( & mut self , func : & ast:: ItemFunction < ' _ > ) -> bool {
2556+ self . estimate_header_size ( func) <= self . space_left ( )
2557+ }
2558+
2559+ fn can_header_params_be_inlined ( & mut self , func : & ast:: ItemFunction < ' _ > ) -> bool {
2560+ self . estimate_header_params_size ( func) <= self . space_left ( )
2561+ }
2562+
2563+ fn estimate_header_size ( & mut self , func : & ast:: ItemFunction < ' _ > ) -> usize {
2564+ let ast:: ItemFunction { kind : _, ref header, ref body, body_span : _ } = * func;
2565+
25582566 // ' ' + visibility
25592567 let visibility = header. visibility . map_or ( 0 , |v| self . estimate_size ( v. span ) + 1 ) ;
25602568 // ' ' + state mutability
25612569 let mutability = header. state_mutability . map_or ( 0 , |sm| self . estimate_size ( sm. span ) + 1 ) ;
25622570 // ' ' + modifier + (' ' + modifier)
2563- let modifiers =
2564- header . modifiers . iter ( ) . fold ( 0 , |len , m| len + self . estimate_size ( m . span ( ) ) ) + 1 ;
2571+ let m = header . modifiers . iter ( ) . fold ( 0 , |len , m| len + self . estimate_size ( m . span ( ) ) ) ;
2572+ let modifiers = if m != 0 { m + 1 } else { 0 } ;
25652573 // ' ' + override
25662574 let override_ = header. override_ . as_ref ( ) . map_or ( 0 , |o| self . estimate_size ( o. span ) + 1 ) ;
2575+ // ' ' + virtual
2576+ let virtual_ = if header. virtual_ . is_none ( ) { 0 } else { 8 } ;
25672577 // ' returns(' + var + (', ' + var) + ')'
25682578 let returns = header. returns . as_ref ( ) . map_or ( 0 , |ret| {
25692579 ret. vars
25702580 . iter ( )
25712581 . fold ( 0 , |len, p| if len != 0 { len + 2 } else { 10 } + self . estimate_size ( p. span ) )
25722582 } ) ;
25732583 // ' {' or ';'
2574- let end = if has_body { 2 } else { 1 } ;
2584+ let end = if body . is_some ( ) { 2 } else { 1 } ;
25752585
2576- self . estimate_header_params_size ( header ) // accounts for 'function name(..)'
2586+ self . estimate_header_params_size ( func )
25772587 + visibility
25782588 + mutability
25792589 + modifiers
25802590 + override_
2591+ + virtual_
25812592 + returns
25822593 + end
2583- <= self . space_left ( )
25842594 }
25852595
2586- fn can_header_params_be_inlined ( & mut self , header : & ast:: FunctionHeader < ' _ > ) -> bool {
2587- self . estimate_header_params_size ( header) <= self . space_left ( )
2588- }
2596+ fn estimate_header_params_size ( & mut self , func : & ast:: ItemFunction < ' _ > ) -> usize {
2597+ let ast:: ItemFunction { kind, ref header, body : _, body_span : _ } = * func;
2598+
2599+ let kw = match kind {
2600+ ast:: FunctionKind :: Constructor => 11 , // 'constructor'
2601+ ast:: FunctionKind :: Function => 9 , // 'function '
2602+ ast:: FunctionKind :: Modifier => 9 , // 'modifier '
2603+ ast:: FunctionKind :: Fallback => 8 , // 'fallback'
2604+ ast:: FunctionKind :: Receive => 7 , // 'receive'
2605+ } ;
25892606
2590- fn estimate_header_params_size ( & mut self , header : & ast:: FunctionHeader < ' _ > ) -> usize {
25912607 // '(' + param + (', ' + param) + ')'
25922608 let params = header
25932609 . parameters
25942610 . vars
25952611 . iter ( )
25962612 . fold ( 0 , |len, p| if len != 0 { len + 2 } else { 2 } + self . estimate_size ( p. span ) ) ;
25972613
2598- // 'function ' + name + ' ' + params
2599- 9 + header. name . map_or ( 0 , |name| self . estimate_size ( name. span ) + 1 ) + params
2614+ kw + header. name . map_or ( 0 , |name| self . estimate_size ( name. span ) ) + std:: cmp:: max ( 2 , params)
26002615 }
26012616
26022617 fn estimate_lhs_size ( & self , expr : & ast:: Expr < ' _ > , parent_op : & ast:: BinOp ) -> usize {
@@ -2958,3 +2973,117 @@ pub(super) fn get_callee_head_size(callee: &ast::Expr<'_>) -> usize {
29582973 _ => 0 ,
29592974 }
29602975}
2976+
2977+ #[ cfg( test) ]
2978+ mod tests {
2979+ use super :: * ;
2980+ use crate :: { FormatterConfig , InlineConfig } ;
2981+ use foundry_common:: comments:: Comments ;
2982+ use solar:: {
2983+ interface:: { Session , source_map:: FileName } ,
2984+ sema:: Compiler ,
2985+ } ;
2986+ use std:: sync:: Arc ;
2987+
2988+ /// This helper extracts function headers from the AST and passes them to the test function.
2989+ fn parse_and_test < F > ( source : & str , test_fn : F )
2990+ where
2991+ F : FnOnce ( & mut State < ' _ , ' _ > , & ast:: ItemFunction < ' _ > ) + Send ,
2992+ {
2993+ let session = Session :: builder ( ) . with_buffer_emitter ( Default :: default ( ) ) . build ( ) ;
2994+ let mut compiler = Compiler :: new ( session) ;
2995+
2996+ compiler
2997+ . enter_mut ( |c| -> solar:: interface:: Result < ( ) > {
2998+ let mut pcx = c. parse ( ) ;
2999+ pcx. set_resolve_imports ( false ) ;
3000+
3001+ // Create a source file using stdin as the filename
3002+ let file = c
3003+ . sess ( )
3004+ . source_map ( )
3005+ . new_source_file ( FileName :: Stdin , source)
3006+ . map_err ( |e| c. sess ( ) . dcx . err ( e. to_string ( ) ) . emit ( ) ) ?;
3007+
3008+ pcx. add_file ( file. clone ( ) ) ;
3009+ pcx. parse ( ) ;
3010+ c. dcx ( ) . has_errors ( ) ?;
3011+
3012+ // Get AST from parsed source and setup the formatter
3013+ let gcx = c. gcx ( ) ;
3014+ let ( _, source_obj) = gcx. get_ast_source ( & file. name ) . expect ( "Failed to get AST" ) ;
3015+ let ast = source_obj. ast . as_ref ( ) . expect ( "No AST found" ) ;
3016+ let comments =
3017+ Comments :: new ( & source_obj. file , gcx. sess . source_map ( ) , true , false , None ) ;
3018+ let config = Arc :: new ( FormatterConfig :: default ( ) ) ;
3019+ let inline_config = InlineConfig :: default ( ) ;
3020+ let mut state = State :: new ( gcx. sess . source_map ( ) , config, inline_config, comments) ;
3021+
3022+ // Extract the first function header (either top-level or inside a contract)
3023+ let func = ast
3024+ . items
3025+ . iter ( )
3026+ . find_map ( |item| match & item. kind {
3027+ ast:: ItemKind :: Function ( func) => Some ( func) ,
3028+ ast:: ItemKind :: Contract ( contract) => {
3029+ contract. body . iter ( ) . find_map ( |contract_item| {
3030+ match & contract_item. kind {
3031+ ast:: ItemKind :: Function ( func) => Some ( func) ,
3032+ _ => None ,
3033+ }
3034+ } )
3035+ }
3036+ _ => None ,
3037+ } )
3038+ . expect ( "No function found in source" ) ;
3039+
3040+ // Run the closure
3041+ test_fn ( & mut state, func) ;
3042+
3043+ Ok ( ( ) )
3044+ } )
3045+ . expect ( "Test failed" ) ;
3046+ }
3047+
3048+ #[ test]
3049+ fn test_estimate_header_sizes ( ) {
3050+ let test_cases = [
3051+ ( "function foo();" , 14 , 15 ) ,
3052+ ( "function foo() {}" , 14 , 16 ) ,
3053+ ( "function foo() public {}" , 14 , 23 ) ,
3054+ ( "function foo(uint256 a) public {}" , 23 , 32 ) ,
3055+ ( "function foo(uint256 a, address b, bool c) public {}" , 42 , 51 ) ,
3056+ ( "function foo() public pure {}" , 14 , 28 ) ,
3057+ ( "function foo() public virtual {}" , 14 , 31 ) ,
3058+ ( "function foo() public override {}" , 14 , 32 ) ,
3059+ ( "function foo() public onlyOwner {}" , 14 , 33 ) ,
3060+ ( "function foo() public returns(uint256) {}" , 14 , 40 ) ,
3061+ ( "function foo() public returns(uint256, address) {}" , 14 , 49 ) ,
3062+ ( "function foo(uint256 a) public virtual override returns(uint256) {}" , 23 , 66 ) ,
3063+ ( "function foo() external payable {}" , 14 , 33 ) ,
3064+ // other function types
3065+ ( "contract C { constructor() {} }" , 13 , 15 ) ,
3066+ ( "contract C { constructor(uint256 a) {} }" , 22 , 24 ) ,
3067+ ( "contract C { modifier onlyOwner() {} }" , 20 , 22 ) ,
3068+ ( "contract C { modifier onlyRole(bytes32 role) {} }" , 31 , 33 ) ,
3069+ ( "contract C { fallback() external payable {} }" , 10 , 29 ) ,
3070+ ( "contract C { receive() external payable {} }" , 9 , 28 ) ,
3071+ ] ;
3072+
3073+ for ( source, expected_params, expected_header) in & test_cases {
3074+ parse_and_test ( source, |state, func| {
3075+ let params_size = state. estimate_header_params_size ( func) ;
3076+ assert_eq ! (
3077+ params_size, * expected_params,
3078+ "Failed params size: expected {expected_params}, got {params_size} for source: {source}" ,
3079+ ) ;
3080+
3081+ let header_size = state. estimate_header_size ( func) ;
3082+ assert_eq ! (
3083+ header_size, * expected_header,
3084+ "Failed header size: expected {expected_header}, got {header_size} for source: {source}" ,
3085+ ) ;
3086+ } ) ;
3087+ }
3088+ }
3089+ }
0 commit comments