@@ -442,16 +442,24 @@ def compact_number(v):
442442def get_grid_dimensions (current_problem_size , params , grid_div , block_size_names ):
443443 """Compute grid dims based on problem sizes and listed grid divisors."""
444444
445- def get_dimension_divisor (divisor_list , default , params ):
446- if divisor_list is None :
447- if default in params :
448- divisor_list = [default ]
449- else :
450- return 1
451- if callable (divisor_list ):
452- return divisor_list (params )
445+ def get_dimension_divisor (divisor , default , params ):
446+ divisor_num = 1
447+
448+ if divisor is None :
449+ divisor_num = params .get (default , 1 )
450+ elif isinstance (divisor , int ):
451+ divisor_num = divisor
452+ elif callable (divisor ):
453+ divisor_num = divisor (params )
454+ elif isinstance (divisor , str ):
455+ divisor_num = int (eval (replace_param_occurrences (divisor , params )))
456+ elif np .iterable (divisor ):
457+ for div in divisor :
458+ divisor_num *= get_dimension_divisor (div , 1 , params )
453459 else :
454- return np .prod ([int (eval (replace_param_occurrences (s , params ))) for s in divisor_list ])
460+ raise ValueError ("Error: unrecognized type in grid divisor list, should be any of int, str, callable, or iterable" )
461+
462+ return divisor_num
455463
456464 divisors = [get_dimension_divisor (d , block_size_names [i ], params ) for i , d in enumerate (grid_div )]
457465 return tuple (int (np .ceil (float (current_problem_size [i ]) / float (d ))) for i , d in enumerate (divisors ))
0 commit comments