Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .idea/P_wavenet_vocoder.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/libraries/R_User_Library.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

458 changes: 397 additions & 61 deletions .idea/workspace.xml

Large diffs are not rendered by default.

Binary file removed generate.zip
Binary file not shown.
2 changes: 1 addition & 1 deletion preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def write_metadata(metadata, out_dir):
in_dir = '/home/jinqiangzeng/work/data/speech/cmu_arctic'
out_dir = './data/cmu_arctic'
elif name == 'ljspeech':
in_dir = '/home/jinqiangzeng/work/data/speech/ljspeech/LJSpeech-1.0'
in_dir = '/home/tesla/work/data/LJSpeech-1.0'
out_dir = './data/ljspeech'
num_workers = None # args["--num_workers"]
num_workers = cpu_count() - 1 if num_workers is None else int(num_workers)
Expand Down
71 changes: 35 additions & 36 deletions train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_power_loss_torch(y, y1, n_fft=1024, hop_length=256, cuda=True):
s = torch.stft(x, n_fft, hop_length, window=torch.hann_window(n_fft, periodic=True).cuda())
s1 = torch.stft(x1, n_fft, hop_length, window=torch.hann_window(n_fft, periodic=True).cuda())
ss = torch.log(torch.sqrt(torch.sum(s ** 2, -1) + 1e-5)) - torch.log(torch.sqrt(torch.sum(s1 ** 2, -1) + 1e-5))
return torch.sum(ss**2)/batch
return torch.sum(ss ** 2) / batch


def to_numpy(x):
Expand Down Expand Up @@ -330,44 +330,39 @@ def __train_step(phase, epoch, global_step, global_test_step,
mask = sequence_mask(input_lengths, max_len=x.size(-1)).unsqueeze(-1)
mask = mask[:, 1:, :]
# apply the student model with stacked iaf layers and return mu,scale
# u = Variable(torch.from_numpy(np.random.uniform(1e-5, 1 - 1e-5, x.size())).float().cuda(), requires_grad=False)
# z = torch.log(u) - torch.log(1 - u)
u = Variable(torch.zeros(*x.size()).uniform_(1e-5, 1 - 1e-5), requires_grad=False).cuda()
z = torch.log(u) - torch.log(1 - u)
predict, mu, scale = student(z, c=c, g=g, softmax=False)
m, s = mu, scale
# mu, scale = to_numpy(mu), to_numpy(scale)
# TODO sample times, change to 300 or 400
sample_T, kl_loss_sum = 16, 0
sample_T, kl_loss_sum = 8, 0
power_loss_sum = 0
y_hat = teacher(predict, c=c, g=g) # y_hat: (B x C x T) teacher: 10-mixture-logistic
y_hat = teacher(predict_, c=c, g=g) # y_hat: (B x C x T) teacher: 10-mixture-logistic
h_pt_ps = 0
# TODO add some constrain on scale ,we want it to be small?
for i in range(sample_T):
# https://en.wikipedia.org/wiki/Logistic_distribution
u = Variable(torch.zeros(*x.size()).uniform_(1e-5,1-1e-5),requires_grad=False).cuda()
z = torch.log(u) - torch.log(1 - u)
student_predict = m + s * z # predicted wave
# student_predict.clamp(-0.99, 0.99)
student_predict = student_predict.permute(0, 2, 1)
_, teacher_log_p = discretized_mix_logistic_loss(y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False)
h_pt_ps += torch.sum(teacher_log_p * mask) / mask.sum()
student_predict = student_predict.permute(0, 2, 1)
power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=512, hop_length=128)
power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=256, hop_length=64)
power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=2048, hop_length=512)
power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=1024, hop_length=256)
power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=128, hop_length=32)
# https://en.wikipedia.org/wiki/Logistic_distribution
u = Variable(torch.zeros(*x.size()).uniform_(1e-5, 1 - 1e-5), requires_grad=False).cuda()
z = torch.log(u) - torch.log(1 - u)
student_predict = m + s * z # predicted wave
# student_predict.clamp(-0.99, 0.99)
student_predict = student_predict.permute(0, 2, 1)
_, teacher_log_p = discretized_mix_logistic_loss(y_hat[:, :, :-1], student_predict[:, 1:, :], reduce=False)
h_pt_ps += torch.sum(teacher_log_p * mask) / mask.sum()
student_predict = student_predict.permute(0, 2, 1)
power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=512, hop_length=128)
power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=256, hop_length=64)
power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=2048, hop_length=512)
power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=4096, hop_length=1024)
power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=1024, hop_length=256)
power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=128, hop_length=32)
power_loss_sum += get_power_loss_torch(student_predict, x, n_fft=64, hop_length=16)
a = s.permute(0, 2, 1)
h_ps = torch.sum((torch.log(a[:, 1:, :]) + 2) * mask) / ( mask.sum())
cross_entropy = h_pt_ps /(sample_T)
kl_loss = cross_entropy - 2*h_ps
# power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=64)
# power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=128)
# power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=256)
# power_loss_sum += get_power_loss_torch(predict, x, n_fft=1024, hop_length=512)
power_loss = power_loss_sum / (5 * sample_T)
loss = kl_loss + power_loss
h_ps = torch.sum((torch.log(a[:, 1:, :]) + 2) * mask) / (mask.sum())
cross_entropy = h_pt_ps / (sample_T)
kl_loss = cross_entropy - h_ps
power_loss = power_loss_sum / (7 * sample_T)
loss = kl_loss + power_loss
if step > 0 and step % 20 == 0:
print('power_loss={}, mean_scale={}, mean_mu={},kl_loss={},loss={}'.format(to_numpy(power_loss),
np.mean(to_numpy(s)),
Expand Down Expand Up @@ -501,7 +496,7 @@ def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch, ema=None):
print("Saved averaged checkpoint:", checkpoint_path)


def build_model(name='teacher'):
def build_model(name='teacher', use_group_norm=False):
if is_mulaw_quantize(hparams.input_type):
if hparams.out_channels != hparams.quantize_channels:
raise RuntimeError(
Expand Down Expand Up @@ -536,6 +531,7 @@ def build_model(name='teacher'):
gin_channels=hparams.gin_channels,
upsample_conditional_features=hparams.upsample_conditional_features,
upsample_scales=hparams.upsample_scales,
use_group_norm=use_group_norm
)


Expand Down Expand Up @@ -629,17 +625,17 @@ def get_data_loaders(data_root, speaker_id, test_shuffle=True):
"--checkpoint-dir": 'checkpoints_student',
"--checkpoint_teacher": './checkpoints_teacher/20180127_mixture_lj_checkpoint_step000410000_ema.pth',
# the pre-trained teacher model
"--checkpoint_student": '/home/jinqiangzeng/work/pycharm/P_wavenet_vocoder/checkpoints_student/checkpoint_step000056000.pth', # 是否加载
#"--checkpoint_student": None, # 是否加载
"--checkpoint_student": './checkpoints_student/checkpoint_step000009000.pth', # 是否加载
# "--checkpoint_student": None, # 是否加载
"--checkpoint": None,
"--restore-parts": None,
"--data-root": './data/ljspeech', # dataset
"--log-event-path": None, # if continue training, reload the checkpoint
"--speaker-id": None,
"--reset-optimizer": None,
"--hparams": "cin_channels=80,gin_channels=-1",
"--gpu": 0 # 指定gpu

"--gpu": 0, # 指定gpu
"--use_group_norm": True
}
print("Command line args:\n", args)
checkpoint_dir = args["--checkpoint-dir"]
Expand All @@ -649,7 +645,7 @@ def get_data_loaders(data_root, speaker_id, test_shuffle=True):
checkpoint_restore_parts = args["--restore-parts"]
speaker_id = args["--speaker-id"]
speaker_id = int(speaker_id) if speaker_id is not None else None

use_group_norm = args['--use_group_norm']
data_root = args["--data-root"]
if data_root is None:
data_root = join(dirname(__file__), "data", "ljspeech")
Expand Down Expand Up @@ -677,7 +673,10 @@ def get_data_loaders(data_root, speaker_id, test_shuffle=True):
data_loaders = get_data_loaders(data_root, speaker_id, test_shuffle=True)

# Model
student_model = build_model(name='student')
if use_group_norm:
student_model = build_model(name='student', use_group_norm=True)
else:
student_model = build_model(name='student')
teacher_model = build_model(name='teacher')

if use_cuda:
Expand Down
39 changes: 35 additions & 4 deletions wavenet_vocoder/student_wavenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,41 @@ def __init__(self, out_channels=2,
freq_axis_kernel_size=3,
scalar_input=True,
gpu=0,
use_group_norm = False,
group_size=2
):
super(StudentWaveNet, self).__init__()
self.scalar_input = scalar_input
self.out_channels = out_channels
self.cin_channels = cin_channels
self.gpu = gpu
self.last_layers = []
self.use_group_norm = use_group_norm
self.group_size = group_size
# 噪声
assert layers % stacks == 0
layers_per_stack = layers // stacks
if scalar_input:
self.first_conv = nn.ModuleList([Conv1d1x1(1, residual_channels) for _ in range(len(iaf_layer_size))])
if use_group_norm:
first_layer = nn.ModuleList()
for _ in range(len(iaf_layer_size)):
first_layer.append(nn.ModuleList([
Conv1d1x1(1, residual_channels),
nn.GroupNorm(self.group_size,residual_channels)
]))
self.first_conv = first_layer
else:
self.first_conv = nn.ModuleList(
[Conv1d1x1(out_channels, residual_channels) for _ in range(len(iaf_layer_size))])
if use_group_norm:
first_layer = nn.ModuleList()
for _ in range(len(iaf_layer_size)):
first_layer.append(nn.ModuleList([
Conv1d1x1(out_channels, residual_channels),
nn.GroupNorm(self.group_size,residual_channels)
]))
self.first_conv = first_layer
self.iaf_layers = nn.ModuleList() # iaf层
self.last_layers = nn.ModuleList()

Expand All @@ -102,12 +122,15 @@ def __init__(self, out_channels=2,
gin_channels=gin_channels,
weight_normalization=weight_normalization)
iaf_layer.append(conv)
if self.use_group_norm:
iaf_layer.append(nn.GroupNorm(self.group_size,residual_channels))
self.iaf_layers.append(iaf_layer)
self.last_layers.append(nn.ModuleList([ # iaf的最后一层
nn.ReLU(inplace=True),
Conv1d1x1(residual_channels, residual_channels, weight_normalization=weight_normalization),
nn.GroupNorm(self.group_size,residual_channels),
nn.ReLU(inplace=True),
Conv1d1x1(residual_channels, out_channels, weight_normalization=weight_normalization),
# nn.ReLU(inplace=True),
# Conv1d1x1(residual_channels, out_channels, weight_normalization=weight_normalization),
]))

if gin_channels > 0:
Expand Down Expand Up @@ -178,9 +201,17 @@ def forward(self, z, c=None, g=None, softmax=False):
m = []
index = 0
for first_con, iaf, last_layer in zip(self.first_conv, self.iaf_layers, self.last_layers): # iaf layer forward
new_z = first_con(z)
new_z = z
for f in first_con:
new_z = f(new_z)
for f in iaf:
new_z, h = f(new_z, c, g_bct)
if self.use_group_norm:
if 'GroupNorm' in str(f):
new_z = f(new_z)
else:
new_z, h = f(new_z, c, g_bct)
else:
new_z, h = f(new_z, c, g_bct)
for f in last_layer:
new_z = f(new_z)
mu_f, scale_f = new_z[:, :1, :], torch.exp(new_z[:, 1:, :])
Expand Down