Skip to content

Commit

Permalink
v0.0.1 add all
Browse files Browse the repository at this point in the history
  • Loading branch information
kongqiuqiang committed May 6, 2020
0 parents commit 9300eaa
Show file tree
Hide file tree
Showing 11 changed files with 1,001 additions and 0 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Piano transcription inference

This toolbox provide easy to use command for piano transcription inference.

# Installation
Install PyTorch (>=1.0) following https://pytorch.org/

```
$ python3 setup.py install
```

# Usage
```
python3 example.py --audio_path='examples/cut_liszt.wav' --output_midi_path='cut_liszt.mid' --cuda
```

For example:
```
# Load audio
(audio, _) = librosa.core.load('examples/cut_liszt.wav', sr=sample_rate, mono=True)
# Transcriptor
transcriptor = PianoTranscription(device=device)
# Transcribe and write out to MIDI file
transcribed_dict = transcriptor.transcribe(audio, 'cut_liszt.mid')
```

# Cite
[1] Q. Kong, et al., High resolution piano transcription by regressing onset and offset time stamps, [To appear], 2020
45 changes: 45 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import argparse
import torch
import librosa
import time

from piano_transcription_inference import PianoTranscription, sample_rate


def inference(args):
"""Inference template.
Args:
model_type: str
checkpoitn_path: str
audio_path: str
cuda: bool
"""

# Arugments & parameters
device = 'cpu' # 'cpu' | 'cuda'
audio_path = args.audio_path
output_midi_path = args.output_midi_path

# Load audio
(audio, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True)

# Transcriptor
transcriptor = PianoTranscription(device=device)

# Transcribe and write out to MIDI file
transcribe_time = time.time()
transcribed_dict = transcriptor.transcribe(audio, output_midi_path)
print('Transcribe time: {:.3f} s'.format(time.time() - transcribe_time))


if __name__ == '__main__':

parser = argparse.ArgumentParser(description='')
parser.add_argument('--audio_path', type=str, required=True)
parser.add_argument('--output_midi_path', type=str, required=True)
parser.add_argument('--cuda', action='store_true', default=False)

args = parser.parse_args()
inference(args)
Binary file added examples/cut_liszt.wav
Binary file not shown.
2 changes: 2 additions & 0 deletions piano_transcription_inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .inference import PianoTranscription
from .config import sample_rate
7 changes: 7 additions & 0 deletions piano_transcription_inference/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
sample_rate = 16000
classes_num = 88 # Number of notes of piano
begin_note = 21 # MIDI note of A0, the lowest note of a piano.
segment_seconds = 10. # Training segment duration
hop_seconds = 1.
frames_per_second = 100
velocity_scale = 128
167 changes: 167 additions & 0 deletions piano_transcription_inference/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import os
import numpy as np
import time
import librosa
from pathlib import Path

import torch

from .utilities import (create_folder, get_filename, RegressionPostProcessor,
write_events_to_midi)
from .models import Regress_onset_offset_frame_velocity_CRNN
from .pytorch_utils import move_data_to_device, forward
from . import config


class PianoTranscription(object):
def __init__(self, model_type='Regress_onset_offset_frame_velocity_CRNN',
checkpoint_path=None, segment_samples=16000*10, device='cuda'):
"""Class for transcribing piano solo recording.
Args:
model_type: str
checkpoint_path: str
segment_samples: int
device: 'cuda' | 'cpu'
"""
if not checkpoint_path:
checkpoint_path='{}/piano_transcription_inference_data/Regress_onset_offset_frame_velocity_CRNN_onset_F1=0.9677.pth'.format(str(Path.home()))
print('Checkpoint path: {}'.format(checkpoint_path))

if not os.path.exists(checkpoint_path):
create_folder(os.path.dirname(checkpoint_path))
print('Downloading (Please use VPN in mainland of China) ...')
print('Total size: 331 MB')
os.system('gdown -O "{}" --id 1lTDHkBUbp-69ta0uo6r5kzwEpEIUh-PQ'.format(checkpoint_path))

if device == 'cuda' and torch.cuda.is_available():
self.device = 'cuda'
else:
self.device = 'cpu'

self.segment_samples = segment_samples
self.frames_per_second = config.frames_per_second
self.classes_num = config.classes_num
self.onset_threshold = 0.3
self.offset_threshod = 0.3
self.frame_threshold = 0.1

# Build model
Model = eval(model_type)
self.model = Model(frames_per_second=self.frames_per_second,
classes_num=self.classes_num)

# Load model
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model'], strict=False)

# Parallel
print('Using {}'.format(self.device))
print('GPU number: {}'.format(torch.cuda.device_count()))
self.model = torch.nn.DataParallel(self.model)

if 'cuda' in str(self.device):
self.model.to(self.device)

def transcribe(self, audio, midi_path):
"""Transcribe an audio recording.
Args:
audio: (audio_samples,)
midi_path: str, path to write out the transcribed MIDI.
Returns:
transcribed_dict, dict: {'output_dict':, ..., 'est_note_events': ...}
"""
audio = audio[None, :] # (1, audio_samples)

# Pad audio to be evenly divided by segment_samples
audio_len = audio.shape[1]
pad_len = int(np.ceil(audio_len / self.segment_samples))\
* self.segment_samples - audio_len

audio = np.concatenate((audio, np.zeros((1, pad_len))), axis=1)

# Enframe to segments
segments = self.enframe(audio, self.segment_samples)
"""(N, segment_samples)"""

# Forward
output_dict = forward(self.model, segments, batch_size=12)
"""{'reg_onset_output': (N, segment_frames, classes_num), ...}"""

# Deframe to original length
for key in output_dict.keys():
output_dict[key] = self.deframe(output_dict[key])[0 : audio_len]
"""output_dict: {
'reg_onset_output': (N, segment_frames, classes_num),
'reg_offset_output': (N, segment_frames, classes_num),
'frame_output': (N, segment_frames, classes_num),
'velocity_output': (N, segment_frames, classes_num)}"""

# Post processor
post_processor = RegressionPostProcessor(self.frames_per_second,
classes_num=self.classes_num, onset_threshold=self.onset_threshold,
offset_threshold=self.offset_threshod,
frame_threshold=self.frame_threshold)

# Post process output_dict to MIDI events
est_note_events = post_processor.output_dict_to_midi_events(output_dict)

# Write MIDI events to file
if midi_path:
write_events_to_midi(start_time=0, note_events=est_note_events, midi_path=midi_path)
print('Write out to {}'.format(midi_path))

transcribed_dict = {'output_dict': output_dict, 'est_note_events': est_note_events}
return transcribed_dict


def enframe(self, x, segment_samples):
"""Enframe long sequence to short segments.
Args:
x: (1, audio_samples)
segment_samples: int
Returns:
batch: (N, segment_samples)
"""
assert x.shape[1] % segment_samples == 0
batch = []

pointer = 0
while pointer + segment_samples <= x.shape[1]:
batch.append(x[:, pointer : pointer + segment_samples])
pointer += segment_samples // 2

batch = np.concatenate(batch, axis=0)
return batch

def deframe(self, x):
"""Deframe predicted segments to original sequence.
Args:
x: (N, segment_frames, classes_num)
Returns:
y: (audio_frames, classes_num)
"""
if x.shape[0] == 1:
return x[0]

else:
x = x[:, 0 : -1, :]
"""Remove an extra frame in the end of each segment caused by the
'center=True' argument when calculating spectrogram."""
(N, segment_samples, classes_num) = x.shape
assert segment_samples % 4 == 0

y = []
y.append(x[0, 0 : int(segment_samples * 0.75)])
for i in range(1, N - 1):
y.append(x[i, int(segment_samples * 0.25) : int(segment_samples * 0.75)])
y.append(x[-1, int(segment_samples * 0.25) :])
y = np.concatenate(y, axis=0)
return y
Loading

0 comments on commit 9300eaa

Please sign in to comment.