Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
kongqiuqiang committed Sep 18, 2020
1 parent 8a184b8 commit d7008cf
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 68 deletions.
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
# Piano transcription inference

This toolbox provide easy to use command for piano transcription inference.
This toolbox is a piano transcription inference package that can be easily installed. Users can transcribe their favorite piano recordings to MIDI files after installation. To see how the piano transcription system is trained, please visit: https://github.com/bytedance/piano_transcription.

# Installation
Install PyTorch (>=1.0) following https://pytorch.org/
## Demos
Here is a demo of our piano transcription system: https://www.youtube.com/watch?v=5U-WL0QvKCg

## Installation
Install PyTorch (>=1.4) following https://pytorch.org/

```
$ python3 setup.py install
```

# Usage
## Usage
```
python3 example.py --audio_path='resources/cut_liszt.mp3' --output_midi_path='cut_liszt.mid' --cuda
```

For example:
```
import librosa
from piano_transcription_inference import PianoTranscription, sample_rate
# Load audio
(audio, _) = librosa.core.load('resources/cut_liszt.mp3', sr=sample_rate, mono=True)
Expand All @@ -26,5 +32,5 @@ transcriptor = PianoTranscription(device=device)
transcribed_dict = transcriptor.transcribe(audio, 'cut_liszt.mid')
```

# Cite
[1] Q. Kong, et al., High resolution piano transcription by regressing onset and offset time stamps, [To appear], 2020
## Cite
[1] High-resolution Piano Transcription with Pedals by Regressing Onsets and Offsets Times, [To appear], 2020
17 changes: 9 additions & 8 deletions piano_transcription_inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def __init__(self, model_type='Note_pedal', checkpoint_path=None,
checkpoint_path='{}/piano_transcription_inference_data/note_F1=0.9677_pedal_F1=0.8658.pth'.format(str(Path.home()))
print('Checkpoint path: {}'.format(checkpoint_path))

if not os.path.exists(checkpoint_path):
if not os.path.exists(checkpoint_path) or os.path.getsize(checkpoint_path) < 1.6e8
create_folder(os.path.dirname(checkpoint_path))
print('Downloading (Please use VPN in mainland of China) ...')
print('Total size: 164 MB')
os.system('gdown -O "{}" --id 15to2oXUIJc1345Koyur8aPwUfd0h5SOT'.format(checkpoint_path))
print('Total size: ~165 MB')
zenodo_path = 'https://zenodo.org/record/4034264/files/CRNN_note_F1%3D0.9677_pedal_F1%3D0.9186.pth?download=1'
os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path))

print('Using {} for inference.'.format(device))

Expand All @@ -54,11 +54,12 @@ def __init__(self, model_type='Note_pedal', checkpoint_path=None,
self.model.load_state_dict(checkpoint['model'], strict=False)

# Parallel
print('GPU number: {}'.format(torch.cuda.device_count()))

if 'cuda' in str(device):
if 'cuda' in str(self.device):
self.model.to(self.device)
print('GPU number: {}'.format(torch.cuda.device_count()))
self.model = torch.nn.DataParallel(self.model)
self.model.to(device)
else:
print('Using CPU.')

def transcribe(self, audio, midi_path):
"""Transcribe an audio recording.
Expand Down
46 changes: 24 additions & 22 deletions piano_transcription_inference/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn.functional as F

from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from .pytorch_utils import move_data_to_device
from pytorch_utils import move_data_to_device


def init_layer(layer):
Expand Down Expand Up @@ -84,20 +84,18 @@ def init_weight(self):


def forward(self, input, pool_size=(2, 2), pool_type='avg'):

x = input
x = F.relu_(self.bn1(self.conv1(x)))
"""
Args:
input: (batch_size, in_channels, time_steps, freq_bins)
Outputs:
output: (batch_size, out_channels, classes_num)
"""

x = F.relu_(self.bn1(self.conv1(input)))
x = F.relu_(self.bn2(self.conv2(x)))
if pool_type == 'max':
x = F.max_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg':

if pool_type == 'avg':
x = F.avg_pool2d(x, kernel_size=pool_size)
elif pool_type == 'avg+max':
x1 = F.avg_pool2d(x, kernel_size=pool_size)
x2 = F.max_pool2d(x, kernel_size=pool_size)
x = x1 + x2
else:
raise Exception('Incorrect argument!')

return x

Expand Down Expand Up @@ -128,6 +126,13 @@ def init_weight(self):
init_layer(self.fc)

def forward(self, input):
"""
Args:
input: (batch_size, channels_num, time_steps, freq_bins)
Outputs:
output: (batch_size, time_steps, classes_num)
"""

x = self.conv_block1(input, pool_size=(1, 2), pool_type='avg')
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv_block2(x, pool_size=(1, 2), pool_type='avg')
Expand Down Expand Up @@ -206,7 +211,6 @@ def forward(self, input):
"""
Args:
input: (batch_size, data_length)
Outputs:
output_dict: dict, {
'reg_onset_output': (batch_size, time_steps, classes_num),
Expand All @@ -228,14 +232,14 @@ def forward(self, input):
reg_offset_output = self.reg_offset_model(x) # (batch_size, time_steps, classes_num)
velocity_output = self.velocity_model(x) # (batch_size, time_steps, classes_num)

# Use velocity to improve onset regression
# Use velocities to condition onset regression
x = torch.cat((reg_onset_output, (reg_onset_output ** 0.5) * velocity_output.detach()), dim=2)
(x, _) = self.reg_onset_gru(x)
x = F.dropout(x, p=0.5, training=self.training, inplace=False)
reg_onset_output = torch.sigmoid(self.reg_onset_fc(x))
"""(batch_size, time_steps, classes_num)"""

# Use onset and offset to improve framewise classification
# Use onsets and offsets to condition frame-wise classification
x = torch.cat((frame_output, reg_onset_output.detach(), reg_offset_output.detach()), dim=2)
(x, _) = self.frame_gru(x)
x = F.dropout(x, p=0.5, training=self.training, inplace=False)
Expand All @@ -251,7 +255,6 @@ def forward(self, input):
return output_dict


####################################
class Regress_pedal_CRNN(nn.Module):
def __init__(self, frames_per_second, classes_num):
super(Regress_pedal_CRNN, self).__init__()
Expand Down Expand Up @@ -298,7 +301,6 @@ def forward(self, input):
"""
Args:
input: (batch_size, data_length)
Outputs:
output_dict: dict, {
'reg_onset_output': (batch_size, time_steps, classes_num),
Expand All @@ -307,6 +309,7 @@ def forward(self, input):
'velocity_output': (batch_size, time_steps, classes_num)
}
"""

x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)

Expand All @@ -326,11 +329,10 @@ def forward(self, input):
return output_dict


####################################
# This model is not trained, but combined from the pretrained note and pedal models.
# This model is not trained, but is combined from the trained note and pedal models.
class Note_pedal(nn.Module):
def __init__(self, frames_per_second, classes_num):
"""Combination of note and pedal model.
"""The combination of note and pedal model.
"""
super(Note_pedal, self).__init__()

Expand All @@ -348,4 +350,4 @@ def forward(self, input):
full_output_dict = {}
full_output_dict.update(note_output_dict)
full_output_dict.update(pedal_output_dict)
return full_output_dict
return full_output_dict
15 changes: 8 additions & 7 deletions piano_transcription_inference/piano_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
def note_detection_with_onset_offset_regress(frame_output, onset_output,
onset_shift_output, offset_output, offset_shift_output, velocity_output,
frame_threshold):
"""Estimate onset and offset, onset shift, offset shift and velocity of
piano notes. First detect onsets with onset outputs, then detect offsets
"""Process prediction matrices to note events information.
First, detect onsets with onset outputs. Then, detect offsets
with frame and offset outputs.
Args:
Expand All @@ -16,11 +16,10 @@ def note_detection_with_onset_offset_regress(frame_output, onset_output,
offset_shift_output: (frames_num,)
velocity_output: (frames_num,)
frame_threshold: float
Returns:
output_tuples: list of [bgn, fin, onset_shift, offset_shift, normalized_velocity],
e.g., [
[1821, 1909, 0.4749851, 0.3048533, 0.72119445],
[1821, 1909, 0.47498, 0.3048533, 0.72119445],
[1909, 1947, 0.30730522, -0.45764327, 0.64200014],
...]
"""
Expand All @@ -31,8 +30,10 @@ def note_detection_with_onset_offset_regress(frame_output, onset_output,

for i in range(onset_output.shape[0]):
if onset_output[i] == 1:
"""Onset detected"""
if bgn:
"""Consecutive onsets"""
"""Consecutive onsets. E.g., pedal is not released, but two
consecutive notes being played."""
fin = max(i - 1, 0)
output_tuples.append([bgn, fin, onset_shift_output[bgn],
0, velocity_output[bgn]])
Expand Down Expand Up @@ -75,14 +76,13 @@ def note_detection_with_onset_offset_regress(frame_output, onset_output,

def pedal_detection_with_onset_offset_regress(frame_output, offset_output,
offset_shift_output, frame_threshold):
"""Estimate onset and offset, onset shift and offset shift of pedals.
"""Process prediction array to pedal events information.
Args:
frame_output: (frames_num,)
offset_output: (frames_num,)
offset_shift_output: (frames_num,)
frame_threshold: float
Returns:
output_tuples: list of [bgn, fin, onset_shift, offset_shift],
e.g., [
Expand All @@ -97,6 +97,7 @@ def pedal_detection_with_onset_offset_regress(frame_output, offset_output,

for i in range(1, frame_output.shape[0]):
if frame_output[i] >= frame_threshold and frame_output[i] > frame_output[i - 1]:
"""Pedal onset detected"""
if bgn:
pass
else:
Expand Down
Loading

0 comments on commit d7008cf

Please sign in to comment.