diff --git a/.idea/P_wavenet_vocoder.iml b/.idea/P_wavenet_vocoder.iml
index b61ff71..ffe6a8a 100644
--- a/.idea/P_wavenet_vocoder.iml
+++ b/.idea/P_wavenet_vocoder.iml
@@ -4,5 +4,8 @@
+
+
+
\ No newline at end of file
diff --git a/.idea/libraries/R_User_Library.xml b/.idea/libraries/R_User_Library.xml
new file mode 100644
index 0000000..71f5ff7
--- /dev/null
+++ b/.idea/libraries/R_User_Library.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
index 50c02da..62eb76a 100644
--- a/.idea/modules.xml
+++ b/.idea/modules.xml
@@ -3,6 +3,7 @@
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
index 94a25f7..35eb1dd 100644
--- a/.idea/vcs.xml
+++ b/.idea/vcs.xml
@@ -1,6 +1,6 @@
-
+
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
index 7e9acaf..63cdb05 100644
--- a/.idea/workspace.xml
+++ b/.idea/workspace.xml
@@ -2,17 +2,15 @@
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
@@ -21,70 +19,125 @@
+
+
+
+
-
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ build
+ bat
+ 10
+
+
+
+
+
true
DEFINITION_ORDER
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
+
+
-
-
+
+
+
+
+
+
+
-
-
-
-
-
+
+
+
+
+
+
+
+
@@ -98,8 +151,77 @@
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -115,28 +237,31 @@
-
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -146,15 +271,226 @@
-
-
+
+
+
+
+
+
+
+ f
+ Python
+ EXPRESSION
+
+
+ new_z
+ Python
+ EXPRESSION
+
+
+ last_layer(new_z)
+ Python
+ EXPRESSION
+
+
+ 'GroupNorm' in str(f)
+ Python
+ EXPRESSION
+
+
+ print(str(f))
+ Python
+ EXPRESSION
+
+
+ str(f)
+ Python
+ EXPRESSION
+
+
+ first_con
+ Python
+ EXPRESSION
+
+
+ u.size()
+ Python
+ EXPRESSION
+
+
+ u
+ Python
+ EXPRESSION
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
+
+
+
+
+
diff --git a/generate.zip b/generate.zip
deleted file mode 100644
index 719f047..0000000
Binary files a/generate.zip and /dev/null differ
diff --git a/preprocess.py b/preprocess.py
index d0be71c..4a4c87f 100644
--- a/preprocess.py
+++ b/preprocess.py
@@ -44,7 +44,7 @@ def write_metadata(metadata, out_dir):
in_dir = '/home/jinqiangzeng/work/data/speech/cmu_arctic'
out_dir = './data/cmu_arctic'
elif name == 'ljspeech':
- in_dir = '/home/jinqiangzeng/work/data/speech/ljspeech/LJSpeech-1.0'
+ in_dir = '/home/tesla/work/data/LJSpeech-1.0'
out_dir = './data/ljspeech'
num_workers = None # args["--num_workers"]
num_workers = cpu_count() - 1 if num_workers is None else int(num_workers)
diff --git a/train_student.py b/train_student.py
index e9067d9..33501e9 100644
--- a/train_student.py
+++ b/train_student.py
@@ -127,7 +127,7 @@ def get_power_loss_torch(y, y1, n_fft=1024, hop_length=256, cuda=True):
s = torch.stft(x, n_fft, hop_length, window=torch.hann_window(n_fft, periodic=True).cuda())
s1 = torch.stft(x1, n_fft, hop_length, window=torch.hann_window(n_fft, periodic=True).cuda())
ss = torch.log(torch.sqrt(torch.sum(s ** 2, -1) + 1e-5)) - torch.log(torch.sqrt(torch.sum(s1 ** 2, -1) + 1e-5))
- return torch.sum(ss**2)/batch
+ return torch.sum(ss ** 2) / batch
def to_numpy(x):
@@ -330,44 +330,39 @@ def __train_step(phase, epoch, global_step, global_test_step,
mask = sequence_mask(input_lengths, max_len=x.size(-1)).unsqueeze(-1)
mask = mask[:, 1:, :]
# apply the student model with stacked iaf layers and return mu,scale
- # u = Variable(torch.from_numpy(np.random.uniform(1e-5, 1 - 1e-5, x.size())).float().cuda(), requires_grad=False)
- # z = torch.log(u) - torch.log(1 - u)
u = Variable(torch.zeros(*x.size()).uniform_(1e-5, 1 - 1e-5), requires_grad=False).cuda()
z = torch.log(u) - torch.log(1 - u)
predict, mu, scale = student(z, c=c, g=g, softmax=False)
m, s = mu, scale
# mu, scale = to_numpy(mu), to_numpy(scale)
# TODO sample times, change to 300 or 400
- sample_T, kl_loss_sum = 16, 0
+ sample_T, kl_loss_sum = 8, 0
power_loss_sum = 0
- y_hat = teacher(predict, c=c, g=g) # y_hat: (B x C x T) teacher: 10-mixture-logistic
+ y_hat = teacher(predict_, c=c, g=g) # y_hat: (B x C x T) teacher: 10-mixture-logistic
h_pt_ps = 0
# TODO add some constrain on scale ,we want it to be small?
- for i in range(sample_T):
- # https://en.wikipedia.org/wiki/Logistic_distribution
- u = Variable(torch.zeros(*x.size()).uniform_(1e-5,1-1e-5),requires_grad=False).cuda()
- z = torch.log(u) - torch.log(1 - u)
- student_predict = m + s * z # predicted wave
- # student_predict.clamp(-0.99, 0.99)
- student_predict = student_predict.permute(0, 2, 1)
- _, teacher_log_p = discretized_mix_logistic_loss(y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False)
- h_pt_ps += torch.sum(teacher_log_p * mask) / mask.sum()
- student_predict = student_predict.permute(0, 2, 1)
- power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=512, hop_length=128)
- power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=256, hop_length=64)
- power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=2048, hop_length=512)
- power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=1024, hop_length=256)
- power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=128, hop_length=32)
+ # https://en.wikipedia.org/wiki/Logistic_distribution
+ u = Variable(torch.zeros(*x.size()).uniform_(1e-5, 1 - 1e-5), requires_grad=False).cuda()
+ z = torch.log(u) - torch.log(1 - u)
+ student_predict = m + s * z # predicted wave
+ # student_predict.clamp(-0.99, 0.99)
+ student_predict = student_predict.permute(0, 2, 1)
+ _, teacher_log_p = discretized_mix_logistic_loss(y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False)
+ h_pt_ps += torch.sum(teacher_log_p * mask) / mask.sum()
+ student_predict = student_predict.permute(0, 2, 1)
+ power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=512, hop_length=128)
+ power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=256, hop_length=64)
+ power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=2048, hop_length=512)
+ power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=4096, hop_length=1024)
+ power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=1024, hop_length=256)
+ power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=128, hop_length=32)
+ power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=64, hop_length=16)
a = s.permute(0, 2, 1)
- h_ps = torch.sum((torch.log(a[:, 1:, :]) + 2) * mask) / ( mask.sum())
- cross_entropy = h_pt_ps /(sample_T)
- kl_loss = cross_entropy - 2*h_ps
- # power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=64)
- # power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=128)
- # power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=256)
- # power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=512)
- power_loss = power_loss_sum / (5 * sample_T)
- loss = kl_loss + power_loss
+ h_ps = torch.sum((torch.log(a[:, 1:, :]) + 2) * mask) / (mask.sum())
+ cross_entropy = h_pt_ps / (sample_T)
+ kl_loss = cross_entropy - h_ps
+ power_loss = power_loss_sum / (7 * sample_T)
+ loss = kl_loss + power_loss
if step > 0 and step % 20 == 0:
print('power_loss={}, mean_scale={}, mean_mu={},kl_loss={},loss={}'.format(to_numpy(power_loss),
np.mean(to_numpy(s)),
@@ -501,7 +496,7 @@ def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, ema=None):
print("Saved averaged checkpoint:", checkpoint_path)
-def build_model(name='teacher'):
+def build_model(name='teacher', use_group_norm=False):
if is_mulaw_quantize(hparams.input_type):
if hparams.out_channels != hparams.quantize_channels:
raise RuntimeError(
@@ -536,6 +531,7 @@ def build_model(name='teacher'):
gin_channels=hparams.gin_channels,
upsample_conditional_features=hparams.upsample_conditional_features,
upsample_scales=hparams.upsample_scales,
+ use_group_norm=use_group_norm
)
@@ -629,8 +625,8 @@ def get_data_loaders(data_root, speaker_id, test_shuffle=True):
"--checkpoint-dir": 'checkpoints_student',
"--checkpoint_teacher": './checkpoints_teacher/20180127_mixture_lj_checkpoint_step000410000_ema.pth',
# the pre-trained teacher model
- "--checkpoint_student": '/home/jinqiangzeng/work/pycharm/P_wavenet_vocoder/checkpoints_student/checkpoint_step000056000.pth', # 是否加载
- #"--checkpoint_student": None, # 是否加载
+ "--checkpoint_student": './checkpoints_student/checkpoint_step000009000.pth', # 是否加载
+ # "--checkpoint_student": None, # 是否加载
"--checkpoint": None,
"--restore-parts": None,
"--data-root": './data/ljspeech', # dataset
@@ -638,8 +634,8 @@ def get_data_loaders(data_root, speaker_id, test_shuffle=True):
"--speaker-id": None,
"--reset-optimizer": None,
"--hparams": "cin_channels=80,gin_channels=-1",
- "--gpu": 0 # 指定gpu
-
+ "--gpu": 0, # 指定gpu
+ "--use_group_norm": True
}
print("Command line args:\n", args)
checkpoint_dir = args["--checkpoint-dir"]
@@ -649,7 +645,7 @@ def get_data_loaders(data_root, speaker_id, test_shuffle=True):
checkpoint_restore_parts = args["--restore-parts"]
speaker_id = args["--speaker-id"]
speaker_id = int(speaker_id) if speaker_id is not None else None
-
+ use_group_norm = args['--use_group_norm']
data_root = args["--data-root"]
if data_root is None:
data_root = join(dirname(__file__), "data", "ljspeech")
@@ -677,7 +673,10 @@ def get_data_loaders(data_root, speaker_id, test_shuffle=True):
data_loaders = get_data_loaders(data_root, speaker_id, test_shuffle=True)
# Model
- student_model = build_model(name='student')
+ if use_group_norm:
+ student_model = build_model(name='student', use_group_norm=True)
+ else:
+ student_model = build_model(name='student')
teacher_model = build_model(name='teacher')
if use_cuda:
diff --git a/wavenet_vocoder/student_wavenet.py b/wavenet_vocoder/student_wavenet.py
index 6af96de..c80ea34 100644
--- a/wavenet_vocoder/student_wavenet.py
+++ b/wavenet_vocoder/student_wavenet.py
@@ -69,6 +69,8 @@ def __init__(self, out_channels=2,
freq_axis_kernel_size=3,
scalar_input=True,
gpu=0,
+ use_group_norm = False,
+ group_size=2
):
super(StudentWaveNet, self).__init__()
self.scalar_input = scalar_input
@@ -76,14 +78,32 @@ def __init__(self, out_channels=2,
self.cin_channels = cin_channels
self.gpu = gpu
self.last_layers = []
+ self.use_group_norm = use_group_norm
+ self.group_size = group_size
# 噪声
assert layers % stacks == 0
layers_per_stack = layers // stacks
if scalar_input:
self.first_conv = nn.ModuleList([Conv1d1x1(1, residual_channels) for _ in range(len(iaf_layer_size))])
+ if use_group_norm:
+ first_layer = nn.ModuleList()
+ for _ in range(len(iaf_layer_size)):
+ first_layer.append(nn.ModuleList([
+ Conv1d1x1(1, residual_channels),
+ nn.GroupNorm(self.group_size,residual_channels)
+ ]))
+ self.first_conv = first_layer
else:
self.first_conv = nn.ModuleList(
[Conv1d1x1(out_channels, residual_channels) for _ in range(len(iaf_layer_size))])
+ if use_group_norm:
+ first_layer = nn.ModuleList()
+ for _ in range(len(iaf_layer_size)):
+ first_layer.append(nn.ModuleList([
+ Conv1d1x1(out_channels, residual_channels),
+ nn.GroupNorm(self.group_size,residual_channels)
+ ]))
+ self.first_conv = first_layer
self.iaf_layers = nn.ModuleList() # iaf层
self.last_layers = nn.ModuleList()
@@ -102,12 +122,15 @@ def __init__(self, out_channels=2,
gin_channels=gin_channels,
weight_normalization=weight_normalization)
iaf_layer.append(conv)
+ if self.use_group_norm:
+ iaf_layer.append(nn.GroupNorm(self.group_size,residual_channels))
self.iaf_layers.append(iaf_layer)
self.last_layers.append(nn.ModuleList([ # iaf的最后一层
+ nn.ReLU(inplace=True),
+ Conv1d1x1(residual_channels, residual_channels, weight_normalization=weight_normalization),
+ nn.GroupNorm(self.group_size,residual_channels),
nn.ReLU(inplace=True),
Conv1d1x1(residual_channels, out_channels, weight_normalization=weight_normalization),
- # nn.ReLU(inplace=True),
- # Conv1d1x1(residual_channels, out_channels, weight_normalization=weight_normalization),
]))
if gin_channels > 0:
@@ -178,9 +201,17 @@ def forward(self, z, c=None, g=None, softmax=False):
m = []
index = 0
for first_con, iaf, last_layer in zip(self.first_conv, self.iaf_layers, self.last_layers): # iaf layer forward
- new_z = first_con(z)
+ new_z = z
+ for f in first_con:
+ new_z = f(new_z)
for f in iaf:
- new_z, h = f(new_z, c, g_bct)
+ if self.use_group_norm:
+ if 'GroupNorm' in str(f):
+ new_z = f(new_z)
+ else:
+ new_z, h = f(new_z, c, g_bct)
+ else:
+ new_z, h = f(new_z, c, g_bct)
for f in last_layer:
new_z = f(new_z)
mu_f, scale_f = new_z[:, :1, :], torch.exp(new_z[:, 1:, :])