diff --git a/src/gallium/frontends/rusticl/api/memory.rs b/src/gallium/frontends/rusticl/api/memory.rs index 949e6c606c5..59b23e6411e 100644 --- a/src/gallium/frontends/rusticl/api/memory.rs +++ b/src/gallium/frontends/rusticl/api/memory.rs @@ -290,7 +290,7 @@ fn create_buffer_with_properties( let diff = unsafe { host_ptr.offset_from(svm_ptr) } as usize; // technically we don't have to account for the offset, but it's almost for free. - if size > svm_layout.size() - diff { + if size > svm_layout - diff { return Err(CL_INVALID_BUFFER_SIZE); } } @@ -2929,16 +2929,16 @@ fn enqueue_svm_migrate_mem( // CL_INVALID_VALUE if sizes[i] is non-zero range [svm_pointers[i], svm_pointers[i]+sizes[i]) is // not contained within an existing clSVMAlloc allocation. for (ptr, size) in svm_pointers.iter_mut().zip(&mut sizes) { - if let Some((alloc, layout)) = q.context.find_svm_alloc(*ptr) { + if let Some((alloc, alloc_size)) = q.context.find_svm_alloc(*ptr) { let ptr_addr = *ptr; let alloc_addr = alloc as usize; // if the offset + size is bigger than the allocation we are out of bounds - if (ptr_addr - alloc_addr) + *size <= layout.size() { + if (ptr_addr - alloc_addr) + *size <= alloc_size { // if the size is 0, the entire allocation should be migrated if *size == 0 { *ptr = alloc as usize; - *size = layout.size(); + *size = alloc_size; } continue; } diff --git a/src/gallium/frontends/rusticl/core/context.rs b/src/gallium/frontends/rusticl/core/context.rs index 1aa2f30ea10..fb5575fc0f8 100644 --- a/src/gallium/frontends/rusticl/core/context.rs +++ b/src/gallium/frontends/rusticl/core/context.rs @@ -11,10 +11,10 @@ use mesa_rust::pipe::resource::*; use mesa_rust::pipe::screen::ResourceType; use mesa_rust_gen::*; use mesa_rust_util::properties::Properties; +use mesa_rust_util::ptr::TrackedPointers; use rusticl_opencl_gen::*; use std::alloc::Layout; -use std::collections::BTreeMap; use std::collections::HashMap; use std::convert::TryInto; use std::mem; @@ -27,7 +27,7 @@ pub struct Context { pub devs: Vec<&'static Device>, pub properties: Properties, pub dtors: Mutex>, - pub svm_ptrs: Mutex>, + svm_ptrs: Mutex>, pub gl_ctx_manager: Option, } @@ -44,7 +44,7 @@ impl Context { devs: devs, properties: properties, dtors: Mutex::new(Vec::new()), - svm_ptrs: Mutex::new(BTreeMap::new()), + svm_ptrs: Mutex::new(TrackedPointers::new()), gl_ctx_manager: gl_ctx_manager, }) } @@ -191,20 +191,12 @@ impl Context { self.svm_ptrs.lock().unwrap().insert(ptr, layout); } - pub fn find_svm_alloc(&self, ptr: usize) -> Option<(*const c_void, Layout)> { - let lock = self.svm_ptrs.lock().unwrap(); - if let Some((&base, layout)) = lock.range(..=ptr).next_back() { - // SAFETY: we really just do some pointer math here... - unsafe { - let base = base as *const c_void; - // we check if ptr is within [base..base+size) - // means we can check if ptr - (base + size) < 0 - if ptr < (base.add(layout.size()) as usize) { - return Some((base, *layout)); - } - } - } - None + pub fn find_svm_alloc(&self, ptr: usize) -> Option<(*const c_void, usize)> { + self.svm_ptrs + .lock() + .unwrap() + .find_alloc(ptr) + .map(|(ptr, layout)| (ptr as *const c_void, layout.size())) } pub fn remove_svm_ptr(&self, ptr: usize) -> Option { diff --git a/src/gallium/frontends/rusticl/util/ptr.rs b/src/gallium/frontends/rusticl/util/ptr.rs index 10347e3fd4f..32b43883b05 100644 --- a/src/gallium/frontends/rusticl/util/ptr.rs +++ b/src/gallium/frontends/rusticl/util/ptr.rs @@ -1,7 +1,9 @@ use std::{ + alloc::Layout, + collections::BTreeMap, hash::{Hash, Hasher}, mem, - ops::Deref, + ops::{Add, Deref}, ptr::{self, NonNull}, }; @@ -133,3 +135,56 @@ pub const fn addr(ptr: *const T) -> usize { mem::transmute(ptr.cast::<()>()) } } + +pub trait AllocSize

{ + fn size(&self) -> P; +} + +impl AllocSize for Layout { + fn size(&self) -> usize { + Self::size(self) + } +} + +pub struct TrackedPointers> { + ptrs: BTreeMap, +} + +impl> TrackedPointers { + pub fn new() -> Self { + Self { + ptrs: BTreeMap::new(), + } + } +} + +impl> TrackedPointers +where + P: Ord + Add + Copy, +{ + pub fn find_alloc(&self, ptr: P) -> Option<(P, &T)> { + if let Some((&base, val)) = self.ptrs.range(..=ptr).next_back() { + let size = val.size(); + // we check if ptr is within [base..base+size) + // means we can check if ptr - (base + size) < 0 + if ptr < (base + size) { + return Some((base, val)); + } + } + None + } + + pub fn insert(&mut self, ptr: P, val: T) -> Option { + self.ptrs.insert(ptr, val) + } + + pub fn remove(&mut self, ptr: &P) -> Option { + self.ptrs.remove(ptr) + } +} + +impl> Default for TrackedPointers { + fn default() -> Self { + Self::new() + } +}