Skip to content

Commit

Permalink
Pre-warm the model on creation over the size of the receptive field. …
Browse files Browse the repository at this point in the history
…Removed no longer needed anti-pop code. (#71)
  • Loading branch information
mikeoliphant authored Sep 13, 2023
1 parent bad4928 commit 885a535
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 44 deletions.
52 changes: 17 additions & 35 deletions NAM/wavenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,23 @@ wavenet::WaveNet::WaveNet(const double loudness, const std::vector<wavenet::Laye
}
this->_head_output.resize(1, 0); // Mono output!
this->set_params_(params);
this->_reset_anti_pop_();

long receptive_field = 1;
for (size_t i = 0; i < this->_layer_arrays.size(); i++)
receptive_field += this->_layer_arrays[i].get_receptive_field();

NAM_SAMPLE sample = 0;
NAM_SAMPLE* sample_ptr = &sample;

std::unordered_map<std::string, double> param_dict = {};

// pre-warm the model over the size of the receptive field
for (long i = 0; i < receptive_field; i++)
{
this->process(&sample_ptr, &sample_ptr, 1, 1, 1.0, 1.0, param_dict);
this->finalize_(1);
sample = 0;
}
}

void wavenet::WaveNet::finalize_(const int num_frames)
Expand Down Expand Up @@ -325,11 +341,6 @@ void wavenet::WaveNet::_process_core_()
this->_set_num_frames_(num_frames);
this->_prepare_for_frames_(num_frames);

// NOTE: During warm-up, weird things can happen that NaN out the layers.
// We could solve this by anti-popping the *input*. But, it's easier to check
// the outputs for NaNs and zero them out.
// They'll flush out eventually because the model doesn't use any feedback.

// Fill into condition array:
// Clumsy...
for (int j = 0; j < num_frames; j++)
Expand Down Expand Up @@ -361,13 +372,8 @@ void wavenet::WaveNet::_process_core_()
for (int s = 0; s < num_frames; s++)
{
float out = this->_head_scale * this->_head_arrays[final_head_array](0, s);
// This is the NaN check that we could fix with anti-popping the input
if (isnan(out))
out = 0.0;
this->_core_dsp_output[s] = out;
}
// Apply anti-pop
this->_anti_pop_();
}

void wavenet::WaveNet::_set_num_frames_(const long num_frames)
Expand All @@ -388,27 +394,3 @@ void wavenet::WaveNet::_set_num_frames_(const long num_frames)
// this->_head.set_num_frames_(num_frames);
this->_num_frames = num_frames;
}

void wavenet::WaveNet::_anti_pop_()
{
if (this->_anti_pop_countdown >= this->_anti_pop_ramp)
return;
const float slope = 1.0f / float(this->_anti_pop_ramp);
for (size_t i = 0; i < this->_core_dsp_output.size(); i++)
{
if (this->_anti_pop_countdown >= this->_anti_pop_ramp)
break;
const float gain = std::max(slope * float(this->_anti_pop_countdown), 0.0f);
this->_core_dsp_output[i] *= gain;
this->_anti_pop_countdown++;
}
}

void wavenet::WaveNet::_reset_anti_pop_()
{
// You need the "real" receptive field, not the buffers.
long receptive_field = 1;
for (size_t i = 0; i < this->_layer_arrays.size(); i++)
receptive_field += this->_layer_arrays[i].get_receptive_field();
this->_anti_pop_countdown = -receptive_field;
}
9 changes: 0 additions & 9 deletions NAM/wavenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,5 @@ class WaveNet : public DSP

// Ensure that all buffer arrays are the right size for this num_frames
void _set_num_frames_(const long num_frames);

// The net starts with random parameters inside; we need to wait for a full
// receptive field to pass through before we can count on the output being
// ok. This implements a gentle "ramp-up" so that there's no "pop" at the
// start.
long _anti_pop_countdown;
const long _anti_pop_ramp = 4000;
void _anti_pop_();
void _reset_anti_pop_();
};
}; // namespace wavenet

0 comments on commit 885a535

Please sign in to comment.