-
Notifications
You must be signed in to change notification settings - Fork 436
/
Copy pathnetworks.py
133 lines (102 loc) · 5.21 KB
/
networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# -*- coding: utf-8 -*-
#/usr/bin/python2
'''
By kyubyong park. [email protected].
https://www.github.com/kyubyong/tacotron
'''
from __future__ import print_function
from hyperparams import Hyperparams as hp
from modules import *
import tensorflow as tf
def encoder(inputs, is_training=True, scope="encoder", reuse=None):
'''
Args:
inputs: A 2d tensor with shape of [N, T_x, E], with dtype of int32. Encoder inputs.
is_training: Whether or not the layer is in training mode.
scope: Optional scope for `variable_scope`
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Returns:
A collection of Hidden vectors. So-called memory. Has the shape of (N, T_x, E).
'''
with tf.variable_scope(scope, reuse=reuse):
# Encoder pre-net
prenet_out = prenet(inputs, is_training=is_training) # (N, T_x, E/2)
# Encoder CBHG
## Conv1D banks
enc = conv1d_banks(prenet_out, K=hp.encoder_num_banks, is_training=is_training) # (N, T_x, K*E/2)
## Max pooling
enc = tf.layers.max_pooling1d(enc, pool_size=2, strides=1, padding="same") # (N, T_x, K*E/2)
## Conv1D projections
enc = conv1d(enc, filters=hp.embed_size//2, size=3, scope="conv1d_1") # (N, T_x, E/2)
enc = bn(enc, is_training=is_training, activation_fn=tf.nn.relu, scope="conv1d_1")
enc = conv1d(enc, filters=hp.embed_size // 2, size=3, scope="conv1d_2") # (N, T_x, E/2)
enc = bn(enc, is_training=is_training, scope="conv1d_2")
enc += prenet_out # (N, T_x, E/2) # residual connections
## Highway Nets
for i in range(hp.num_highwaynet_blocks):
enc = highwaynet(enc, num_units=hp.embed_size//2,
scope='highwaynet_{}'.format(i)) # (N, T_x, E/2)
## Bidirectional GRU
memory = gru(enc, num_units=hp.embed_size//2, bidirection=True) # (N, T_x, E)
return memory
def decoder1(inputs, memory, is_training=True, scope="decoder1", reuse=None):
'''
Args:
inputs: A 3d tensor with shape of [N, T_y/r, n_mels(*r)]. Shifted log melspectrogram of sound files.
memory: A 3d tensor with shape of [N, T_x, E].
is_training: Whether or not the layer is in training mode.
scope: Optional scope for `variable_scope`
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Returns
Predicted log melspectrogram tensor with shape of [N, T_y/r, n_mels*r].
'''
with tf.variable_scope(scope, reuse=reuse):
# Decoder pre-net
inputs = prenet(inputs, is_training=is_training) # (N, T_y/r, E/2)
# Attention RNN
dec, state = attention_decoder(inputs, memory, num_units=hp.embed_size) # (N, T_y/r, E)
## for attention monitoring
alignments = tf.transpose(state.alignment_history.stack(),[1,2,0])
# Decoder RNNs
dec += gru(dec, hp.embed_size, bidirection=False, scope="decoder_gru1") # (N, T_y/r, E)
dec += gru(dec, hp.embed_size, bidirection=False, scope="decoder_gru2") # (N, T_y/r, E)
# Outputs => (N, T_y/r, n_mels*r)
mel_hats = tf.layers.dense(dec, hp.n_mels*hp.r)
return mel_hats, alignments
def decoder2(inputs, is_training=True, scope="decoder2", reuse=None):
'''Decoder Post-processing net = CBHG
Args:
inputs: A 3d tensor with shape of [N, T_y/r, n_mels*r]. Log magnitude spectrogram of sound files.
It is recovered to its original shape.
is_training: Whether or not the layer is in training mode.
scope: Optional scope for `variable_scope`
reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Returns
Predicted linear spectrogram tensor with shape of [N, T_y, 1+n_fft//2].
'''
with tf.variable_scope(scope, reuse=reuse):
# Restore shape -> (N, Ty, n_mels)
inputs = tf.reshape(inputs, [tf.shape(inputs)[0], -1, hp.n_mels])
# Conv1D bank
dec = conv1d_banks(inputs, K=hp.decoder_num_banks, is_training=is_training) # (N, T_y, E*K/2)
# Max pooling
dec = tf.layers.max_pooling1d(dec, pool_size=2, strides=1, padding="same") # (N, T_y, E*K/2)
## Conv1D projections
dec = conv1d(dec, filters=hp.embed_size // 2, size=3, scope="conv1d_1") # (N, T_x, E/2)
dec = bn(dec, is_training=is_training, activation_fn=tf.nn.relu, scope="conv1d_1")
dec = conv1d(dec, filters=hp.n_mels, size=3, scope="conv1d_2") # (N, T_x, E/2)
dec = bn(dec, is_training=is_training, scope="conv1d_2")
# Extra affine transformation for dimensionality sync
dec = tf.layers.dense(dec, hp.embed_size//2) # (N, T_y, E/2)
# Highway Nets
for i in range(4):
dec = highwaynet(dec, num_units=hp.embed_size//2,
scope='highwaynet_{}'.format(i)) # (N, T_y, E/2)
# Bidirectional GRU
dec = gru(dec, hp.embed_size//2, bidirection=True) # (N, T_y, E)
# Outputs => (N, T_y, 1+n_fft//2)
outputs = tf.layers.dense(dec, 1+hp.n_fft//2)
return outputs