diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index e576d8af8b9..6865ebec658 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -21,6 +21,7 @@ use rusticl_opencl_gen::*; use std::cmp; use std::collections::HashMap; use std::convert::TryInto; +use std::mem::size_of; use std::os::raw::c_void; use std::ptr; use std::slice; @@ -307,16 +308,17 @@ pub struct Kernel { impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL); -fn create_kernel_arr(vals: &[usize], val: T) -> [T; 3] +fn create_kernel_arr(vals: &[usize], val: T) -> CLResult<[T; 3]> where T: std::convert::TryFrom + Copy, >::Error: std::fmt::Debug, { let mut res = [val; 3]; for (i, v) in vals.iter().enumerate() { - res[i] = (*v).try_into().expect("64 bit work groups not supported"); + res[i] = (*v).try_into().ok().ok_or(CL_OUT_OF_RESOURCES)?; } - res + + Ok(res) } fn opt_nir(nir: &mut NirShader, dev: &Device, has_explicit_types: bool) { @@ -401,18 +403,21 @@ fn lower_and_optimize_nir( args: &[spirv::SPIRVKernelArg], lib_clc: &NirShader, ) -> (Vec, Vec) { - let address_bits_base_type; let address_bits_ptr_type; let global_address_format; let shared_address_format; + let host_bits_base_type = if size_of::() == 8 { + glsl_base_type::GLSL_TYPE_UINT64 + } else { + glsl_base_type::GLSL_TYPE_UINT + }; + if dev.address_bits() == 64 { - address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT64; address_bits_ptr_type = unsafe { glsl_uint64_t_type() }; global_address_format = nir_address_format::nir_address_format_64bit_global; shared_address_format = nir_address_format::nir_address_format_32bit_offset_as_64bit; } else { - address_bits_base_type = glsl_base_type::GLSL_TYPE_UINT; address_bits_ptr_type = unsafe { glsl_uint_type() }; global_address_format = nir_address_format::nir_address_format_32bit_global; shared_address_format = nir_address_format::nir_address_format_32bit_offset; @@ -533,12 +538,12 @@ fn lower_and_optimize_nir( internal_args.push(InternalKernelArg { kind: InternalKernelArgType::GlobalWorkOffsets, offset: 0, - size: (3 * dev.address_bits() / 8) as usize, + size: 3 * size_of::(), }); lower_state.base_global_invoc_id_loc = args.len() + internal_args.len() - 1; nir.add_var( nir_variable_mode::nir_var_uniform, - unsafe { glsl_vector_type(address_bits_base_type, 3) }, + unsafe { glsl_vector_type(host_bits_base_type, 3) }, lower_state.base_global_invoc_id_loc, "base_global_invocation_id", ); @@ -866,27 +871,23 @@ impl Kernel { } } - fn optimize_local_size(&self, d: &Device, grid: &mut [u32; 3], block: &mut [u32; 3]) { + fn optimize_local_size(&self, d: &Device, grid: &mut [usize; 3], block: &mut [u32; 3]) { if !block.contains(&0) { for i in 0..3 { // we already made sure everything is fine - grid[i] /= block[i]; + grid[i] /= block[i] as usize; } return; } - let mut usize_grid = [0usize; 3]; let mut usize_block = [0usize; 3]; - for i in 0..3 { - usize_grid[i] = grid[i] as usize; usize_block[i] = block[i] as usize; } - self.suggest_local_size(d, 3, &mut usize_grid, &mut usize_block); + self.suggest_local_size(d, 3, grid, &mut usize_block); for i in 0..3 { - grid[i] = usize_grid[i] as u32; block[i] = usize_block[i] as u32; } } @@ -902,9 +903,9 @@ impl Kernel { offsets: &[usize], ) -> CLResult { let nir_kernel_build = self.builds.get(q.device).unwrap().clone(); - let mut block = create_kernel_arr::(block, 1); - let mut grid = create_kernel_arr::(grid, 1); - let offsets = create_kernel_arr::(offsets, 0); + let mut block = create_kernel_arr::(block, 1)?; + let mut grid = create_kernel_arr::(grid, 1)?; + let offsets = create_kernel_arr::(offsets, 0)?; let mut input: Vec = Vec::new(); let mut resource_info = Vec::new(); // Set it once so we get the alignment padding right @@ -1040,17 +1041,7 @@ impl Kernel { )); } InternalKernelArgType::GlobalWorkOffsets => { - if q.device.address_bits() == 64 { - input.extend_from_slice(unsafe { as_byte_slice(&offsets) }); - } else { - input.extend_from_slice(unsafe { - as_byte_slice(&[ - offsets[0] as u32, - offsets[1] as u32, - offsets[2] as u32, - ]) - }); - } + input.extend_from_slice(unsafe { as_byte_slice(&offsets) }); } InternalKernelArgType::PrintfBuffer => { let buf = Arc::new( @@ -1132,6 +1123,12 @@ impl Kernel { ctx.set_global_binding(resources.as_slice(), &mut globals); ctx.update_cb0(&input); + let grid = [ + grid[0].try_into().ok().ok_or(CL_OUT_OF_HOST_MEMORY)?, + grid[1].try_into().ok().ok_or(CL_OUT_OF_HOST_MEMORY)?, + grid[2].try_into().ok().ok_or(CL_OUT_OF_HOST_MEMORY)?, + ]; + ctx.launch_grid(work_dim, block, grid, variable_local_size as u32); ctx.clear_global_binding(globals.len() as u32);