diff --git a/src/gallium/frontends/rusticl/api/icd.rs b/src/gallium/frontends/rusticl/api/icd.rs index ce7dae85065..15337c68d15 100644 --- a/src/gallium/frontends/rusticl/api/icd.rs +++ b/src/gallium/frontends/rusticl/api/icd.rs @@ -235,8 +235,46 @@ pub trait ReferenceCountedAPIPointer { fn from_ptr(ptr: *const T) -> Self; } -pub trait CLObject<'a, const ERR: i32, CL: ReferenceCountedAPIPointer + 'a>: +pub trait BaseCLObject<'a, const ERR: i32, CL: ReferenceCountedAPIPointer + 'a>: Sized +{ + fn ref_from_raw(obj: CL) -> CLResult<&'a Self> { + let obj = obj.get_ptr()?; + // SAFETY: `get_ptr` already checks if it's one of our pointers and not null + Ok(unsafe { &*obj }) + } + + fn refs_from_arr(objs: *const CL, count: u32) -> CLResult> + where + CL: Copy, + { + // CL spec requires validation for obj arrays, both values have to make sense + if objs.is_null() && count > 0 || !objs.is_null() && count == 0 { + return Err(CL_INVALID_VALUE); + } + + let mut res = Vec::new(); + if objs.is_null() || count == 0 { + return Ok(res); + } + + for i in 0..count as usize { + res.push(Self::ref_from_raw(unsafe { *objs.add(i) })?); + } + Ok(res) + } +} + +pub trait CLObject<'a, const ERR: i32, CL: ReferenceCountedAPIPointer + 'a>: + Sized + BaseCLObject<'a, ERR, CL> +{ + fn as_cl(&self) -> CL { + CL::from_ptr(self) + } +} + +pub trait ArcedCLObject<'a, const ERR: i32, CL: ReferenceCountedAPIPointer + 'a>: + Sized + BaseCLObject<'a, ERR, CL> { /// Note: this operation increases the internal ref count as `ref_from_raw` is the better option /// when an Arc is not needed. @@ -281,30 +319,8 @@ pub trait CLObject<'a, const ERR: i32, CL: ReferenceCountedAPIPointer Ok(res as u32) } - fn ref_from_raw(obj: CL) -> CLResult<&'a Self> { - let obj = obj.get_ptr()?; - // SAFETY: `get_ptr` already checks if it's one of our pointers and not null - Ok(unsafe { &*obj }) - } - - fn refs_from_arr(objs: *const CL, count: u32) -> CLResult> - where - CL: Copy, - { - // CL spec requires validation for obj arrays, both values have to make sense - if objs.is_null() && count > 0 || !objs.is_null() && count == 0 { - return Err(CL_INVALID_VALUE); - } - - let mut res = Vec::new(); - if objs.is_null() || count == 0 { - return Ok(res); - } - - for i in 0..count as usize { - res.push(Self::ref_from_raw(unsafe { *objs.add(i) })?); - } - Ok(res) + fn into_cl(self: Arc) -> CL { + CL::from_ptr(Arc::into_raw(self)) } fn release(ptr: CL) -> CLResult<()> { @@ -320,15 +336,11 @@ pub trait CLObject<'a, const ERR: i32, CL: ReferenceCountedAPIPointer unsafe { Arc::increment_strong_count(ptr) }; Ok(()) } - - fn into_cl(self: Arc) -> CL { - CL::from_ptr(Arc::into_raw(self)) - } } #[macro_export] -macro_rules! impl_cl_type_trait { - ($cl: ident, $t: ident, $err: ident, $($field:ident).+) => { +macro_rules! impl_cl_type_trait_base { + (@BASE $cl: ident, $t: ident, $err: ident, $($field:ident).+) => { impl $crate::api::icd::ReferenceCountedAPIPointer<$t, $err> for $cl { fn get_ptr(&self) -> CLResult<*const $t> { type Base = $crate::api::icd::CLObjectBase<$err>; @@ -359,7 +371,7 @@ macro_rules! impl_cl_type_trait { } } - impl $crate::api::icd::CLObject<'_, $err, $cl> for $t {} + impl $crate::api::icd::BaseCLObject<'_, $err, $cl> for $t {} // there are two reason to implement those traits for all objects // 1. it speeds up operations @@ -379,6 +391,23 @@ macro_rules! impl_cl_type_trait { } }; + ($cl: ident, $t: ident, $err: ident, $($field:ident).+) => { + $crate::impl_cl_type_trait_base!(@BASE $cl, $t, $err, $($field).+); + impl $crate::api::icd::CLObject<'_, $err, $cl> for $t {} + }; + + ($cl: ident, $t: ident, $err: ident) => { + $crate::impl_cl_type_trait_base!($cl, $t, $err, base); + }; +} + +#[macro_export] +macro_rules! impl_cl_type_trait { + ($cl: ident, $t: ident, $err: ident, $($field:ident).+) => { + $crate::impl_cl_type_trait_base!(@BASE $cl, $t, $err, $($field).+); + impl $crate::api::icd::ArcedCLObject<'_, $err, $cl> for $t {} + }; + ($cl: ident, $t: ident, $err: ident) => { $crate::impl_cl_type_trait!($cl, $t, $err, base); }; diff --git a/src/gallium/frontends/rusticl/api/util.rs b/src/gallium/frontends/rusticl/api/util.rs index c16582cbf00..903457678ab 100644 --- a/src/gallium/frontends/rusticl/api/util.rs +++ b/src/gallium/frontends/rusticl/api/util.rs @@ -1,4 +1,4 @@ -use crate::api::icd::{CLObject, CLResult}; +use crate::api::icd::{ArcedCLObject, CLResult}; use crate::api::types::*; use crate::core::event::*; use crate::core::queue::*;