diff --git a/.pick_status.json b/.pick_status.json index 95304eb483a..b1a0f83f4f4 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -894,7 +894,7 @@ "description": "rusticl/event: use Weak refs for dependencies", "nominated": true, "nomination_type": 0, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": null, "notes": null diff --git a/src/gallium/frontends/rusticl/core/event.rs b/src/gallium/frontends/rusticl/core/event.rs index 6389d6de831..bc1284750d1 100644 --- a/src/gallium/frontends/rusticl/core/event.rs +++ b/src/gallium/frontends/rusticl/core/event.rs @@ -15,6 +15,7 @@ use std::sync::Arc; use std::sync::Condvar; use std::sync::Mutex; use std::sync::MutexGuard; +use std::sync::Weak; use std::time::Duration; // we assert that those are a continous range of numbers so we won't have to use HashMaps @@ -48,7 +49,8 @@ pub struct Event { pub context: Arc, pub queue: Option>, pub cmd_type: cl_command_type, - pub deps: Vec>, + // using a Weak ref so we don't cause stack overflows in the `drop` impl + pub deps: Vec>, state: Mutex, cv: Condvar, } @@ -66,6 +68,7 @@ impl Event { deps: Vec>, work: EventSig, ) -> Arc { + let deps = deps.iter().map(Arc::downgrade).collect(); Arc::new(Self { base: CLObjectBase::new(), context: queue.context.clone(), @@ -238,14 +241,18 @@ impl Event { } } - fn deep_unflushed_deps_impl<'a>(&'a self, result: &mut HashSet<&'a Event>) { + pub fn deps(&self) -> impl Iterator> + Clone + '_ { + self.deps.iter().filter_map(Weak::upgrade) + } + + fn deep_unflushed_deps_impl(self: &Arc, result: &mut HashSet>) { if self.status() <= CL_SUBMITTED as i32 { return; } // only scan dependencies if it's a new one - if result.insert(self) { - for e in &self.deps { + if result.insert(Arc::clone(self)) { + for e in self.deps() { e.deep_unflushed_deps_impl(result); } } @@ -253,7 +260,7 @@ impl Event { /// does a deep search and returns a list of all dependencies including `events` which haven't /// been flushed out yet - pub fn deep_unflushed_deps(events: &[Arc]) -> HashSet<&Event> { + pub fn deep_unflushed_deps(events: &[Arc]) -> HashSet> { let mut result = HashSet::new(); for e in events { diff --git a/src/gallium/frontends/rusticl/core/queue.rs b/src/gallium/frontends/rusticl/core/queue.rs index b1c37c1767a..2b87d6fb899 100644 --- a/src/gallium/frontends/rusticl/core/queue.rs +++ b/src/gallium/frontends/rusticl/core/queue.rs @@ -127,16 +127,17 @@ impl Queue { let mut flushed = Vec::new(); for e in new_events { + let deps_iter = e.deps(); + // If we hit any deps from another queue, flush so we don't risk a dead - // lock. - if e.deps.iter().any(|ev| ev.queue != e.queue) { + // lock. Also clone the iter here as we'll iterate again later + if deps_iter.clone().any(|ev| ev.queue != e.queue) { + // this flush _has_ to happen before we wait on any of the deps flush_events(&mut flushed, &ctx); } // We have to wait on user events or events from other queues. - let err = e - .deps - .iter() + let err = deps_iter .filter(|ev| ev.is_user() || ev.queue != e.queue) .map(|e| e.wait()) .find(|s| *s < 0);