Skip to content

Commit

Permalink
update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
funcwj committed Jun 20, 2018
1 parent c570bdc commit 8a88896
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 89 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,10 @@ see requirements.txt

```shell
python ./train_dcnet.py --config conf/train.yaml --num-epoches 20 > train.log 2>&1 &
```
```

### Experiments

| Configure | Epoch | FM | FF | MM | FF/MM | AVG |
| :-------: | :---: | :---: | :--: | :--: | :---: | :--: |
| [config-1](conf/1.config.yaml) | 26 | 11.34 | 6.41 | 7.89 | 7.15 | 9.44 |
13 changes: 6 additions & 7 deletions conf/train.yaml → conf/1.config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# config for training

trainer:
checkpoint: "./tune/2spk_dcnet_a"
optimizer: "adam"
lr: 1e-3
checkpoint: "./tune/2spk_dcnet_d"
optimizer: "rmsprop"
lr: 1e-5
momentum: 0.9
weight_decay: 0
clip_norm: 200
Expand All @@ -12,8 +12,8 @@ trainer:
dcnet:
rnn: "lstm"
embedding_dim: 20
num_layers: 4
hidden_size: 300
num_layers: 2
hidden_size: 600
dropout: 0.5
non_linear: "tanh"
bidirectional: true
Expand Down Expand Up @@ -43,8 +43,7 @@ debug_scp_conf:

dataloader:
shuffle: true
batch_size: 16
batch_size: 1
drop_last: false
vad_threshold: 40
mvn_dict: "data/cmvn.dict"

52 changes: 0 additions & 52 deletions conf/train_3spk.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion scripts/run_demo.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env bash
# wujian@2018

mix_scp=./data/tune/mix.scp
mix_scp=./data/2spk/test/mix.scp
mdl_dir=./tune/2spk_dcnet_a

set -eu
Expand Down
2 changes: 1 addition & 1 deletion scripts/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ echo "start training --> $checkpoint ..."

cp $conf $checkpoint/train.yaml

CUDA_VISIBLE_DEVICES=1 python ./train_dcnet.py --config $conf --num-epoches 20 > $checkpoint/train.log 2>&1
CUDA_VISIBLE_DEVICES=0 python ./train_dcnet.py --config $conf --num-epoches 20 > $checkpoint/train.log 2>&1

echo "done"
53 changes: 32 additions & 21 deletions separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@
import torch as th
import scipy.io as sio

from utils import stft, istft, parse_scps, compute_vad_mask, apply_cmvn
from utils import stft, istft, parse_scps, compute_vad_mask, apply_cmvn, parse_yaml
from dcnet import DCNet
from train_dcnet import parse_yaml


class DeepCluster(object):
def __init__(self, dcnet, state_dict, num_spks):
if not os.path.exists(state_dict):
def __init__(self, dcnet, dcnet_state, num_spks, pca=False, cuda=False):
if not os.path.exists(dcnet_state):
raise RuntimeError(
"Could not find state file {}".format(state_dict))
"Could not find state file {}".format(dcnet_state))
self.dcnet = dcnet
# compute on cpu
self.dcnet.load_state_dict(th.load(state_dict, map_location="cpu"))

self.location = "cuda" if args.cuda else "cpu"
self.dcnet.load_state_dict(
th.load(dcnet_state, map_location=self.location))
self.dcnet.eval()
self.kmeans = sklearn.cluster.KMeans(n_clusters=num_spks)
self.pca = sklearn.decomposition.PCA(n_components=3)
self.pca = sklearn.decomposition.PCA(n_components=3) if pca else None
self.num_spks = num_spks

def _cluster(self, spectra, vad_mask):
Expand All @@ -40,19 +40,21 @@ def _cluster(self, spectra, vad_mask):
"""
# TF x D
net_embed = self.dcnet(
th.tensor(spectra, dtype=th.float32),
th.tensor(spectra, dtype=th.float32, device=self.location),
train=False).cpu().data.numpy()
# filter silence embeddings: TF x D => N x D
active_embed = net_embed[vad_mask.reshape(-1)]
# classes: N x D
# pca_mat: N x 3
classes = self.kmeans.fit_predict(active_embed)
pca_mat = self.pca.fit_transform(active_embed)

pca_mat = None
if self.pca:
pca_mat = self.pca.fit_transform(active_embed)

def form_mask(classes, spkid, vad_mask):
# or give silence bins to each speaker
# mask = ~vad_mask
mask = np.zeros_like(vad_mask)
mask = ~vad_mask
# mask = np.zeros_like(vad_mask)
mask[vad_mask] = (classes == spkid)
return mask

Expand Down Expand Up @@ -98,7 +100,12 @@ def run(args):
frame_shift = config_dict["spectrogram_reader"]["frame_shift"]
window = config_dict["spectrogram_reader"]["window"]

cluster = DeepCluster(dcnet, args.dcnet_state, args.num_spks)
cluster = DeepCluster(
dcnet,
args.dcnet_state,
args.num_spks,
pca=args.dump_pca,
cuda=args.cuda)

utt_dict = parse_scps(args.wave_scp)
num_utts = 0
Expand All @@ -121,9 +128,6 @@ def run(args):
stft_mat, cmvn=dict_mvn)

for index, stft_mat in enumerate(spk_spectrogram):
# NOTE: bss_eval_sources.m is sensitive to shift of samples,
# so it's better to center frames and keep separated speech
# the same length as mixture's.
istft(
os.path.join(args.dump_dir, '{}.spk{}.wav'.format(
key, index + 1)),
Expand All @@ -136,7 +140,6 @@ def run(args):
fs=8000,
nsamps=samps.size)
if args.dump_mask:
# compatible with matlab
sio.savemat(
os.path.join(args.dump_dir, '{}.spk{}.mat'.format(
key, index + 1)), {"mask": spk_mask[index]})
Expand All @@ -157,11 +160,19 @@ def run(args):
parser.add_argument(
"dcnet_state", type=str, help="Location of networks state file")
parser.add_argument(
"wave_scp", type=str, help="Input wave scripts, in kaldi format")
"wave_scp",
type=str,
help="Location of input wave scripts in kaldi format")
parser.add_argument(
"--cuda",
default=False,
action="store_true",
dest="cuda",
help="If true, inference on GPUs")
parser.add_argument(
"--num-spks",
type=int,
default=3,
default=2,
dest="num_spks",
help="Number of speakers to be seperated")
parser.add_argument(
Expand Down
4 changes: 2 additions & 2 deletions train_dcnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import argparse
import os

from trainer import PerUttTrainer
from trainer import Trainer
from dataset import SpectrogramReader, Dataset, DataLoader, logger
from dcnet import DCNet
from utils import nfft, parse_yaml
Expand Down Expand Up @@ -61,7 +61,7 @@ def train(args):
if checkpoint is None else checkpoint))

dcnet = DCNet(num_bins, **dcnnet_conf)
trainer = PerUttTrainer(dcnet, **config_dict["trainer"])
trainer = Trainer(dcnet, **config_dict["trainer"])
trainer.run(train_loader, valid_loader, num_epoches=args.num_epoches)


Expand Down
5 changes: 1 addition & 4 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def create_optimizer(optimizer, params, **kwargs):
return opt


class PerUttTrainer(object):
class Trainer(object):
def __init__(self,
dcnet,
checkpoint="checkpoint",
Expand All @@ -43,9 +43,6 @@ def __init__(self,
num_spks=2):
self.nnet = dcnet
logger.info("DCNet:\n{}".format(self.nnet))
if type(lr) is str:
lr = float(lr)
logger.info("Transfrom lr from str to float => {}".format(lr))
self.optimizer = create_optimizer(
optimizer,
self.nnet.parameters(),
Expand Down

0 comments on commit 8a88896

Please sign in to comment.