Skip to content

Commit b3d2953

Browse files
committed
feat(gpt&web0): add more logs&flags
1 parent c26573a commit b3d2953

File tree

4 files changed

+77
-17
lines changed

4 files changed

+77
-17
lines changed

ChatTTS/model/gpt.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -297,22 +297,25 @@ def _prepare_generation_outputs(
297297
hiddens: List[torch.Tensor],
298298
infer_text: bool,
299299
) -> GenerationOutputs:
300-
inputs_ids = [
301-
inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx)
300+
end_idx_int = end_idx.int()
301+
302+
inputs_ids_lst = [
303+
inputs_ids[idx].narrow(0, start_idx, int(i)) for idx, i in enumerate(end_idx_int)
302304
]
303305
if infer_text:
304-
inputs_ids = [i.narrow(1, 0, 1).squeeze_(1) for i in inputs_ids]
306+
inputs_ids_lst = [i.narrow(1, 0, 1).squeeze_(1) for i in inputs_ids_lst]
305307

308+
hiddens_lst = []
306309
if len(hiddens) > 0:
307-
hiddens = torch.stack(hiddens, 1)
308-
hiddens = [
309-
hiddens[idx].narrow(0, 0, i) for idx, i in enumerate(end_idx.int())
310+
hiddens_lst = torch.stack(hiddens, 1)
311+
hiddens_lst = [
312+
hiddens_lst[idx].narrow(0, 0, int(i)) for idx, i in enumerate(end_idx_int)
310313
]
311314

312315
return self.GenerationOutputs(
313-
ids=inputs_ids,
316+
ids=inputs_ids_lst,
314317
attentions=attentions,
315-
hiddens=hiddens,
318+
hiddens=hiddens_lst,
316319
)
317320

318321
@torch.no_grad()
@@ -338,6 +341,8 @@ def generate(
338341
manual_seed: Optional[int] = None,
339342
context=Context(),
340343
):
344+
345+
self.logger.debug("start generate")
341346

342347
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = []
343348
hiddens = []
@@ -348,6 +353,8 @@ def generate(
348353
)
349354
finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
350355

356+
self.logger.debug(f"set start_idx: {start_idx}, end_idx and finish with all zeros, len {inputs_ids.shape[0]}")
357+
351358
old_temperature = temperature
352359

353360
temperature = (
@@ -357,6 +364,8 @@ def generate(
357364
.view(-1, 1)
358365
)
359366

367+
self.logger.debug(f"expand temperature from shape {old_temperature.shape} to {temperature.shape}")
368+
360369
attention_mask_cache = torch.ones(
361370
(
362371
inputs_ids.shape[0],
@@ -365,10 +374,12 @@ def generate(
365374
dtype=torch.bool,
366375
device=inputs_ids.device,
367376
)
377+
self.logger.debug(f"init attention_mask_cache with shape {attention_mask_cache.shape}")
368378
if attention_mask is not None:
369379
attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
370380
attention_mask
371381
)
382+
self.logger.debug(f"copy attention_mask with shape {attention_mask.shape}")
372383

373384
progress = inputs_ids.size(1)
374385
# pre-allocate inputs_ids
@@ -380,6 +391,7 @@ def generate(
380391
device=inputs_ids.device,
381392
)
382393
inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids)
394+
self.logger.debug(f"expand inputs_ids buf from shape {inputs_ids.shape} to {inputs_ids_buf.shape}")
383395
del inputs_ids
384396
inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
385397

@@ -396,28 +408,36 @@ def generate(
396408

397409
for i in range(max_new_token):
398410

411+
self.logger.debug("start _prepare_generation_inputs")
399412
model_input = self._prepare_generation_inputs(
400413
inputs_ids,
401414
past_key_values,
402415
attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
403416
)
417+
self.logger.debug("finis _prepare_generation_inputs")
404418

405419
if i > 0:
406420
del emb
407421
inputs_ids_emb = model_input.input_ids.to(self.device_gpt)
408422
if infer_text:
423+
self.logger.debug("start emb_text")
409424
emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0])
425+
self.logger.debug("finis emb_text")
410426
else:
427+
self.logger.debug("start code_emb")
411428
code_emb = [
412-
self.emb_code[i](inputs_ids_emb[:, :, i])
429+
self.emb_code[i](inputs_ids_emb[:, :, i]).to(self.device)
413430
for i in range(self.num_vq)
414431
]
415432
emb = torch.stack(code_emb, 3).sum(3)
433+
self.logger.debug("finis code_emb")
416434
del inputs_ids_emb, model_input.input_ids
417435
model_input.inputs_embeds = emb
418436

437+
self.logger.debug(f"move model_input to device_gpt: {str(self.device_gpt)}")
419438
model_input.to(self.device_gpt, self.gpt.dtype)
420439

440+
self.logger.debug("start gpt...")
421441
outputs: BaseModelOutputWithPast = self.gpt(
422442
attention_mask=model_input.attention_mask,
423443
position_ids=model_input.position_ids,
@@ -427,6 +447,7 @@ def generate(
427447
output_attentions=return_attn,
428448
cache_position=model_input.cache_position,
429449
)
450+
self.logger.debug("finis gpt")
430451
del_all(model_input)
431452
attentions.append(outputs.attentions)
432453
hidden_states = outputs.last_hidden_state.to(
@@ -439,8 +460,11 @@ def generate(
439460

440461
with P.cached():
441462
if infer_text:
463+
self.logger.debug("start head_text")
442464
logits: torch.Tensor = self.head_text(hidden_states)
465+
self.logger.debug("finis head_text")
443466
else:
467+
self.logger.debug("start head_code")
444468
# logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
445469
logits = torch.empty(
446470
hidden_states.size(0),
@@ -454,9 +478,11 @@ def generate(
454478
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
455479
logits[..., num_vq_iter] = x
456480
del x
481+
self.logger.debug("finis head_code")
457482

458483
del hidden_states
459484

485+
self.logger.debug("start logits")
460486
# logits = logits[:, -1].float()
461487
logits = logits.narrow(1, -1, 1).squeeze_(1).float()
462488

@@ -500,6 +526,9 @@ def generate(
500526

501527
del logits
502528

529+
self.logger.debug("finis logits")
530+
531+
self.logger.debug("start seed")
503532
if manual_seed is None:
504533
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
505534
else:
@@ -511,6 +540,10 @@ def generate(
511540

512541
del scores
513542

543+
self.logger.debug("finis seed")
544+
545+
self.logger.debug("start finish")
546+
514547
if not infer_text:
515548
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
516549
idx_next = idx_next.view(-1, self.num_vq)
@@ -526,6 +559,8 @@ def generate(
526559
idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
527560
)
528561

562+
self.logger.debug("finis finish")
563+
529564
if i == 0 and finish.any():
530565
self.logger.warning(
531566
"unexpected end at index %s",
@@ -570,6 +605,8 @@ def generate(
570605
yield result
571606
del inputs_ids
572607
return
608+
609+
self.logger.debug("start output")
573610

574611
del idx_next
575612
progress += 1
@@ -591,6 +628,8 @@ def generate(
591628
)
592629
del not_finished
593630

631+
self.logger.debug("finis output")
632+
594633
if finish.all() or context.get():
595634
break
596635

ChatTTS/utils/gpu.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,17 @@ def select_device(min_memory=2047, experimental=False):
4646
logger.get_logger().info("found Apple GPU, but use CPU.")
4747
device = torch.device("cpu")
4848
elif importlib.util.find_spec("torch_directml") is not None:
49-
import torch_directml
50-
51-
device = torch_directml.device(torch_directml.default_device())
49+
"""
50+
Currently DML is under developing and may output wrong result,
51+
so only enable this for experimental use.
52+
"""
53+
if experimental:
54+
logger.get_logger().warning("experimental: using DML.")
55+
import torch_directml
56+
device = torch_directml.device(torch_directml.default_device())
57+
else:
58+
logger.get_logger().info("found DML, but use CPU.")
59+
device = torch.device("cpu")
5260
else:
5361
logger.get_logger().warning("no GPU or NPU found, use CPU instead")
5462
device = torch.device("cpu")

examples/web/funcs.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
has_interrupted = False
2626
is_in_generate = False
2727

28+
enable_cache=True
29+
experimental=False
30+
2831
seed_min = 1
2932
seed_max = 4294967295
3033

@@ -61,14 +64,21 @@ def on_audio_seed_change(audio_seed_input):
6164
rand_spk = chat.sample_random_speaker()
6265
return rand_spk
6366

67+
def set_params(en_cache, exp):
68+
global enable_cache, experimental
69+
70+
enable_cache = en_cache
71+
experimental = exp
72+
73+
def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool:
74+
global enable_cache, experimental
6475

65-
def load_chat(cust_path: Optional[str], coef: Optional[str], enable_cache=True) -> bool:
6676
if cust_path == None:
67-
ret = chat.load(coef=coef, enable_cache=enable_cache)
77+
ret = chat.load(coef=coef, enable_cache=enable_cache, experimental=experimental)
6878
else:
6979
logger.info("local model path: %s", cust_path)
7080
ret = chat.load(
71-
"custom", custom_path=cust_path, coef=coef, enable_cache=enable_cache
81+
"custom", custom_path=cust_path, coef=coef, enable_cache=enable_cache, experimental=experimental
7282
)
7383
global custom_path
7484
custom_path = cust_path

examples/web/webui.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,14 @@ def make_audio(autoplay, stream):
264264
parser.add_argument(
265265
"--disable_cache", action="store_true", help="enable model cache"
266266
)
267+
parser.add_argument(
268+
"--experimental", action="store_true", help="enable model cache"
269+
)
267270
args = parser.parse_args()
268-
271+
set_params(not args.disable_cache, args.experimental)
269272
logger.info("loading ChatTTS model...")
270273

271-
if load_chat(args.custom_path, args.coef, not args.disable_cache):
274+
if load_chat(args.custom_path, args.coef):
272275
logger.info("Models loaded successfully.")
273276
else:
274277
logger.error("Models load failed.")

0 commit comments

Comments
 (0)