@@ -340,17 +340,7 @@ def forward(
340340 fmha_out = None
341341
342342 # NOTE: (changwenbin) qkv_a_proj horizontal fusion
343- paddle .device .synchronize ()
344- print ("==RyanDebug, the hidden_states is:" , hidden_states ) # 这是一个输入,我们假设它没问题,但也可以加上检查
345- print ("==RyanDebug, hidden_states contains NaN:" , paddle .any (paddle .isnan (hidden_states )).item ())
346-
347343 qkv_a_out = self .qkv_a_proj_with_mqa (hidden_states )
348- paddle .device .synchronize ()
349-
350- # --- NaN Check Start ---
351- print ("===RyanDebug, the qkv_a_out is:" , qkv_a_out )
352- print (" >>> RyanDebug, qkv_a_out contains NaN:" , paddle .any (paddle .isnan (qkv_a_out )).item ())
353- # --- NaN Check End ---
354344
355345 query , compressed_kv , key_pe = qkv_a_out .split (
356346 [self .q_lora_rank , self .kv_lora_rank , self .qk_rope_head_dim ], axis = - 1
@@ -363,13 +353,10 @@ def forward(
363353
364354 key_pe .reshape_ ([- 1 , 1 , self .qk_rope_head_dim ])
365355 query_pe , key_pe = self .rotary_emb (position_ids , query_pe , key_pe )
366- paddle .device .synchronize ()
367356
368357 compressed_kv = self .kv_a_layernorm (compressed_kv )[0 ]
369358
370- print ("===RyanDebug, in #370, forward_meta.max_len_tensor_cpu[1] is:" , forward_meta .max_len_tensor_cpu [1 ])
371359 if forward_meta .max_len_tensor_cpu [1 ]: # max_enc_len_this_time
372- print ("===RyanDebug, in #372, forward_meta.max_len_tensor_cpu[1] is:" , forward_meta .max_len_tensor_cpu [1 ])
373360 key_value = self .kv_b_proj (compressed_kv )
374361 key_value .reshape_ (
375362 [
@@ -402,12 +389,8 @@ def forward(
402389 fmha_out_prefill = fmha_out_prefill * mask_encoder_batch .cast (fmha_out_prefill .dtype )
403390
404391 fmha_out = fmha_out_prefill
405- print ("====RYanDebug, #404, fmha_out after MLA is: " , fmha_out )
406392
407393 if forward_meta .max_len_tensor_cpu [2 ]: # max_dec_len_this_time
408- print ("===RyanDebug, D in dsv3 !!!!=====" )
409- paddle .device .synchronize ()
410-
411394 q_nope_out = self .kv_b_proj_bmm (query_nope .transpose ([1 , 0 , 2 ]), proj_type = "k" ).transpose ([1 , 0 , 2 ])
412395
413396 q_input = paddle .concat ([q_nope_out , query_pe ], axis = - 1 )
@@ -418,18 +401,6 @@ def forward(
418401 ]
419402 )
420403
421- print ("===RyanDebug, the q_input # 435 is:" , q_input )
422- print (" >>> RyanDebug, q_input # 435 contains NaN:" , paddle .any (paddle .isnan (q_input )).item ())
423-
424- print ("===RyanDebug, the compressed_kv # 435 is:" , compressed_kv )
425- print (
426- " >>> RyanDebug, compressed_kv # 435 contains NaN:" , paddle .any (paddle .isnan (compressed_kv )).item ()
427- )
428-
429- print ("===RyanDebug, the key_pe # 435 is:" , q_input )
430- print (" >>> RyanDebug, key_pe # 435 contains NaN:" , paddle .any (paddle .isnan (key_pe )).item ())
431-
432- paddle .device .synchronize ()
433404 fmha_out_decode = self .mla_attn (
434405 q = q_input ,
435406 k = None ,
@@ -439,39 +410,23 @@ def forward(
439410 k_pe = key_pe ,
440411 forward_meta = forward_meta ,
441412 )
442- paddle .device .synchronize ()
443- # --- NaN Check Start ---
444- print ("===RyanDebug, the fmha_out_decode # 448 is:" , fmha_out_decode )
445- print (
446- " >>> RyanDebug, fmha_out_decode # 448 contains NaN:" ,
447- paddle .any (paddle .isnan (fmha_out_decode )).item (),
448- )
449413
450414 fmha_out_decode = fmha_out_decode .reshape ([- 1 , self .num_attention_heads_tp , self .kv_lora_rank ]).transpose (
451415 [1 , 0 , 2 ]
452416 )
453417
454- paddle .device .synchronize ()
455-
456418 fmha_out_decode = (
457419 self .kv_b_proj_bmm (fmha_out_decode , proj_type = "v" )
458420 .transpose ([1 , 0 , 2 ])
459421 .reshape ([- 1 , self .num_attention_heads_tp * self .v_head_dim ])
460422 )
461423
462- # --- NaN Check Start ---
463- print ("===RyanDebug, the fmha_out_decode is:" , fmha_out_decode )
464- print (" >>> RyanDebug, fmha_out_decode contains NaN:" , paddle .any (paddle .isnan (fmha_out_decode )).item ())
465- # --- NaN Check End ---
466-
467- paddle .device .synchronize ()
468424 if fmha_out is None :
469425 fmha_out = fmha_out_decode
470426 else :
471427 fmha_out = fmha_out + fmha_out_decode
472428
473429 output = self .o_proj (fmha_out )
474- paddle .device .synchronize ()
475430 return output
476431
477432 def load_state_dict (self , state_dict ):
@@ -559,19 +514,11 @@ def forward(
559514 hidden_states , residual_input = residual , forward_meta = forward_meta
560515 )
561516
562- print ("===RyanDebug, the hidden_states before self_attn is :" , hidden_states )
563517 hidden_states = self .self_attn (forward_meta , hidden_states , position_ids , mask_encoder_batch )
564518
565- print ("==RyanDebug, #563 hidden_states contains NaN:" , paddle .any (paddle .isnan (hidden_states )).item ())
566-
567519 hidden_states , residual = self .post_attention_layernorm (hidden_states , residual )
568- print ("==RyanDebug, #566 hidden_states contains NaN:" , paddle .any (paddle .isnan (hidden_states )).item ())
569520 hidden_states = self .mlp (hidden_states )
570521
571- print ("===RyanDebug, the hidden_states after mlp is :" , hidden_states )
572- print (
573- "==RyanDebug, #570 hidden_states after mlp contains NaN:" , paddle .any (paddle .isnan (hidden_states )).item ()
574- )
575522 return hidden_states , residual
576523
577524
@@ -731,7 +678,6 @@ def load_weights(self, weights_iterator) -> None:
731678 process_weights_after_loading_fn = process_weights_after_loading (dict (self .named_sublayers ()), self .fd_config )
732679 for loaded_weight_name , loaded_weight in weights_iterator :
733680 loaded_weight_name = loaded_weight_name .replace ("deepseek_v3" , "model" )
734- print (f"loaded_weight_name:{ loaded_weight_name } " )
735681 for param_name , weight_name , shard_id in stacked_params_mapping :
736682 if weight_name not in loaded_weight_name :
737683 continue
0 commit comments