@@ -138,14 +138,16 @@ def update(
138138
139139
140140class StaticAttentionMask :
141- def __init__ (self , input_len , cache_len , style , mask_val = float ("-inf" )):
141+ def __init__ (
142+ self , input_len , cache_len , style , mask_val = float ("-inf" ), dtype = torch .float32
143+ ):
142144 self .input_len = input_len
143145 self .cache_len = cache_len
144146 assert style in ("shift_pointer" , "smart_mask" )
145147 self .style = style
146148 self .mask_val = mask_val
147149 self .unmasked_len = 0
148- self .tensor = torch .zeros (1 , input_len , input_len + cache_len )
150+ self .tensor = torch .zeros (1 , input_len , input_len + cache_len , dtype = dtype )
149151 self .reset ()
150152
151153 def reset (self ):
@@ -200,44 +202,45 @@ def __init__(
200202 config : ModelArgs ,
201203 input_len : int ,
202204 cache_len : int ,
205+ dtype = torch .float32 ,
203206 style : str = "shift_pointer" ,
204207 mask_val : float = float ("-inf" ),
205208 ):
206209 self .mask = StaticAttentionMask (
207- input_len , cache_len , style = style , mask_val = mask_val
210+ input_len , cache_len , style = style , mask_val = mask_val , dtype = dtype
208211 )
209212
210213 rope = Rope (config )
211214 freqs = rope .get_freqs (None , config .max_seq_len )
212- self .freqs_cos = freqs [0 ]
213- self .freqs_sin = freqs [1 ]
215+ self .freqs_cos = freqs [0 ]. to ( dtype )
216+ self .freqs_sin = freqs [1 ]. to ( dtype )
214217
215218 split_mha = config .attention_type in ("static" , "static_shas" )
216219 if split_mha :
217220 self .k_caches = {
218221 StaticKVCache .calculate_cache_key (layer_id , head_id ): torch .zeros (
219- 1 , cache_len , config .head_dim
222+ 1 , cache_len , config .head_dim , dtype = dtype
220223 )
221224 for layer_id in range (config .n_layers )
222225 for head_id in range (config .n_kv_heads )
223226 }
224227 self .v_caches = {
225228 StaticKVCache .calculate_cache_key (layer_id , head_id ): torch .zeros (
226- 1 , cache_len , config .head_dim
229+ 1 , cache_len , config .head_dim , dtype = dtype
227230 )
228231 for layer_id in range (config .n_layers )
229232 for head_id in range (config .n_kv_heads )
230233 }
231234 else :
232235 self .k_caches = {
233236 StaticKVCache .calculate_cache_key (layer_id , 0 ): torch .zeros (
234- 1 , config .n_kv_heads , cache_len , config .head_dim
237+ 1 , config .n_kv_heads , cache_len , config .head_dim , dtype = dtype
235238 )
236239 for layer_id in range (config .n_layers )
237240 }
238241 self .v_caches = {
239242 StaticKVCache .calculate_cache_key (layer_id , 0 ): torch .zeros (
240- 1 , config .n_kv_heads , cache_len , config .head_dim
243+ 1 , config .n_kv_heads , cache_len , config .head_dim , dtype = dtype
241244 )
242245 for layer_id in range (config .n_layers )
243246 }
0 commit comments