forked from facebookresearch/svoice
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathseparate.py
135 lines (115 loc) · 4.24 KB
/
separate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Authors: Yossi Adi (adiyoss)
import argparse
import logging
import os
import sys
import librosa
import torch
import tqdm
import soundfile as sf
from svoice.data.data import EvalDataLoader, EvalDataset
from svoice import distrib
from svoice.utils import remove_pad
from svoice.utils import bold, deserialize_model, LogProgress
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser("Speech separation using MulCat blocks")
parser.add_argument("--model_path", type=str, default="checkpoint.th", help="Model name")
parser.add_argument("--out_dir", type=str, default="exp/result",
help="Directory putting enhanced wav files")
parser.add_argument("--mix_dir", type=str, default=None,
help="Directory including mix wav files")
parser.add_argument("--mix_json", type=str, default="test/mix.json",
help="Json file including mix wav files")
parser.add_argument('--device', default="cuda")
parser.add_argument("--sample_rate", default=16000,
type=int, help="Sample rate")
parser.add_argument("--batch_size", default=1, type=int, help="Batch size")
parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG,
default=logging.INFO, help="More loggging")
def save_wavs(estimate_source, mix_sig, lengths, filenames, out_dir, sr=16000):
# Remove padding and flat
flat_estimate = remove_pad(estimate_source, lengths)
mix_sig = remove_pad(mix_sig, lengths)
# Write result
for i, filename in enumerate(filenames):
filename = os.path.join(
out_dir, os.path.basename(filename).strip(".wav"))
#write(mix_sig[i], filename + ".wav", sr=sr)
C = flat_estimate[i].shape[0]
# future support for wave playing
for c in range(C):
write(flat_estimate[i][c], filename + f"_s{c + 1}.wav", sr=sr)
def write(inputs, filename, sr=8000):
#librosa.output.write_wav(filename, inputs, sr, norm=True)
inputs = librosa.util.normalize(inputs)
sf.write(filename, inputs, sr)
def get_mix_paths(args):
mix_dir = None
mix_json = None
# fix mix dir
try:
if args.dset.mix_dir:
mix_dir = args.dset.mix_dir
except:
mix_dir = args.mix_dir
# fix mix json
try:
if args.dset.mix_json:
mix_json = args.dset.mix_json
except:
mix_json = args.mix_json
return mix_dir, mix_json
def separate(args, model=None, local_out_dir=None):
mix_dir, mix_json = get_mix_paths(args)
if not mix_json and not mix_dir:
logger.error("Must provide mix_dir or mix_json! "
"When providing mix_dir, mix_json is ignored.")
# Load model
if not model:
# model
pkg = torch.load(args.model_path)
if 'model' in pkg:
model = pkg['model']
else:
model = pkg
model = deserialize_model(model)
logger.debug(model)
model.eval()
model.to(args.device)
if local_out_dir:
out_dir = local_out_dir
else:
out_dir = args.out_dir
# Load data
eval_dataset = EvalDataset(
mix_dir,
mix_json,
batch_size=args.batch_size,
sample_rate=args.sample_rate,
)
eval_loader = distrib.loader(
eval_dataset, batch_size=1, klass=EvalDataLoader)
if distrib.rank == 0:
os.makedirs(out_dir, exist_ok=True)
distrib.barrier()
with torch.no_grad():
for i, data in enumerate(tqdm.tqdm(eval_loader, ncols=120)):
# Get batch data
mixture, lengths, filenames = data
mixture = mixture.to(args.device)
lengths = lengths.to(args.device)
# Forward
estimate_sources = model(mixture)[-1]
# save wav files
save_wavs(estimate_sources, mixture, lengths,
filenames, out_dir, sr=args.sample_rate)
if __name__ == "__main__":
args = parser.parse_args()
logging.basicConfig(stream=sys.stderr, level=args.verbose)
logger.debug(args)
separate(args, local_out_dir=args.out_dir)