From 45b9fdb4e5178d2d84fb04223be69241d17c6115 Mon Sep 17 00:00:00 2001 From: Karol Herbst Date: Sun, 28 Jan 2024 16:16:07 +0100 Subject: [PATCH] rusticl/icd: move refcnt() and get rid of needless atomic ops The old impl used `get_arc` which internally calls into `Arc::increment_strong_count` in order to protect against Arc::drop deallocating our objects. We could also just not do that :) Part-of: --- src/gallium/frontends/rusticl/api/context.rs | 2 +- src/gallium/frontends/rusticl/api/event.rs | 2 +- src/gallium/frontends/rusticl/api/icd.rs | 14 ++++++++++---- src/gallium/frontends/rusticl/api/kernel.rs | 2 +- src/gallium/frontends/rusticl/api/memory.rs | 4 ++-- src/gallium/frontends/rusticl/api/program.rs | 2 +- src/gallium/frontends/rusticl/api/queue.rs | 2 +- 7 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/gallium/frontends/rusticl/api/context.rs b/src/gallium/frontends/rusticl/api/context.rs index 07206239875..d60c5182ebe 100644 --- a/src/gallium/frontends/rusticl/api/context.rs +++ b/src/gallium/frontends/rusticl/api/context.rs @@ -34,7 +34,7 @@ impl CLInfo for cl_context { ), CL_CONTEXT_NUM_DEVICES => cl_prop::(ctx.devs.len() as u32), CL_CONTEXT_PROPERTIES => cl_prop::<&Properties>(&ctx.properties), - CL_CONTEXT_REFERENCE_COUNT => cl_prop::(self.refcnt()?), + CL_CONTEXT_REFERENCE_COUNT => cl_prop::(Context::refcnt(*self)?), // CL_INVALID_VALUE if param_name is not one of the supported values _ => return Err(CL_INVALID_VALUE), }) diff --git a/src/gallium/frontends/rusticl/api/event.rs b/src/gallium/frontends/rusticl/api/event.rs index 1b50958bee0..4d696b476a1 100644 --- a/src/gallium/frontends/rusticl/api/event.rs +++ b/src/gallium/frontends/rusticl/api/event.rs @@ -32,7 +32,7 @@ impl CLInfo for cl_event { }; cl_prop::(cl_command_queue::from_ptr(ptr)) } - CL_EVENT_REFERENCE_COUNT => cl_prop::(self.refcnt()?), + CL_EVENT_REFERENCE_COUNT => cl_prop::(Event::refcnt(*self)?), CL_EVENT_COMMAND_TYPE => cl_prop::(event.cmd_type), _ => return Err(CL_INVALID_VALUE), }) diff --git a/src/gallium/frontends/rusticl/api/icd.rs b/src/gallium/frontends/rusticl/api/icd.rs index c64855d7366..b3875e0514a 100644 --- a/src/gallium/frontends/rusticl/api/icd.rs +++ b/src/gallium/frontends/rusticl/api/icd.rs @@ -299,15 +299,21 @@ pub trait ReferenceCountedAPIPointer { Ok(()) } } - - fn refcnt(&self) -> CLResult { - Ok((Arc::strong_count(&self.get_arc()?) - 1) as u32) - } } pub trait CLObject<'a, const ERR: i32, CL: ReferenceCountedAPIPointer + 'a>: Sized { + fn refcnt(ptr: CL) -> CLResult { + let ptr = ptr.get_ptr()?; + // SAFETY: `get_ptr` already checks if it's one of our pointers. + let arc = unsafe { Arc::from_raw(ptr) }; + let res = Arc::strong_count(&arc); + // leak the arc again, so we don't reduce the refcount by dropping `arc` + let _ = Arc::into_raw(arc); + Ok(res as u32) + } + fn refs_from_arr(objs: *const CL, count: u32) -> CLResult> { // CL spec requires validation for obj arrays, both values have to make sense if objs.is_null() && count > 0 || !objs.is_null() && count == 0 { diff --git a/src/gallium/frontends/rusticl/api/kernel.rs b/src/gallium/frontends/rusticl/api/kernel.rs index 89496b788e0..40f29cda873 100644 --- a/src/gallium/frontends/rusticl/api/kernel.rs +++ b/src/gallium/frontends/rusticl/api/kernel.rs @@ -33,7 +33,7 @@ impl CLInfo for cl_kernel { let ptr = Arc::as_ptr(&kernel.prog); cl_prop::(cl_program::from_ptr(ptr)) } - CL_KERNEL_REFERENCE_COUNT => cl_prop::(self.refcnt()?), + CL_KERNEL_REFERENCE_COUNT => cl_prop::(Kernel::refcnt(*self)?), // CL_INVALID_VALUE if param_name is not one of the supported values _ => return Err(CL_INVALID_VALUE), }) diff --git a/src/gallium/frontends/rusticl/api/memory.rs b/src/gallium/frontends/rusticl/api/memory.rs index 4caf5ceb849..d2ce93b2f45 100644 --- a/src/gallium/frontends/rusticl/api/memory.rs +++ b/src/gallium/frontends/rusticl/api/memory.rs @@ -235,7 +235,7 @@ impl CLInfo for cl_mem { CL_MEM_HOST_PTR => cl_prop::<*mut c_void>(mem.host_ptr), CL_MEM_OFFSET => cl_prop::(mem.offset), CL_MEM_PROPERTIES => cl_prop::<&Vec>(&mem.props), - CL_MEM_REFERENCE_COUNT => cl_prop::(self.refcnt()?), + CL_MEM_REFERENCE_COUNT => cl_prop::(Mem::refcnt(*self)?), CL_MEM_SIZE => cl_prop::(mem.size), CL_MEM_TYPE => cl_prop::(mem.mem_type), CL_MEM_USES_SVM_POINTER | CL_MEM_USES_SVM_POINTER_ARM => { @@ -908,7 +908,7 @@ impl CLInfo for cl_sampler { } CL_SAMPLER_FILTER_MODE => cl_prop::(sampler.filter_mode), CL_SAMPLER_NORMALIZED_COORDS => cl_prop::(sampler.normalized_coords), - CL_SAMPLER_REFERENCE_COUNT => cl_prop::(self.refcnt()?), + CL_SAMPLER_REFERENCE_COUNT => cl_prop::(Sampler::refcnt(*self)?), CL_SAMPLER_PROPERTIES => { cl_prop::<&Option>>(&sampler.props) } diff --git a/src/gallium/frontends/rusticl/api/program.rs b/src/gallium/frontends/rusticl/api/program.rs index b4f40edcf25..2a90309219e 100644 --- a/src/gallium/frontends/rusticl/api/program.rs +++ b/src/gallium/frontends/rusticl/api/program.rs @@ -46,7 +46,7 @@ impl CLInfo for cl_program { CL_PROGRAM_KERNEL_NAMES => cl_prop::<&str>(&*prog.kernels().join(";")), CL_PROGRAM_NUM_DEVICES => cl_prop::(prog.devs.len() as cl_uint), CL_PROGRAM_NUM_KERNELS => cl_prop::(prog.kernels().len()), - CL_PROGRAM_REFERENCE_COUNT => cl_prop::(self.refcnt()?), + CL_PROGRAM_REFERENCE_COUNT => cl_prop::(Program::refcnt(*self)?), CL_PROGRAM_SCOPE_GLOBAL_CTORS_PRESENT => cl_prop::(CL_FALSE), CL_PROGRAM_SCOPE_GLOBAL_DTORS_PRESENT => cl_prop::(CL_FALSE), CL_PROGRAM_SOURCE => match &prog.src { diff --git a/src/gallium/frontends/rusticl/api/queue.rs b/src/gallium/frontends/rusticl/api/queue.rs index ff1e09a00d7..d2b9d0a2df2 100644 --- a/src/gallium/frontends/rusticl/api/queue.rs +++ b/src/gallium/frontends/rusticl/api/queue.rs @@ -30,7 +30,7 @@ impl CLInfo for cl_command_queue { CL_QUEUE_PROPERTIES_ARRAY => { cl_prop::<&Option>>(&queue.props_v2) } - CL_QUEUE_REFERENCE_COUNT => cl_prop::(self.refcnt()?), + CL_QUEUE_REFERENCE_COUNT => cl_prop::(Queue::refcnt(*self)?), // clGetCommandQueueInfo, passing CL_QUEUE_SIZE Returns CL_INVALID_COMMAND_QUEUE since // command_queue cannot be a valid device command-queue. CL_QUEUE_SIZE => return Err(CL_INVALID_COMMAND_QUEUE),