forked from junhwanjang/visemenet-inference
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- clean the original code - be compatible with tf 2 version
- Loading branch information
1 parent
6e9b684
commit 2d29fc1
Showing
3 changed files
with
317 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import tensorflow as tf | ||
import numpy as np | ||
import scipy.io.wavfile as wav | ||
from python_speech_features import logfbank, mfcc, ssc | ||
|
||
|
||
class VisemeRegressor(object): | ||
def __init__(self, pb_filepath): | ||
# Load forzen graph | ||
self.pb_filepath = pb_filepath | ||
self.graph = self._load_graph(self.pb_filepath) | ||
|
||
# Define Hpyer-params | ||
## Sampling | ||
self.fps = 25 | ||
self.mfcc_win_step_per_frame = 1 | ||
self.up_sample_rate = 4 | ||
self.win_length = 0.025 | ||
self.winstep = 1.0 / self.fps / self.mfcc_win_step_per_frame / self.up_sample_rate | ||
self.window_size = 24 | ||
|
||
## Num Signal features | ||
self.num_mfcc = 13 | ||
self.num_logfbank = 26 | ||
self.num_ssc = 26 | ||
self.num_total_features = 65 | ||
|
||
## Model Params | ||
self.n_steps = 8 | ||
self.n_input = int(self.num_total_features * self.mfcc_win_step_per_frame * self.window_size / self.n_steps) | ||
self.n_landmark = 76 | ||
self.n_face_id = 76 | ||
self.n_phoneme = 21 | ||
self.n_maya_params = 22 | ||
|
||
def predict_outputs(self, wav_file_path, mean_std_csv_path='./saved_params/wav_mean_std.csv', close_face_txt_path='./saved_params/maya_close_face.txt'): | ||
# Define Input | ||
## Preprocess wav file | ||
concat_feat = self._preprocess_wav( | ||
wav_file_path=wav_file_path, is_debug=False | ||
) | ||
normalized_feat = self._normalize_input( | ||
concat_features=concat_feat, mean_std_csv_path=mean_std_csv_path | ||
) | ||
target_wav_idxs = self._get_padded_indexes( | ||
normalized_feat=normalized_feat, window_size=self.window_size | ||
) | ||
## Prepare model input | ||
batch_size = concat_feat.shape[0] # Num Frames | ||
batch_x, batch_x_face_id = self._prepare_model_input( | ||
normalized_feat=normalized_feat, | ||
target_wav_idxs=target_wav_idxs, | ||
batch_size=batch_size, | ||
close_face_txt_path=close_face_txt_path | ||
) | ||
|
||
# Predict Outputs | ||
## Input nodes | ||
x = self.graph.get_tensor_by_name('input/Placeholder_1:0') | ||
x_face_id = self.graph.get_tensor_by_name('input/Placeholder_2:0') | ||
phase = self.graph.get_tensor_by_name('input/phase:0') | ||
dropout = self.graph.get_tensor_by_name('net1_shared_rnn/Placeholder:0') | ||
|
||
## Output nodes | ||
v_cls = self.graph.get_tensor_by_name('net2_output/add_1:0') | ||
v_reg = self.graph.get_tensor_by_name('net2_output/add_4:0') | ||
jali = self.graph.get_tensor_by_name('net2_output/add_6:0') | ||
|
||
with tf.compat.v1.Session(graph=self.graph) as sess: | ||
pred_v_cls, pred_v_reg, pred_jali = sess.run( | ||
[v_cls, v_reg, jali], | ||
feed_dict={ | ||
x: batch_x, | ||
x_face_id: batch_x_face_id, | ||
dropout: 0, phase: 0 | ||
} | ||
) | ||
pred_v_cls = self.sigmoid(pred_v_cls) | ||
|
||
return pred_jali, pred_v_reg, pred_v_cls | ||
|
||
|
||
def _prepare_model_input(self, normalized_feat, target_wav_idxs, batch_size, close_face_txt_path): | ||
batch_x = np.zeros((batch_size, self.n_steps, self.n_input)) | ||
batch_x_face_id = np.zeros((batch_size, self.n_face_id)) | ||
# batch_x_pose = np.zeros((batch_size, 3)) | ||
# batch_y_landmark = np.zeros((batch_size, self.n_landmark)) | ||
# batch_y_phoneme = np.zeros((batch_size, self.n_phoneme)) | ||
# batch_y_lipS = np.zeros((batch_size, 1)) | ||
# batch_y_maya_param = np.zeros((batch_size, self.n_maya_params)) | ||
|
||
batch_x = normalized_feat[target_wav_idxs].reshape(-1, self.n_steps, self.n_input) | ||
|
||
close_face = np.loadtxt(close_face_txt_path) | ||
batch_x_face_id = np.tile(close_face, (batch_size, 1)) | ||
|
||
return batch_x, batch_x_face_id | ||
|
||
def _get_padded_indexes(self, normalized_feat, window_size): | ||
# Get Padded indexes based on the given window size | ||
num_frames = normalized_feat.shape[0] | ||
wav_idxs = [i for i in range(0, num_frames)] | ||
|
||
half_win_size = window_size // 2 | ||
pad_head = [0 for _ in range(half_win_size)] | ||
pad_tail = [wav_idxs[-1] for _ in range(half_win_size)] | ||
padded_idxs = np.array(pad_head + wav_idxs + pad_tail) | ||
|
||
target_wav_idxs = np.zeros(shape=(num_frames, window_size)) | ||
for i in range(0, num_frames): | ||
target_wav_idxs[i] = padded_idxs[i:i+window_size].reshape(num_frames, window_size) | ||
|
||
return target_wav_idxs | ||
|
||
def _normalize_input(self, concat_features, mean_std_csv_path): | ||
# Normalize input using the pre-calculated mean, std values | ||
num_features = self.num_mfcc + self.num_logfbank + self.num_ssc | ||
|
||
mean_std = np.loadtxt(mean_std_csv_path) | ||
mean_vals = mean_std[:num_features] | ||
std_vals = mean_std[num_features:] | ||
|
||
normalized_feat = (concat_features - mean_vals) / std_vals | ||
|
||
return normalized_feat | ||
|
||
def _preprocess_wav(self, wav_file_path, is_debug=False): | ||
sample_rate, signal = wav.read(wav_file_path) | ||
|
||
if (signal.ndim > 1): | ||
signal = signal[:, 0] | ||
|
||
# Get concatentated features | ||
## 1. mfcc_features | ||
mfcc_feat = mfcc( | ||
signal, numcep=self.num_mfcc, | ||
samplerate=sample_rate, | ||
winlen=self.win_length, winstep=self.winstep | ||
) | ||
|
||
## 2. logfbank_features | ||
logfbank_feat = logfbank( | ||
signal, nfilt=self.num_logfbank, | ||
samplerate=sample_rate, | ||
winlen=self.win_length, winstep=self.winstep | ||
) | ||
|
||
## 3. ssc_features | ||
ssc_feat = ssc( | ||
signal, nfilt=self.num_ssc, | ||
samplerate=sample_rate, | ||
winlen=self.win_length, winstep=self.winstep | ||
) | ||
|
||
concat_features = np.concatenate( | ||
[mfcc_feat, logfbank_feat, ssc_feat], axis=1 | ||
) | ||
|
||
if is_debug: | ||
print("Sample Rate: {}".format(sample_rate)) | ||
print("Signal Shape: {}".format(signal.shape)) | ||
print("") | ||
print("Collect Features") | ||
print("[mfcc feat shape]: {}".format(mfcc_feat.shape)) | ||
print("[logfbank feat shape]: {}".format(logfbank_feat.shape)) | ||
print("[ssc feat shape]: {}".format(ssc_feat.shape)) | ||
print("--> Concat Features Shape: {}".format(concat_features.shape)) | ||
|
||
return concat_features | ||
|
||
def _load_graph(self, pb_filepath): | ||
with tf.io.gfile.GFile(pb_filepath, 'rb') as f: | ||
graph_def = tf.compat.v1.GraphDef() | ||
graph_def.ParseFromString(f.read()) | ||
|
||
with tf.Graph().as_default() as graph: | ||
tf.import_graph_def(graph_def, name='') | ||
|
||
for op in graph.get_operations(): | ||
if op.type == 'Placeholder': | ||
print(op.name) | ||
|
||
return graph | ||
|
||
def sigmoid(self, x): | ||
return 1/(1+np.exp(-x)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
-0.63522 -0.63037 -0.43071 -0.80431 -0.21347 -0.95468 -0.00000 -0.98242 0.21347 -0.95468 0.43071 -0.80456 0.63522 -0.63037 -0.20890 0.05236 -0.11252 0.02298 0.00000 0.00000 0.11252 0.02211 0.20890 0.05198 -0.05739 -0.20821 -0.16401 -0.21268 -0.22784 -0.23076 -0.32045 -0.27540 -0.22764 -0.42029 -0.10126 -0.46787 -0.00000 -0.48066 0.10126 -0.46797 0.22764 -0.42056 0.33245 -0.27810 0.22784 -0.23076 0.16401 -0.21282 0.05739 -0.20839 0.00000 -0.23967 -0.17054 -0.30394 -0.28709 -0.28996 -0.20255 -0.30883 -0.08933 -0.32174 -0.00000 -0.32644 0.08933 -0.32183 0.20255 -0.30948 0.29165 -0.28881 0.17054 -0.30394 0.06859 -0.31126 -0.00000 -0.32164 -0.06859 -0.31113 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
16.68886 | ||
3.97479 | ||
-20.30682 | ||
5.80699 | ||
-15.43429 | ||
5.92388 | ||
-28.63684 | ||
4.14146 | ||
-19.75630 | ||
-1.72316 | ||
-10.24419 | ||
5.74067 | ||
-3.86981 | ||
9.38637 | ||
10.30149 | ||
11.62886 | ||
12.49301 | ||
12.46678 | ||
11.90186 | ||
11.84886 | ||
11.95487 | ||
12.28070 | ||
12.54658 | ||
12.54819 | ||
12.67694 | ||
12.96691 | ||
12.98765 | ||
12.97866 | ||
12.62430 | ||
12.24909 | ||
11.83247 | ||
11.93958 | ||
11.98562 | ||
12.08022 | ||
11.97421 | ||
11.62867 | ||
10.93464 | ||
8.95501 | ||
7.04081 | ||
87.12891 | ||
173.25781 | ||
292.47222 | ||
440.17648 | ||
596.93069 | ||
769.05987 | ||
974.09208 | ||
1206.45156 | ||
1504.64536 | ||
1807.83956 | ||
2177.74605 | ||
2595.55055 | ||
3045.21528 | ||
3550.92647 | ||
4109.56849 | ||
4749.31868 | ||
5456.19422 | ||
6408.52870 | ||
7440.71526 | ||
8554.42553 | ||
9794.58080 | ||
11161.64710 | ||
12757.65047 | ||
14405.03992 | ||
15984.00054 | ||
19289.26037 | ||
2.63713 | ||
15.46299 | ||
13.12569 | ||
17.44463 | ||
16.31907 | ||
16.44532 | ||
17.10880 | ||
17.89937 | ||
16.92003 | ||
15.19639 | ||
14.47684 | ||
13.30815 | ||
12.08762 | ||
2.62696 | ||
2.75437 | ||
2.83862 | ||
3.02088 | ||
3.12501 | ||
3.06590 | ||
2.99671 | ||
2.98436 | ||
2.99794 | ||
3.01999 | ||
2.97766 | ||
2.96631 | ||
3.07834 | ||
2.99264 | ||
2.93484 | ||
2.91424 | ||
2.81031 | ||
2.71607 | ||
2.81370 | ||
2.88491 | ||
2.96294 | ||
3.01507 | ||
3.03082 | ||
2.86959 | ||
2.59370 | ||
2.74709 | ||
1000.00000 | ||
1000.00000 | ||
26.24620 | ||
36.81842 | ||
38.78619 | ||
35.94704 | ||
41.87735 | ||
51.41814 | ||
68.31984 | ||
78.18270 | ||
85.86550 | ||
94.36517 | ||
99.69723 | ||
118.70094 | ||
136.14025 | ||
141.08837 | ||
179.08942 | ||
178.79265 | ||
180.67445 | ||
197.02229 | ||
204.96015 | ||
266.20027 | ||
254.49708 | ||
270.17589 | ||
540.48172 | ||
192.19627 |