@@ -384,35 +384,36 @@ def generate_stmt_invocation(program, function, invoked_func=None):
384384
385385def generate_stmt_return (program , function , exp = None ):
386386 c_type = function .return_type
387- if exp is None :
388- if isinstance (c_type , ast .SignedInt ) and probs_helper .random_value (probs .int_emulate_bool ):
389- value = probs_helper .random_value ({0 : 0.5 , 1 : 0.5 })
387+ if isinstance (c_type , ast .Void ):
388+ return ast .Return ()
389+ if isinstance (c_type , ast .SignedInt ) and probs_helper .random_value (probs .int_emulate_bool ):
390+ value = probs_helper .random_value ({0 : 0.5 , 1 : 0.5 })
391+ if not exp :
390392 exp = ast .Literal ("/* EMULATED BOOL LITERAL */ ({}) {}" .format (c_type .name , value ), c_type )
391- return ast .Return (exp , c_type )
392- elif not isinstance (c_type , ast .Void ): # return <expression>;
393- exp = generate_expression (program , function , c_type , probs .return_exp_depth_prob )
394393 return ast .Return (exp , c_type )
395- return ast .Return () # return; (without expression)
396-
394+ else :
395+ if not exp :
396+ exp = generate_expression (program , function , c_type , probs .return_exp_depth_prob )
397+ return ast .Return (exp , c_type )
397398
398399def generate_stmt_block (program : Program , function : Function , stmt_depth : int ) -> Block :
399400 """Generates a block statement"""
400- number_statements = probs_helper .random_value (probs .number_stmts_func_block )
401+ number_statements = probs_helper .random_value (probs .number_stmts_block_prob )
401402 statements = [generate_stmt_func (program , function , stmt_depth - 1 ) for _ in range (number_statements )]
402403 return ast .Block (statements )
403404
404405
405406def generate_stmt_while (program : Program , function : Function , stmt_depth : int ) -> While :
406407 """Generates a while statement"""
407408 condition = generate_expression (program , function , SignedInt (), None )
408- number_statements = probs_helper .random_value (probs .number_stmts_func_block )
409+ number_statements = probs_helper .random_value (probs .number_stmts_block_prob )
409410 statements = [generate_stmt_func (program , function , stmt_depth - 1 ) for _ in range (number_statements )]
410411 return ast .While (condition , statements )
411412
412413
413414def generate_stmt_do (program : Program , function : Function , stmt_depth : int ) -> Do :
414415 """Generates a do statement"""
415- number_statements = probs_helper .random_value (probs .number_stmts_func_block )
416+ number_statements = probs_helper .random_value (probs .number_stmts_block_prob )
416417 statements = [generate_stmt_func (program , function , stmt_depth - 1 ) for _ in range (number_statements )]
417418 condition = generate_expression (program , function , SignedInt (), None )
418419 return ast .Do (statements , condition )
@@ -423,7 +424,7 @@ def generate_stmt_if(program: Program, function: Function, stmt_depth: int) -> I
423424 # generate condition
424425 condition = generate_expression (program , function , SignedInt (), None )
425426 # generate if body
426- number_if_statements = probs_helper .random_value (probs .number_stmts_func_block )
427+ number_if_statements = probs_helper .random_value (probs .number_stmts_block_prob )
427428 if_statements = [generate_stmt_func (program , function , stmt_depth - 1 ) for _ in range (number_if_statements )]
428429 is_there_return_stmt = not isinstance (function .return_type , Void ) and \
429430 probs_helper .random_value (probs .return_at_end_if_else_bodies_prob )
@@ -433,7 +434,7 @@ def generate_stmt_if(program: Program, function: Function, stmt_depth: int) -> I
433434 # generate else body
434435 is_there_else_body = probs_helper .random_value (probs .else_body_prob )
435436 if is_there_else_body :
436- number_else_statements = probs_helper .random_value (probs .number_stmts_func_block )
437+ number_else_statements = probs_helper .random_value (probs .number_stmts_block_prob )
437438 else_statements = [generate_stmt_func (program , function , stmt_depth - 1 ) for _ in range (number_else_statements )]
438439 if is_there_return_stmt : # generate a return statement at the end of the else block
439440 else_statements .append (generate_stmt_return (program , function ,
@@ -448,7 +449,7 @@ def generate_stmt_for(program: Program, function: Function, stmt_depth: int) ->
448449 initialization = generate_basic_stmt (program , function )
449450 condition = generate_expression (program , function , SignedInt (), None )
450451 increment = generate_basic_stmt (program , function )
451- number_statements = probs_helper .random_value (probs .number_stmts_func_block )
452+ number_statements = probs_helper .random_value (probs .number_stmts_block_prob )
452453 statements = [generate_stmt_func (program , function , stmt_depth - 1 ) for _ in range (number_statements )]
453454 return For (initialization , condition , increment , statements )
454455
@@ -467,7 +468,7 @@ def generate_stmt_switch(program: Program, function: Function, stmt_depth: int)
467468 literal = generate_literal (program , function , condition_type )
468469 if literal in cases .keys ():
469470 continue # avoid repeated literals in case conditions
470- number_statements = probs_helper .random_value (probs .number_stmts_func_block )
471+ number_statements = probs_helper .random_value (probs .number_stmts_block_prob )
471472 case_statements = [generate_stmt_func (program , function , stmt_depth - 1 ) for _ in range (number_statements )]
472473 if is_there_return_stmt : # append a return statement at the end of the case block
473474 case_statements .append (generate_stmt_return (program , function ,
@@ -480,7 +481,7 @@ def generate_stmt_switch(program: Program, function: Function, stmt_depth: int)
480481 # generate default, if necessary
481482 is_there_default = probs_helper .random_value (probs .default_switch_prob )
482483 if is_there_default :
483- number_statements = probs_helper .random_value (probs .number_stmts_func_block )
484+ number_statements = probs_helper .random_value (probs .number_stmts_block_prob )
484485 default_statements = [generate_stmt_func (program , function , stmt_depth - 1 ) for _ in range (number_statements )]
485486 if is_there_return_stmt :
486487 default_statements .append (generate_stmt_return (program , function ,
@@ -606,7 +607,7 @@ def generate_program():
606607 program .main = main_function = ast .Function ("main" , ast .SignedInt (), [])
607608 for i in range (number_statements ):
608609 program .main .stmts .append (generate_stmt_func (program , main_function ))
609- main_function .stmts .append (generate_stmt_return (program , main_function , exp = 0 ))
610+ main_function .stmts .append (generate_stmt_return (program , main_function , exp = None ))
610611 return program
611612
612613
@@ -712,7 +713,7 @@ def count_functions_generated_by_group() -> Dict[str, int]:
712713 ###
713714 # Add return statement
714715 ###
715- main_function .stmts .append (generate_stmt_return (program , main_function , exp = 0 ))
716+ main_function .stmts .append (generate_stmt_return (program , main_function , exp = None ))
716717
717718 return program
718719
0 commit comments