@@ -297,22 +297,25 @@ def _prepare_generation_outputs(
297297 hiddens : List [torch .Tensor ],
298298 infer_text : bool ,
299299 ) -> GenerationOutputs :
300- inputs_ids = [
301- inputs_ids [idx ].narrow (0 , start_idx , i ) for idx , i in enumerate (end_idx )
300+ end_idx_int = end_idx .int ()
301+
302+ inputs_ids_lst = [
303+ inputs_ids [idx ].narrow (0 , start_idx , int (i )) for idx , i in enumerate (end_idx_int )
302304 ]
303305 if infer_text :
304- inputs_ids = [i .narrow (1 , 0 , 1 ).squeeze_ (1 ) for i in inputs_ids ]
306+ inputs_ids_lst = [i .narrow (1 , 0 , 1 ).squeeze_ (1 ) for i in inputs_ids_lst ]
305307
308+ hiddens_lst = []
306309 if len (hiddens ) > 0 :
307- hiddens = torch .stack (hiddens , 1 )
308- hiddens = [
309- hiddens [idx ].narrow (0 , 0 , i ) for idx , i in enumerate (end_idx . int () )
310+ hiddens_lst = torch .stack (hiddens , 1 )
311+ hiddens_lst = [
312+ hiddens_lst [idx ].narrow (0 , 0 , int ( i )) for idx , i in enumerate (end_idx_int )
310313 ]
311314
312315 return self .GenerationOutputs (
313- ids = inputs_ids ,
316+ ids = inputs_ids_lst ,
314317 attentions = attentions ,
315- hiddens = hiddens ,
318+ hiddens = hiddens_lst ,
316319 )
317320
318321 @torch .no_grad ()
@@ -338,6 +341,8 @@ def generate(
338341 manual_seed : Optional [int ] = None ,
339342 context = Context (),
340343 ):
344+
345+ self .logger .debug ("start generate" )
341346
342347 attentions : List [Optional [Tuple [torch .FloatTensor , ...]]] = []
343348 hiddens = []
@@ -348,6 +353,8 @@ def generate(
348353 )
349354 finish = torch .zeros (inputs_ids .shape [0 ], device = inputs_ids .device ).bool ()
350355
356+ self .logger .debug (f"set start_idx: { start_idx } , end_idx and finish with all zeros, len { inputs_ids .shape [0 ]} " )
357+
351358 old_temperature = temperature
352359
353360 temperature = (
@@ -357,6 +364,8 @@ def generate(
357364 .view (- 1 , 1 )
358365 )
359366
367+ self .logger .debug (f"expand temperature from shape { old_temperature .shape } to { temperature .shape } " )
368+
360369 attention_mask_cache = torch .ones (
361370 (
362371 inputs_ids .shape [0 ],
@@ -365,10 +374,12 @@ def generate(
365374 dtype = torch .bool ,
366375 device = inputs_ids .device ,
367376 )
377+ self .logger .debug (f"init attention_mask_cache with shape { attention_mask_cache .shape } " )
368378 if attention_mask is not None :
369379 attention_mask_cache .narrow (1 , 0 , attention_mask .shape [1 ]).copy_ (
370380 attention_mask
371381 )
382+ self .logger .debug (f"copy attention_mask with shape { attention_mask .shape } " )
372383
373384 progress = inputs_ids .size (1 )
374385 # pre-allocate inputs_ids
@@ -380,6 +391,7 @@ def generate(
380391 device = inputs_ids .device ,
381392 )
382393 inputs_ids_buf .narrow (1 , 0 , progress ).copy_ (inputs_ids )
394+ self .logger .debug (f"expand inputs_ids buf from shape { inputs_ids .shape } to { inputs_ids_buf .shape } " )
383395 del inputs_ids
384396 inputs_ids = inputs_ids_buf .narrow (1 , 0 , progress )
385397
@@ -396,28 +408,36 @@ def generate(
396408
397409 for i in range (max_new_token ):
398410
411+ self .logger .debug ("start _prepare_generation_inputs" )
399412 model_input = self ._prepare_generation_inputs (
400413 inputs_ids ,
401414 past_key_values ,
402415 attention_mask_cache .narrow (1 , 0 , inputs_ids .shape [1 ]),
403416 )
417+ self .logger .debug ("finis _prepare_generation_inputs" )
404418
405419 if i > 0 :
406420 del emb
407421 inputs_ids_emb = model_input .input_ids .to (self .device_gpt )
408422 if infer_text :
423+ self .logger .debug ("start emb_text" )
409424 emb : torch .Tensor = self .emb_text (inputs_ids_emb [:, :, 0 ])
425+ self .logger .debug ("finis emb_text" )
410426 else :
427+ self .logger .debug ("start code_emb" )
411428 code_emb = [
412- self .emb_code [i ](inputs_ids_emb [:, :, i ])
429+ self .emb_code [i ](inputs_ids_emb [:, :, i ]). to ( self . device )
413430 for i in range (self .num_vq )
414431 ]
415432 emb = torch .stack (code_emb , 3 ).sum (3 )
433+ self .logger .debug ("finis code_emb" )
416434 del inputs_ids_emb , model_input .input_ids
417435 model_input .inputs_embeds = emb
418436
437+ self .logger .debug (f"move model_input to device_gpt: { str (self .device_gpt )} " )
419438 model_input .to (self .device_gpt , self .gpt .dtype )
420439
440+ self .logger .debug ("start gpt..." )
421441 outputs : BaseModelOutputWithPast = self .gpt (
422442 attention_mask = model_input .attention_mask ,
423443 position_ids = model_input .position_ids ,
@@ -427,6 +447,7 @@ def generate(
427447 output_attentions = return_attn ,
428448 cache_position = model_input .cache_position ,
429449 )
450+ self .logger .debug ("finis gpt" )
430451 del_all (model_input )
431452 attentions .append (outputs .attentions )
432453 hidden_states = outputs .last_hidden_state .to (
@@ -439,8 +460,11 @@ def generate(
439460
440461 with P .cached ():
441462 if infer_text :
463+ self .logger .debug ("start head_text" )
442464 logits : torch .Tensor = self .head_text (hidden_states )
465+ self .logger .debug ("finis head_text" )
443466 else :
467+ self .logger .debug ("start head_code" )
444468 # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
445469 logits = torch .empty (
446470 hidden_states .size (0 ),
@@ -454,9 +478,11 @@ def generate(
454478 x : torch .Tensor = self .head_code [num_vq_iter ](hidden_states )
455479 logits [..., num_vq_iter ] = x
456480 del x
481+ self .logger .debug ("finis head_code" )
457482
458483 del hidden_states
459484
485+ self .logger .debug ("start logits" )
460486 # logits = logits[:, -1].float()
461487 logits = logits .narrow (1 , - 1 , 1 ).squeeze_ (1 ).float ()
462488
@@ -500,6 +526,9 @@ def generate(
500526
501527 del logits
502528
529+ self .logger .debug ("finis logits" )
530+
531+ self .logger .debug ("start seed" )
503532 if manual_seed is None :
504533 idx_next = torch .multinomial (scores , num_samples = 1 ).to (finish .device )
505534 else :
@@ -511,6 +540,10 @@ def generate(
511540
512541 del scores
513542
543+ self .logger .debug ("finis seed" )
544+
545+ self .logger .debug ("start finish" )
546+
514547 if not infer_text :
515548 # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
516549 idx_next = idx_next .view (- 1 , self .num_vq )
@@ -526,6 +559,8 @@ def generate(
526559 idx_next .unsqueeze_ (- 1 ).expand (- 1 , - 1 , self .num_vq ),
527560 )
528561
562+ self .logger .debug ("finis finish" )
563+
529564 if i == 0 and finish .any ():
530565 self .logger .warning (
531566 "unexpected end at index %s" ,
@@ -570,6 +605,8 @@ def generate(
570605 yield result
571606 del inputs_ids
572607 return
608+
609+ self .logger .debug ("start output" )
573610
574611 del idx_next
575612 progress += 1
@@ -591,6 +628,8 @@ def generate(
591628 )
592629 del not_finished
593630
631+ self .logger .debug ("finis output" )
632+
594633 if finish .all () or context .get ():
595634 break
596635
0 commit comments