diff --git a/src/gallium/frontends/rusticl/api/memory.rs b/src/gallium/frontends/rusticl/api/memory.rs index ee6eb62437d..6ba9411c5c4 100644 --- a/src/gallium/frontends/rusticl/api/memory.rs +++ b/src/gallium/frontends/rusticl/api/memory.rs @@ -26,12 +26,21 @@ use std::os::raw::c_void; use std::ptr; use std::sync::Arc; -fn validate_mem_flags(flags: cl_mem_flags, import: bool) -> CLResult<()> { +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum MemFlagValidationType { + /// For plain object creation. + Plain, + + /// For when the memory object is imported. + Imported, +} + +fn validate_mem_flags(flags: cl_mem_flags, validation: MemFlagValidationType) -> CLResult<()> { let mut valid_flags = cl_bitfield::from( CL_MEM_READ_WRITE | CL_MEM_WRITE_ONLY | CL_MEM_READ_ONLY | CL_MEM_KERNEL_READ_AND_WRITE, ); - if !import { + if validation != MemFlagValidationType::Imported { valid_flags |= cl_bitfield::from( CL_MEM_USE_HOST_PTR | CL_MEM_ALLOC_HOST_PTR @@ -49,7 +58,10 @@ fn validate_mem_flags(flags: cl_mem_flags, import: bool) -> CLResult<()> { Ok(()) } -fn validate_mem_flags_create(flags: cl_mem_flags, import: bool) -> CLResult<()> { +fn validate_mem_flags_create( + flags: cl_mem_flags, + validation: MemFlagValidationType, +) -> CLResult<()> { let read_write_group = cl_bitfield::from(CL_MEM_READ_WRITE | CL_MEM_WRITE_ONLY | CL_MEM_READ_ONLY); @@ -68,7 +80,7 @@ fn validate_mem_flags_create(flags: cl_mem_flags, import: bool) -> CLResult<()> return Err(CL_INVALID_VALUE); } - validate_mem_flags(flags, import) + validate_mem_flags(flags, validation) } fn validate_map_flags_common(map_flags: cl_mem_flags) -> CLResult<()> { @@ -288,7 +300,7 @@ fn create_buffer_with_properties( let c = Context::arc_from_raw(context)?; // CL_INVALID_VALUE if values specified in flags are not valid as defined in the Memory Flags table. - validate_mem_flags_create(flags, false)?; + validate_mem_flags_create(flags, MemFlagValidationType::Plain)?; // CL_INVALID_BUFFER_SIZE if size is 0 if size == 0 { @@ -365,7 +377,7 @@ fn create_sub_buffer( validate_matching_buffer_flags(&b, flags)?; flags = inherit_mem_flags(flags, &b); - validate_mem_flags_create(flags, false)?; + validate_mem_flags_create(flags, MemFlagValidationType::Plain)?; let (offset, size) = match buffer_create_type { CL_BUFFER_CREATE_TYPE_REGION => { @@ -818,7 +830,7 @@ fn create_image_with_properties( flags = CL_MEM_READ_WRITE.into(); } - validate_mem_flags_create(flags, false)?; + validate_mem_flags_create(flags, MemFlagValidationType::Plain)?; let filtered_flags = filter_image_access_flags(flags); // CL_IMAGE_FORMAT_NOT_SUPPORTED if there are no devices in context that support image_format. @@ -919,7 +931,7 @@ fn get_supported_image_formats( let c = Context::ref_from_raw(context)?; // CL_INVALID_VALUE if flags - validate_mem_flags(flags, false)?; + validate_mem_flags(flags, MemFlagValidationType::Plain)?; // or image_type are not valid if !image_type_valid(image_type) { @@ -3054,7 +3066,7 @@ fn create_from_gl( // CL_INVALID_VALUE if values specified in flags are not valid or if value specified in // texture_target is not one of the values specified in the description of texture_target. - validate_mem_flags_create(flags, true)?; + validate_mem_flags_create(flags, MemFlagValidationType::Imported)?; // CL_INVALID_MIP_LEVEL if miplevel is greather than zero and the OpenGL // implementation does not support creating from non-zero mipmap levels.