diff --git a/src/gallium/frontends/rusticl/core/program.rs b/src/gallium/frontends/rusticl/core/program.rs index 81b13dc6f12..0c70cf7973c 100644 --- a/src/gallium/frontends/rusticl/core/program.rs +++ b/src/gallium/frontends/rusticl/core/program.rs @@ -378,6 +378,36 @@ impl Program { }) } + fn spirv_from_bin_for_dev(bin: &[u8]) -> (SPIRVBin, cl_program_binary_type) { + let mut ptr = bin.as_ptr(); + unsafe { + // 1. version + let version = ptr.cast::().read(); + ptr = ptr.add(size_of::()); + + match version { + 1 => { + // 2. size of the spirv + let spirv_size = ptr.cast::().read(); + ptr = ptr.add(size_of::()); + + // 3. binary_type + let bin_type = ptr.cast::().read(); + ptr = ptr.add(size_of::()); + + // 4. the spirv + assert!(bin.as_ptr().add(BIN_HEADER_SIZE_V1) == ptr); + assert!(bin.len() == BIN_HEADER_SIZE_V1 + spirv_size as usize); + let spirv = + spirv::SPIRVBin::from_bin(slice::from_raw_parts(ptr, spirv_size as usize)); + + (spirv, bin_type) + } + _ => panic!("unknown version"), + } + } + } + pub fn from_bins( context: Arc, devs: Vec<&'static Device>, @@ -387,47 +417,16 @@ impl Program { let mut kernels = HashSet::new(); for (&d, b) in devs.iter().zip(bins) { - let mut ptr = b.as_ptr(); - let bin_type; - let spirv; + let (spirv, bin_type) = Self::spirv_from_bin_for_dev(b); - unsafe { - // 1. version - let version = ptr.cast::().read(); - ptr = ptr.add(size_of::()); - - match version { - 1 => { - // 2. size of the spirv - let spirv_size = ptr.cast::().read(); - ptr = ptr.add(size_of::()); - - // 3. binary_type - bin_type = ptr.cast::().read(); - ptr = ptr.add(size_of::()); - - // 4. the spirv - assert!(b.as_ptr().add(BIN_HEADER_SIZE_V1) == ptr); - assert!(b.len() == BIN_HEADER_SIZE_V1 + spirv_size as usize); - spirv = Some(spirv::SPIRVBin::from_bin(slice::from_raw_parts( - ptr, - spirv_size as usize, - ))); - } - _ => panic!("unknown version"), - } - } - - if let Some(spirv) = &spirv { - for k in spirv.kernels() { - kernels.insert(k); - } + for k in spirv.kernels() { + kernels.insert(k); } builds.insert( d, ProgramDevBuild { - spirv: spirv, + spirv: Some(spirv), status: CL_BUILD_SUCCESS as cl_build_status, log: String::from(""), options: String::from(""),