Skip to content

Commit

Permalink
Add Model Inference module
Browse files Browse the repository at this point in the history
- clean the original code
- be compatible with tf 2 version
  • Loading branch information
junhwanjang committed Mar 3, 2022
1 parent 6e9b684 commit 2d29fc1
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 0 deletions.
186 changes: 186 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import tensorflow as tf
import numpy as np
import scipy.io.wavfile as wav
from python_speech_features import logfbank, mfcc, ssc


class VisemeRegressor(object):
def __init__(self, pb_filepath):
# Load forzen graph
self.pb_filepath = pb_filepath
self.graph = self._load_graph(self.pb_filepath)

# Define Hpyer-params
## Sampling
self.fps = 25
self.mfcc_win_step_per_frame = 1
self.up_sample_rate = 4
self.win_length = 0.025
self.winstep = 1.0 / self.fps / self.mfcc_win_step_per_frame / self.up_sample_rate
self.window_size = 24

## Num Signal features
self.num_mfcc = 13
self.num_logfbank = 26
self.num_ssc = 26
self.num_total_features = 65

## Model Params
self.n_steps = 8
self.n_input = int(self.num_total_features * self.mfcc_win_step_per_frame * self.window_size / self.n_steps)
self.n_landmark = 76
self.n_face_id = 76
self.n_phoneme = 21
self.n_maya_params = 22

def predict_outputs(self, wav_file_path, mean_std_csv_path='./saved_params/wav_mean_std.csv', close_face_txt_path='./saved_params/maya_close_face.txt'):
# Define Input
## Preprocess wav file
concat_feat = self._preprocess_wav(
wav_file_path=wav_file_path, is_debug=False
)
normalized_feat = self._normalize_input(
concat_features=concat_feat, mean_std_csv_path=mean_std_csv_path
)
target_wav_idxs = self._get_padded_indexes(
normalized_feat=normalized_feat, window_size=self.window_size
)
## Prepare model input
batch_size = concat_feat.shape[0] # Num Frames
batch_x, batch_x_face_id = self._prepare_model_input(
normalized_feat=normalized_feat,
target_wav_idxs=target_wav_idxs,
batch_size=batch_size,
close_face_txt_path=close_face_txt_path
)

# Predict Outputs
## Input nodes
x = self.graph.get_tensor_by_name('input/Placeholder_1:0')
x_face_id = self.graph.get_tensor_by_name('input/Placeholder_2:0')
phase = self.graph.get_tensor_by_name('input/phase:0')
dropout = self.graph.get_tensor_by_name('net1_shared_rnn/Placeholder:0')

## Output nodes
v_cls = self.graph.get_tensor_by_name('net2_output/add_1:0')
v_reg = self.graph.get_tensor_by_name('net2_output/add_4:0')
jali = self.graph.get_tensor_by_name('net2_output/add_6:0')

with tf.compat.v1.Session(graph=self.graph) as sess:
pred_v_cls, pred_v_reg, pred_jali = sess.run(
[v_cls, v_reg, jali],
feed_dict={
x: batch_x,
x_face_id: batch_x_face_id,
dropout: 0, phase: 0
}
)
pred_v_cls = self.sigmoid(pred_v_cls)

return pred_jali, pred_v_reg, pred_v_cls


def _prepare_model_input(self, normalized_feat, target_wav_idxs, batch_size, close_face_txt_path):
batch_x = np.zeros((batch_size, self.n_steps, self.n_input))
batch_x_face_id = np.zeros((batch_size, self.n_face_id))
# batch_x_pose = np.zeros((batch_size, 3))
# batch_y_landmark = np.zeros((batch_size, self.n_landmark))
# batch_y_phoneme = np.zeros((batch_size, self.n_phoneme))
# batch_y_lipS = np.zeros((batch_size, 1))
# batch_y_maya_param = np.zeros((batch_size, self.n_maya_params))

batch_x = normalized_feat[target_wav_idxs].reshape(-1, self.n_steps, self.n_input)

close_face = np.loadtxt(close_face_txt_path)
batch_x_face_id = np.tile(close_face, (batch_size, 1))

return batch_x, batch_x_face_id

def _get_padded_indexes(self, normalized_feat, window_size):
# Get Padded indexes based on the given window size
num_frames = normalized_feat.shape[0]
wav_idxs = [i for i in range(0, num_frames)]

half_win_size = window_size // 2
pad_head = [0 for _ in range(half_win_size)]
pad_tail = [wav_idxs[-1] for _ in range(half_win_size)]
padded_idxs = np.array(pad_head + wav_idxs + pad_tail)

target_wav_idxs = np.zeros(shape=(num_frames, window_size))
for i in range(0, num_frames):
target_wav_idxs[i] = padded_idxs[i:i+window_size].reshape(num_frames, window_size)

return target_wav_idxs

def _normalize_input(self, concat_features, mean_std_csv_path):
# Normalize input using the pre-calculated mean, std values
num_features = self.num_mfcc + self.num_logfbank + self.num_ssc

mean_std = np.loadtxt(mean_std_csv_path)
mean_vals = mean_std[:num_features]
std_vals = mean_std[num_features:]

normalized_feat = (concat_features - mean_vals) / std_vals

return normalized_feat

def _preprocess_wav(self, wav_file_path, is_debug=False):
sample_rate, signal = wav.read(wav_file_path)

if (signal.ndim > 1):
signal = signal[:, 0]

# Get concatentated features
## 1. mfcc_features
mfcc_feat = mfcc(
signal, numcep=self.num_mfcc,
samplerate=sample_rate,
winlen=self.win_length, winstep=self.winstep
)

## 2. logfbank_features
logfbank_feat = logfbank(
signal, nfilt=self.num_logfbank,
samplerate=sample_rate,
winlen=self.win_length, winstep=self.winstep
)

## 3. ssc_features
ssc_feat = ssc(
signal, nfilt=self.num_ssc,
samplerate=sample_rate,
winlen=self.win_length, winstep=self.winstep
)

concat_features = np.concatenate(
[mfcc_feat, logfbank_feat, ssc_feat], axis=1
)

if is_debug:
print("Sample Rate: {}".format(sample_rate))
print("Signal Shape: {}".format(signal.shape))
print("")
print("Collect Features")
print("[mfcc feat shape]: {}".format(mfcc_feat.shape))
print("[logfbank feat shape]: {}".format(logfbank_feat.shape))
print("[ssc feat shape]: {}".format(ssc_feat.shape))
print("--> Concat Features Shape: {}".format(concat_features.shape))

return concat_features

def _load_graph(self, pb_filepath):
with tf.io.gfile.GFile(pb_filepath, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())

with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')

for op in graph.get_operations():
if op.type == 'Placeholder':
print(op.name)

return graph

def sigmoid(self, x):
return 1/(1+np.exp(-x))
1 change: 1 addition & 0 deletions saved_params/maya_close_face.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
-0.63522 -0.63037 -0.43071 -0.80431 -0.21347 -0.95468 -0.00000 -0.98242 0.21347 -0.95468 0.43071 -0.80456 0.63522 -0.63037 -0.20890 0.05236 -0.11252 0.02298 0.00000 0.00000 0.11252 0.02211 0.20890 0.05198 -0.05739 -0.20821 -0.16401 -0.21268 -0.22784 -0.23076 -0.32045 -0.27540 -0.22764 -0.42029 -0.10126 -0.46787 -0.00000 -0.48066 0.10126 -0.46797 0.22764 -0.42056 0.33245 -0.27810 0.22784 -0.23076 0.16401 -0.21282 0.05739 -0.20839 0.00000 -0.23967 -0.17054 -0.30394 -0.28709 -0.28996 -0.20255 -0.30883 -0.08933 -0.32174 -0.00000 -0.32644 0.08933 -0.32183 0.20255 -0.30948 0.29165 -0.28881 0.17054 -0.30394 0.06859 -0.31126 -0.00000 -0.32164 -0.06859 -0.31113
130 changes: 130 additions & 0 deletions saved_params/wav_mean_std.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
16.68886
3.97479
-20.30682
5.80699
-15.43429
5.92388
-28.63684
4.14146
-19.75630
-1.72316
-10.24419
5.74067
-3.86981
9.38637
10.30149
11.62886
12.49301
12.46678
11.90186
11.84886
11.95487
12.28070
12.54658
12.54819
12.67694
12.96691
12.98765
12.97866
12.62430
12.24909
11.83247
11.93958
11.98562
12.08022
11.97421
11.62867
10.93464
8.95501
7.04081
87.12891
173.25781
292.47222
440.17648
596.93069
769.05987
974.09208
1206.45156
1504.64536
1807.83956
2177.74605
2595.55055
3045.21528
3550.92647
4109.56849
4749.31868
5456.19422
6408.52870
7440.71526
8554.42553
9794.58080
11161.64710
12757.65047
14405.03992
15984.00054
19289.26037
2.63713
15.46299
13.12569
17.44463
16.31907
16.44532
17.10880
17.89937
16.92003
15.19639
14.47684
13.30815
12.08762
2.62696
2.75437
2.83862
3.02088
3.12501
3.06590
2.99671
2.98436
2.99794
3.01999
2.97766
2.96631
3.07834
2.99264
2.93484
2.91424
2.81031
2.71607
2.81370
2.88491
2.96294
3.01507
3.03082
2.86959
2.59370
2.74709
1000.00000
1000.00000
26.24620
36.81842
38.78619
35.94704
41.87735
51.41814
68.31984
78.18270
85.86550
94.36517
99.69723
118.70094
136.14025
141.08837
179.08942
178.79265
180.67445
197.02229
204.96015
266.20027
254.49708
270.17589
540.48172
192.19627

0 comments on commit 2d29fc1

Please sign in to comment.