41
41
"range_num_stages" ,
42
42
"range_multi_buffers" ,
43
43
"range_flattens" ,
44
+ "static_ranges" ,
44
45
"num_warps" ,
45
46
"num_stages" ,
46
47
"use_yz_grid" ,
@@ -81,6 +82,9 @@ class ConfigSpec:
81
82
range_flattens : BlockIdSequence [RangeFlattenSpec ] = dataclasses .field (
82
83
default_factory = BlockIdSequence
83
84
)
85
+ static_ranges : BlockIdSequence [StaticRangeSpec ] = dataclasses .field (
86
+ default_factory = BlockIdSequence
87
+ )
84
88
user_defined_tunables : dict [str , ConfigSpecFragment ] = dataclasses .field (
85
89
default_factory = dict
86
90
)
@@ -95,6 +99,7 @@ def _remove_duplicates(self) -> None:
95
99
self .range_num_stages ._remove_duplicates ()
96
100
self .range_multi_buffers ._remove_duplicates ()
97
101
self .range_flattens ._remove_duplicates ()
102
+ self .static_ranges ._remove_duplicates ()
98
103
99
104
def normalize (self , config : helion .Config | dict [str , object ]) -> None :
100
105
"""Normalize the config to match the block_sizes and validate the config."""
@@ -113,6 +118,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
113
118
"range_num_stage" ,
114
119
"range_multi_buffer" ,
115
120
"range_flatten" ,
121
+ "static_range" ,
116
122
):
117
123
if name in config :
118
124
names = f"{ name } s"
@@ -131,11 +137,32 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
131
137
("range_num_stages" , self .range_num_stages , True ),
132
138
("range_multi_buffers" , self .range_multi_buffers , True ),
133
139
("range_flattens" , self .range_flattens , True ),
140
+ ("static_ranges" , self .static_ranges , True ),
134
141
]:
135
142
config [name ] = mapping ._normalize (
136
143
name , config .get (name , ()), flatten = flatten
137
144
)
138
145
146
+ static_range_block_ids = []
147
+ for block_id in self .static_ranges .valid_block_ids ():
148
+ use_static_range = self .static_ranges .config_get (
149
+ config .get ("static_ranges" , ()), # pyre-ignore[6]
150
+ block_id ,
151
+ )
152
+ if use_static_range :
153
+ static_range_block_ids .append (block_id )
154
+
155
+ for name , mapping in (
156
+ ("range_unroll_factors" , self .range_unroll_factors ),
157
+ ("range_warp_specializes" , self .range_warp_specialize ),
158
+ ("range_num_stages" , self .range_num_stages ),
159
+ ("range_multi_buffers" , self .range_multi_buffers ),
160
+ ("range_flattens" , self .range_flattens ),
161
+ ):
162
+ config [name ] = mapping ._reset_to_default (
163
+ name , config .get (name , ()), block_ids = static_range_block_ids
164
+ )
165
+
139
166
for name in (
140
167
"loop_orders" ,
141
168
"l2_groupings" ,
@@ -146,6 +173,7 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
146
173
"range_num_stages" ,
147
174
"range_multi_buffers" ,
148
175
"range_flattens" ,
176
+ "static_ranges" ,
149
177
):
150
178
if not config [name ]:
151
179
config .pop (name )
@@ -180,6 +208,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
180
208
"range_num_stages" : self .range_num_stages ._flat_config (self , fn ),
181
209
"range_multi_buffers" : self .range_multi_buffers ._flat_config (self , fn ),
182
210
"range_flattens" : self .range_flattens ._flat_config (self , fn ),
211
+ "static_ranges" : self .static_ranges ._flat_config (self , fn ),
183
212
"num_warps" : fn (NumWarpsFragment (1 , 32 , DEFAULT_NUM_WARPS )),
184
213
"num_stages" : fn (IntegerFragment (1 , 8 , DEFAULT_NUM_STAGES )),
185
214
"indexing" : fn (
@@ -211,6 +240,7 @@ def flat_config(self, fn: Callable[[ConfigSpecFragment], object]) -> helion.Conf
211
240
"range_num_stages" ,
212
241
"range_multi_buffers" ,
213
242
"range_flattens" ,
243
+ "static_ranges" ,
214
244
):
215
245
if not config [name ]:
216
246
config .pop (name )
@@ -399,6 +429,20 @@ class RangeFlattenSpec(_OptionalBoolSpec):
399
429
pass
400
430
401
431
432
+ class StaticRangeSpec (_BlockIdItem ):
433
+ def _fragment (self , base : ConfigSpec ) -> BooleanFragment :
434
+ return BooleanFragment ()
435
+
436
+ def _normalize (self , name : str , value : object ) -> bool :
437
+ if not isinstance (value , bool ):
438
+ raise InvalidConfig (f"{ name } must be a boolean, got { value !r} " )
439
+ return value
440
+
441
+ def _fill_missing (self ) -> bool :
442
+ """Provide a value when not provided by the user."""
443
+ return False
444
+
445
+
402
446
def _product (seq : Sequence [int ]) -> int :
403
447
"""Return the product of the elements in the sequence."""
404
448
return functools .reduce (operator .mul , seq , 1 )
0 commit comments