diff --git a/configs/default.yaml b/configs/default.yaml index 92ed7a9..d6a33e2 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -1,6 +1,6 @@ data: - data_dir: 'H:\Deepsync\backup\fastspeech\data\' - wav_dir: 'H:\Deepsync\backup\deepsync\LJSpeech-1.1\wavs\' + data_dir: '/workspace/data/' + wav_dir: '/workspace/LJSpeech-1.1/wavs/' # Compute statistics e_mean: 21.578571319580078 e_std: 18.916799545288086 @@ -106,10 +106,12 @@ model: train: + discriminator_start: 20000 + rep_discriminator: 1 # optimization related eos: False #True opt: 'noam' - accum_grad: 4 + accum_grad: 1 grad_clip: 1.0 weight_decay: 0.001 patience: 0 @@ -125,7 +127,7 @@ train: seed: 1 # random seed number resume: "" # the snapshot path to resume (if set empty, no effect) use_phonemes: True - batch_size : 16 + batch_size : 24 # other melgan_vocoder : True save_interval : 1000 @@ -134,4 +136,4 @@ train: summary_interval : 200 validation_step : 500 tts_max_mel_len : 870 # if you have a couple of extremely long spectrograms you might want to use this - tts_bin_lengths : True # bins the spectrogram lengths before sampling in data loader - speeds up training \ No newline at end of file + tts_bin_lengths : True # bins the spectrogram lengths before sampling in data loader - speeds up training diff --git a/core/discriminator.py b/core/discriminator.py new file mode 100644 index 0000000..08b2a30 --- /dev/null +++ b/core/discriminator.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + + self.discriminator = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, stride=1, padding = 1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(16, 32, kernel_size=3, stride=1, padding = 1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(32, 64, kernel_size=3, stride=1, padding = 1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 1, kernel_size=3, stride=1, padding = 1) + #nn.Flatten(), # add conv2d a 1 channel + #nn.Linear(46240,256) + ) + + def forward(self, x): + ''' + we directly predict score without last sigmoid function + since we're using Least Squares GAN (https://arxiv.org/abs/1611.04076) + ''' + # print(x.shape, "Input to Discriminator") + return self.discriminator(x) + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find("BatchNorm2d") != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + +class SFDiscriminator(nn.Module): + def __init__(self): + super().__init__() + self.disc1 = Discriminator() + self.disc2 = Discriminator() + self.disc3 = Discriminator() + self.apply(weights_init) + def forward(self, x, start): + results = [] + results.append(self.disc1(x[:, :, start: start + 40, 0:40])) + results.append(self.disc2(x[:, :, start: start + 40, 20:60])) + results.append(self.disc3(x[:, :, start: start + 40, 40:80, ])) + return results + +if __name__ == '__main__': + model = SFDiscriminator() + + x = torch.randn(16, 1, 40, 80) + print(x.shape) + + out = model(x) + print(len(out), "Shape of output") + + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(pytorch_total_params) diff --git a/dataset/texts/__init__.py b/dataset/texts/__init__.py index 24f35e3..a9f5930 100644 --- a/dataset/texts/__init__.py +++ b/dataset/texts/__init__.py @@ -1,14 +1,7 @@ """ from https://github.com/keithito/tacotron """ import re from dataset.texts import cleaners -from dataset.texts.symbols import ( - symbols, - _eos, - phonemes_symbols, - PAD, - EOS, - _PHONEME_SEP, -) +from dataset.texts.symbols import symbols, _eos, phonemes_symbols, PAD, EOS, _PHONEME_SEP from dataset.texts.dict_ import symbols_ import nltk from g2p_en import G2p @@ -18,125 +11,64 @@ _id_to_symbol = {i: s for i, s in enumerate(symbols)} # Regular expression matching text enclosed in curly braces: -_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") +_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') symbols_inv = {v: k for k, v in symbols_.items()} -valid_symbols = [ - "AA", - "AA1", - "AE", - "AE0", - "AE1", - "AH", - "AH0", - "AH1", - "AO", - "AO1", - "AW", - "AW0", - "AW1", - "AY", - "AY0", - "AY1", - "B", - "CH", - "D", - "DH", - "EH", - "EH0", - "EH1", - "ER", - "EY", - "EY0", - "EY1", - "F", - "G", - "HH", - "IH", - "IH0", - "IH1", - "IY", - "IY0", - "IY1", - "JH", - "K", - "L", - "M", - "N", - "NG", - "OW", - "OW0", - "OW1", - "OY", - "OY0", - "OY1", - "P", - "R", - "S", - "SH", - "T", - "TH", - "UH", - "UH0", - "UH1", - "UW", - "UW0", - "UW1", - "V", - "W", - "Y", - "Z", - "ZH", - "pau", - "sil", -] - +valid_symbols = ['AA', 'AA1', 'AE', 'AE0', 'AE1', 'AH', 'AH0', 'AH1', + 'AO', 'AO1', 'AW', 'AW0', 'AW1', 'AY', 'AY0', 'AY1', + 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'ER', 'EY', + 'EY0', 'EY1', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IY', + 'IY0', 'IY1', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', + 'OW1', 'OY', 'OY0','OY1', 'P', 'R', 'S', 'SH', 'T', 'TH', + 'UH', 'UH0', 'UH1', 'UW','UW0', 'UW1', 'V', 'W', 'Y', 'Z', + 'ZH', 'pau', 'sil', 'spn'] def pad_with_eos_bos(_sequence): return _sequence + [_symbol_to_id[_eos]] + def text_to_sequence(text, cleaner_names, eos): - """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - The text can optionally have ARPAbet sequences enclosed in curly braces embedded - in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." - Args: - text: string to convert to a sequence - cleaner_names: names of the cleaner functions to run the text through - Returns: - List of integers corresponding to the symbols in the text - """ + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + The text can optionally have ARPAbet sequences enclosed in curly braces embedded + in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + Returns: + List of integers corresponding to the symbols in the text + ''' sequence = [] if eos: - text = text + "~" + text = text + '~' try: sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) except KeyError: - print("text : ", text) + print("text : ",text) exit(0) return sequence def sequence_to_text(sequence): - """Converts a sequence of IDs back to a string""" - result = "" + '''Converts a sequence of IDs back to a string''' + result = '' for symbol_id in sequence: if symbol_id in symbols_inv: s = symbols_inv[symbol_id] # Enclose ARPAbet back in curly braces: - if len(s) > 1 and s[0] == "@": - s = "{%s}" % s[1:] + if len(s) > 1 and s[0] == '@': + s = '{%s}' % s[1:] result += s - return result.replace("}{", " ") + return result.replace('}{', ' ') def _clean_text(text, cleaner_names): for name in cleaner_names: cleaner = getattr(cleaners, name) if not cleaner: - raise Exception("Unknown cleaner: %s" % name) + raise Exception('Unknown cleaner: %s' % name) text = cleaner(text) return text @@ -146,11 +78,7 @@ def _symbols_to_sequence(symbols): def _arpabet_to_sequence(text): - return _symbols_to_sequence(["@" + s for s in text.split()]) - - -def _should_keep_symbol(s): - return s in _symbol_to_id and s != "_" and s != "~" + return _symbols_to_sequence(['@' + s for s in text.split()]) # For phonemes @@ -159,58 +87,55 @@ def _should_keep_symbol(s): def _should_keep_token(token, token_dict): - return ( - token in token_dict - and token != PAD - and token != EOS - and token != _phoneme_to_id[PAD] - and token != _phoneme_to_id[EOS] - ) - + return token in token_dict \ + and token != PAD and token != EOS \ + and token != _phoneme_to_id[PAD] \ + and token != _phoneme_to_id[EOS] def phonemes_to_sequence(phonemes): string = phonemes.split() if isinstance(phonemes, str) else phonemes - # string.append(EOS) + #string.append(EOS) sequence = list(map(convert_phoneme_CMU, string)) - sequence = [_phoneme_to_id[s] for s in string] - # if _should_keep_token(s, _phoneme_to_id)] + sequence = [_phoneme_to_id[s] for s in sequence] + #if _should_keep_token(s, _phoneme_to_id)] return sequence def sequence_to_phonemes(sequence, use_eos=False): string = [_id_to_phoneme[idx] for idx in sequence] - # if _should_keep_token(idx, _id_to_phoneme)] + #if _should_keep_token(idx, _id_to_phoneme)] string = _PHONEME_SEP.join(string) if use_eos: - string = string.replace(EOS, "") + string = string.replace(EOS, '') return string def convert_phoneme_CMU(phoneme): REMAPPING = { - "AA0": "AA1", - "AA2": "AA1", - "AE2": "AE1", - "AH2": "AH1", - "AO0": "AO1", - "AO2": "AO1", - "AW2": "AW1", - "AY2": "AY1", - "EH2": "EH1", - "ER0": "EH1", - "ER1": "EH1", - "ER2": "EH1", - "EY2": "EY1", - "IH2": "IH1", - "IY2": "IY1", - "OW2": "OW1", - "OY2": "OY1", - "UH2": "UH1", - "UW2": "UW1", + 'AA0': 'AA1', + 'AA2': 'AA1', + 'AE2': 'AE1', + 'AH2': 'AH1', + 'AO0': 'AO1', + 'AO2': 'AO1', + 'AW2': 'AW1', + 'AY2': 'AY1', + 'EH2': 'EH1', + 'ER0': 'EH1', + 'ER1': 'EH1', + 'ER2': 'EH1', + 'EY2': 'EY1', + 'IH2': 'IH1', + 'IY2': 'IY1', + 'OW2': 'OW1', + 'OY2': 'OY1', + 'UH2': 'UH1', + 'UW2': 'UW1', } return REMAPPING.get(phoneme, phoneme) + def text_to_phonemes(text, custom_words={}): """ Convert text into ARPAbet. @@ -224,7 +149,7 @@ def text_to_phonemes(text, custom_words={}): """ g2p = G2p() - """def convert_phoneme_CMU(phoneme): + '''def convert_phoneme_CMU(phoneme): REMAPPING = { 'AA0': 'AA1', 'AA2': 'AA1', @@ -247,18 +172,17 @@ def text_to_phonemes(text, custom_words={}): 'UW2': 'UW1', } return REMAPPING.get(phoneme, phoneme) - """ - + ''' def convert_phoneme_listener(phoneme): - VOWELS = ["A", "E", "I", "O", "U"] + VOWELS = ['A', 'E', 'I', 'O', 'U'] if phoneme[0] in VOWELS: - phoneme += "1" - return phoneme # convert_phoneme_CMU(phoneme) + phoneme += '1' + return phoneme #convert_phoneme_CMU(phoneme) try: known_words = nltk.corpus.cmudict.dict() except LookupError: - nltk.download("cmudict") + nltk.download('cmudict') known_words = nltk.corpus.cmudict.dict() for word, phonemes in custom_words.items(): @@ -267,20 +191,16 @@ def convert_phoneme_listener(phoneme): words = nltk.tokenize.WordPunctTokenizer().tokenize(text.lower()) phonemes = [] - PUNCTUATION = "!?.,-:;\"'()" + PUNCTUATION = '!?.,-:;"\'()' for word in words: if all(c in PUNCTUATION for c in word): - pronounciation = ["pau"] + pronounciation = ['pau'] elif word in known_words: pronounciation = known_words[word][0] - pronounciation = list( - pronounciation - ) # map(convert_phoneme_CMU, pronounciation)) + pronounciation = list(pronounciation)#map(convert_phoneme_CMU, pronounciation)) else: pronounciation = g2p(word) - pronounciation = list( - pronounciation - ) # (map(convert_phoneme_CMU, pronounciation)) + pronounciation = list(pronounciation)#(map(convert_phoneme_CMU, pronounciation)) phonemes += pronounciation diff --git a/fastspeech.py b/fastspeech.py index 0202677..f4658a6 100644 --- a/fastspeech.py +++ b/fastspeech.py @@ -269,7 +269,7 @@ def forward( before_outs, after_outs, d_outs, e_outs, p_outs = self._forward( xs, ilens, olens, ds, es, ps, is_inference=False ) - + out_mels = after_outs.detach() # modifiy mod part of groundtruth # if hp.model.reduction_factor > 1: # olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) @@ -332,8 +332,9 @@ def forward( ] # self.reporter.report(report_keys) + #print(out_mels.shape, "Shape of out_mels in Fs") - return loss, report_keys + return loss, report_keys, out_mels def inference(self, x: torch.Tensor) -> torch.Tensor: """Generate the sequence of features given the sequences of characters. diff --git a/train_fastspeech.py b/train_fastspeech.py index d7f4b5c..4bf3518 100644 --- a/train_fastspeech.py +++ b/train_fastspeech.py @@ -18,6 +18,8 @@ from utils.util import get_commit_hash from utils.hparams import HParam +from core.discriminator import SFDiscriminator + BATCH_COUNT_CHOICES = ["auto", "seq", "bin", "frame"] BATCH_SORT_KEY_CHOICES = ["input", "output", "shuffle"] @@ -34,6 +36,9 @@ def train(args, hp, hp_str, logger, vocoder): idim = len(valid_symbols) odim = hp.audio.num_mels model = fastspeech.FeedForwardTransformer(idim, odim, hp) + model_d = SFDiscriminator().cuda() + criterion_d = torch.nn.MSELoss().cuda() + # set torch device model = model.to(device) print("Model is loaded ...") @@ -49,6 +54,12 @@ def train(args, hp, hp_str, logger, vocoder): hp.model.transformer_warmup_steps, hp.model.transformer_lr, ) + optim_d = get_std_opt( + model_d, + hp.model.adim, + hp.model.transformer_warmup_steps, + hp.model.transformer_lr, + ) optimizer.load_state_dict(checkpoint["optim"]) global_step = checkpoint["step"] @@ -74,10 +85,18 @@ def train(args, hp, hp_str, logger, vocoder): hp.model.transformer_warmup_steps, hp.model.transformer_lr, ) + optim_d = get_std_opt( + model_d, + hp.model.adim, + hp.model.transformer_warmup_steps, + hp.model.transformer_lr, + ) + print("Batch Size :", hp.train.batch_size) num_params(model) + num_params(model_d) os.makedirs(os.path.join(hp.train.log_dir, args.name), exist_ok=True) writer = SummaryWriter(os.path.join(hp.train.log_dir, args.name)) @@ -88,6 +107,8 @@ def train(args, hp, hp_str, logger, vocoder): start = time.time() running_loss = 0 j = 0 + d_loss = [] + pbar = tqdm.tqdm(dataloader, desc="Loading train data") for data in pbar: @@ -96,7 +117,7 @@ def train(args, hp, hp_str, logger, vocoder): # x : [batch , num_char], input_length : [batch], y : [batch, T_in, num_mel] # # stop_token : [batch, T_in], out_length : [batch] - loss, report_dict = model( + loss, report_dict, mel = model( x.cuda(), input_length.cuda(), y.cuda(), @@ -108,6 +129,18 @@ def train(args, hp, hp_str, logger, vocoder): loss = loss.mean() / hp.train.accum_grad running_loss += loss.item() + adv_loss = 0 + + if global_step >= hp.train.discriminator_start: + start_disc = np.random.randint(0, out_length.min()-40) + + disc_fake = model_d(mel.unsqueeze(1).cuda(), start_disc) + for score_fake in disc_fake: + # adv_loss += torch.mean(torch.sum(torch.pow(score_fake - 1.0, 2), dim=[1, 2])) + adv_loss += criterion_d(score_fake, torch.ones_like(score_fake)) + adv_loss = adv_loss / len(disc_fake) # len(disc_fake) = 3 + loss = loss + adv_loss + loss.backward() # update parameters @@ -129,6 +162,41 @@ def train(args, hp, hp_str, logger, vocoder): optimizer.step() optimizer.zero_grad() + + # Discriminator + loss_d_avg = 0.0 + loss_d_sum = 0.0 + + if step > hp.train.discriminator_start: + loss, report_dict, mel = model( + x.cuda(), + input_length.cuda(), + y.cuda(), + out_length.cuda(), + dur.cuda(), + e.cuda(), + p.cuda(), + ) + + optim_d.zero_grad() + start_disc = np.random.randint(0, out_length.min()-40) + disc_fake = model_d(mel.unsqueeze(1).cuda(), start_disc) + disc_real = model_d(y.unsqueeze(1).cuda(), start_disc) + loss_d = 0.0 + loss_d_real = 0.0 + loss_d_fake = 0.0 + for score_fake, score_real in zip(disc_fake, disc_real): + loss_d_real += criterion_d(score_real, torch.ones_like(score_real)) + loss_d_fake += criterion_d(score_fake, torch.zeros_like(score_fake)) + loss_d_real = loss_d_real / len(disc_real) # len(disc_real) = 3 + loss_d_fake = loss_d_fake / len(disc_fake) # len(disc_fake) = 3 + loss_d = loss_d_real + loss_d_fake + loss_d.backward() + optim_d.step() + loss_d_sum += loss_d + loss_d_avg = loss_d_sum.item() + writer.add_scalar("Advverserial Loss", loss_d_avg, step) + if step % hp.train.summary_interval == 0: pbar.set_description( "Average Loss %.04f Loss %.04f | step %d" @@ -150,7 +218,7 @@ def train(args, hp, hp_str, logger, vocoder): x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid model.eval() with torch.no_grad(): - loss_, report_dict_ = model( + loss_, report_dict_, _ = model( x_.cuda(), input_length_.cuda(), y_.cuda(),