Skip to content

Commit

Permalink
Fix/wgpu queue (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Nov 14, 2024
1 parent a1471a7 commit 99df093
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
83 changes: 82 additions & 1 deletion crates/cubecl-wgpu/src/compute/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub struct WgpuStream {
queue: Arc<wgpu::Queue>,
poll: WgpuPoll,
sync_buffer: Option<wgpu::Buffer>,
submission_load: SubmissionLoad,
}

pub enum PipelineDispatch {
Expand Down Expand Up @@ -56,6 +57,7 @@ impl WgpuStream {
tasks_max,
poll,
sync_buffer,
submission_load: SubmissionLoad::default(),
}
}

Expand Down Expand Up @@ -307,11 +309,90 @@ impl WgpuStream {
let new_encoder = create_encoder(&self.device);
let encoder = std::mem::replace(&mut self.encoder, new_encoder);

self.queue.submit([encoder.finish()]);
let index = self.queue.submit([encoder.finish()]);

self.submission_load
.regulate(&self.device, self.tasks_count, index);

self.tasks_count = 0;
}
}

#[cfg(not(target_family = "wasm"))]
mod __submission_load {

#[derive(Default, Debug)]
pub enum SubmissionLoad {
Init {
last_index: wgpu::SubmissionIndex,
tasks_count_submitted: usize,
},
#[default]
Empty,
}

impl SubmissionLoad {
pub fn regulate(
&mut self,
device: &wgpu::Device,
tasks_count: usize,
mut index: wgpu::SubmissionIndex,
) {
match self {
SubmissionLoad::Init {
last_index,
tasks_count_submitted,
} => {
*tasks_count_submitted += tasks_count;

// Enough to keep the GPU busy.
//
// - Too much can hang the GPU and create slowdown.
// - Too little and GPU utilization is really bad.
//
// TODO: Could be smarter and dynamic based on stats.
const MAX_TOTAL_TASKS: usize = 512;

if *tasks_count_submitted >= MAX_TOTAL_TASKS {
core::mem::swap(last_index, &mut index);
device.poll(wgpu::MaintainBase::WaitForSubmissionIndex(index));

*tasks_count_submitted = 0;
}
}
SubmissionLoad::Empty => {
*self = Self::Init {
last_index: index,
tasks_count_submitted: 0,
}
}
}
}
}
}

#[cfg(target_family = "wasm")]
mod __submission_load_wasm {
#[derive(Default, Debug)]
pub struct SubmissionLoad;

impl SubmissionLoad {
pub fn regulate(
&mut self,
_device: &wgpu::Device,
_tasks_count: usize,
_index: wgpu::SubmissionIndex,
) {
// Nothing to do.
}
}
}

#[cfg(not(target_family = "wasm"))]
use __submission_load::*;
#[cfg(target_family = "wasm")]
use __submission_load_wasm::*;

fn create_encoder(device: &wgpu::Device) -> wgpu::CommandEncoder {
device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("CubeCL Command Encoder"),
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-wgpu/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl Default for RuntimeOptions {
#[cfg(test)]
const DEFAULT_MAX_TASKS: usize = 1;
#[cfg(not(test))]
const DEFAULT_MAX_TASKS: usize = 16;
const DEFAULT_MAX_TASKS: usize = 32;

let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
Ok(value) => value
Expand Down

0 comments on commit 99df093

Please sign in to comment.