diff --git a/.pick_status.json b/.pick_status.json index 1960501c363..fd57c610bab 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -684,7 +684,7 @@ "description": "rusticl/kernel: Do not run kernels with a workgroup size beyond work_dim", "nominated": true, "nomination_type": 2, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": "376d1e6667a80e1811d4c25115633817f16666a7", "notes": null diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index fce71923840..a723e00b4cb 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -1478,7 +1478,13 @@ impl Kernel { } } - fn optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3]) { + fn optimize_local_size( + &self, + d: &Device, + work_dim: u32, + grid: &mut [usize; 3], + block: &mut [u32; 3], + ) { if !block.contains(&0) { for i in 0..3 { // we already made sure everything is fine @@ -1492,10 +1498,10 @@ impl Kernel { usize_block[i] = block[i] as usize; } - self.suggest_local_size(d, 3, grid, &mut usize_block); + self.suggest_local_size(d, work_dim as usize, grid, &mut usize_block); for i in 0..3 { - block[i] = usize_block[i] as u32; + block[i] = 1.max(usize_block[i] as u32); } } @@ -1549,7 +1555,7 @@ impl Kernel { let api_grid = grid; - self.optimize_local_size(q.device, &mut grid, &mut block); + self.optimize_local_size(q.device, work_dim, &mut grid, &mut block); Ok(Box::new(move |cl_ctx, ctx| { let hw_max_grid = ctx.dev.max_grid_size();