diff --git a/include/hyprutils/memory/Atomic.hpp b/include/hyprutils/memory/Atomic.hpp index 531e0d4..de0753e 100644 --- a/include/hyprutils/memory/Atomic.hpp +++ b/include/hyprutils/memory/Atomic.hpp @@ -23,12 +23,11 @@ namespace Hyprutils::Memory { namespace Atomic_ { - template - class impl : public Impl_::impl { + class impl : public Impl_::impl_base { std::recursive_mutex m_mutex; public: - impl(T* data, bool lock = true) noexcept : Impl_::impl(data, lock) { + impl(void* data, DeleteFn deleter) noexcept : Impl_::impl_base(data, deleter) { ; } @@ -55,7 +54,13 @@ namespace Hyprutils::Memory { using validHierarchy = std::enable_if_t&, X>, CAtomicSharedPointer&>; public: - explicit CAtomicSharedPointer(Impl_::impl_base* impl) noexcept : m_ptr(impl) {} + explicit CAtomicSharedPointer(T* object) noexcept : m_ptr(new Atomic_::impl(sc(object), _delete)) { + ; + } + + CAtomicSharedPointer(Impl_::impl_base* impl) noexcept : m_ptr(impl) { + ; + } CAtomicSharedPointer(const CAtomicSharedPointer& ref) { if (!ref.m_ptr.impl_) @@ -141,7 +146,7 @@ namespace Hyprutils::Memory { // -> must unlock BEFORE reset // not last ref? // -> must unlock AFTER reset - auto& mutex = sc*>(m_ptr.impl_)->getMutex(); + auto& mutex = sc(m_ptr.impl_)->getMutex(); mutex.lock(); if (m_ptr.impl_->ref() > 1) { @@ -152,7 +157,11 @@ namespace Hyprutils::Memory { if (m_ptr.impl_->wref() == 0) { mutex.unlock(); // Don't hold the mutex when destroying it - m_ptr.reset(); + + m_ptr.impl_->destroy(); + delete sc(m_ptr.impl_); + m_ptr.impl_ = nullptr; + // mutex invalid return; } else { @@ -163,12 +172,18 @@ namespace Hyprutils::Memory { // To avoid this altogether, keep a weak pointer here. // This guarantees that impl_ is still valid after the reset. CWeakPointer guard = m_ptr; - m_ptr.reset(); + m_ptr.reset(); // destroys the data // Now we can safely check if guard is the last wref. if (guard.impl_->wref() == 1) { mutex.unlock(); - return; // ~guard destroys impl_ and mutex + + // destroy impl_ (includes the mutex) + delete sc(guard.impl_); + guard.impl_ = nullptr; + + // mutex invalid + return; } guard.reset(); @@ -205,8 +220,12 @@ namespace Hyprutils::Memory { } private: + static void _delete(void* p) { + std::default_delete{}(sc(p)); + } + std::lock_guard implLockGuard() const { - return sc*>(m_ptr.impl_)->lockGuard(); + return sc(m_ptr.impl_)->lockGuard(); } CSharedPointer m_ptr; @@ -312,11 +331,13 @@ namespace Hyprutils::Memory { // -> must unlock BEFORE reset // not last ref? // -> must unlock AFTER reset - auto& mutex = sc*>(m_ptr.impl_)->getMutex(); + auto& mutex = sc(m_ptr.impl_)->getMutex(); mutex.lock(); if (m_ptr.impl_->ref() == 0 && m_ptr.impl_->wref() == 1) { mutex.unlock(); - m_ptr.reset(); + + delete sc(m_ptr.impl_); + m_ptr.impl_ = nullptr; // mutex invalid return; } @@ -375,7 +396,7 @@ namespace Hyprutils::Memory { private: std::lock_guard implLockGuard() const { - return sc*>(m_ptr.impl_)->lockGuard(); + return sc(m_ptr.impl_)->lockGuard(); } CWeakPointer m_ptr; @@ -387,7 +408,7 @@ namespace Hyprutils::Memory { }; template - static CAtomicSharedPointer makeAtomicShared(Args&&... args) { - return CAtomicSharedPointer(new Atomic_::impl(new U(std::forward(args)...))); + [[nodiscard]] inline CAtomicSharedPointer makeAtomicShared(Args&&... args) { + return CAtomicSharedPointer(new U(std::forward(args)...)); } } diff --git a/include/hyprutils/memory/ImplBase.hpp b/include/hyprutils/memory/ImplBase.hpp index 8994768..7f988e4 100644 --- a/include/hyprutils/memory/ImplBase.hpp +++ b/include/hyprutils/memory/ImplBase.hpp @@ -8,44 +8,71 @@ namespace Hyprutils { namespace Impl_ { class impl_base { public: - virtual ~impl_base() = default; + using DeleteFn = void (*)(void*); - virtual void inc() noexcept = 0; - virtual void dec() noexcept = 0; - virtual void incWeak() noexcept = 0; - virtual void decWeak() noexcept = 0; - virtual unsigned int ref() noexcept = 0; - virtual unsigned int wref() noexcept = 0; - virtual void destroy() noexcept = 0; - virtual bool destroying() noexcept = 0; - virtual bool dataNonNull() noexcept = 0; - virtual bool lockable() noexcept = 0; - virtual void* getData() noexcept = 0; - }; - - template - class impl : public impl_base { - public: - impl(T* data, bool lock = true) noexcept : _lockable(lock), _data(data) { + impl_base(void* data, DeleteFn deleter, bool lock = true) noexcept : _lockable(lock), _data(data), _deleter(deleter) { ; } + void inc() noexcept { + _ref++; + } + + void dec() noexcept { + _ref--; + } + + void incWeak() noexcept { + _weak++; + } + + void decWeak() noexcept { + _weak--; + } + + unsigned int ref() noexcept { + return _ref; + } + + unsigned int wref() noexcept { + return _weak; + } + + void destroy() noexcept { + _destroy(); + } + + bool destroying() noexcept { + return _destroying; + } + + bool lockable() noexcept { + return _lockable; + } + + bool dataNonNull() noexcept { + return _data != nullptr; + } + + void* getData() noexcept { + return _data; + } + + ~impl_base() { + destroy(); + } + + private: /* strong refcount */ unsigned int _ref = 0; /* weak refcount */ unsigned int _weak = 0; /* if this is lockable (shared) */ - bool _lockable = true; + bool _lockable = true; - T* _data = nullptr; + void* _data = nullptr; - friend void swap(impl*& a, impl*& b) { - impl* tmp = a; - a = b; - b = tmp; - } - - /* if the destructor was called, + /* if the destructor was called, creating shared_ptrs is no longer valid */ bool _destroying = false; @@ -63,56 +90,7 @@ namespace Hyprutils { _destroying = false; } - std::default_delete _deleter{}; - - // - virtual void inc() noexcept { - _ref++; - } - - virtual void dec() noexcept { - _ref--; - } - - virtual void incWeak() noexcept { - _weak++; - } - - virtual void decWeak() noexcept { - _weak--; - } - - virtual unsigned int ref() noexcept { - return _ref; - } - - virtual unsigned int wref() noexcept { - return _weak; - } - - virtual void destroy() noexcept { - _destroy(); - } - - virtual bool destroying() noexcept { - return _destroying; - } - - virtual bool lockable() noexcept { - return _lockable; - } - - virtual bool dataNonNull() noexcept { - return _data != nullptr; - } - - virtual void* getData() noexcept { - return _data; - } - - virtual ~impl() { - destroy(); - } + DeleteFn _deleter = nullptr; }; } } diff --git a/include/hyprutils/memory/SharedPtr.hpp b/include/hyprutils/memory/SharedPtr.hpp index dcdd7da..bd1284d 100644 --- a/include/hyprutils/memory/SharedPtr.hpp +++ b/include/hyprutils/memory/SharedPtr.hpp @@ -28,7 +28,7 @@ namespace Hyprutils { /* creates a new shared pointer managing a resource avoid calling. Could duplicate ownership. Prefer makeShared */ - explicit CSharedPointer(T* object) noexcept : impl_(new Impl_::impl(object)) { + explicit CSharedPointer(T* object) noexcept : impl_(new Impl_::impl_base(sc(object), _delete)) { increment(); } @@ -140,6 +140,10 @@ namespace Hyprutils { Impl_::impl_base* impl_ = nullptr; private: + static void _delete(void* p) { + std::default_delete{}(sc(p)); + } + /* no-op if there is no impl_ may delete the stored object if ref == 0 diff --git a/include/hyprutils/memory/UniquePtr.hpp b/include/hyprutils/memory/UniquePtr.hpp index 7fdd7eb..64f14e2 100644 --- a/include/hyprutils/memory/UniquePtr.hpp +++ b/include/hyprutils/memory/UniquePtr.hpp @@ -22,7 +22,7 @@ namespace Hyprutils { /* creates a new unique pointer managing a resource avoid calling. Could duplicate ownership. Prefer makeUnique */ - explicit CUniquePointer(T* object) noexcept : impl_(new Impl_::impl(object, false)) { + explicit CUniquePointer(T* object) noexcept : impl_(new Impl_::impl_base(sc(object), [](void* p) { std::default_delete{}(sc(p)); }, false)) { increment(); } diff --git a/tests/memory.cpp b/tests/memory.cpp index 8f33ca8..fa27c01 100644 --- a/tests/memory.cpp +++ b/tests/memory.cpp @@ -92,6 +92,38 @@ static int testAtomicImpl() { foo->bar = foo; } + { // This tests destroying the data when storing the base class of a type + class ITest { + public: + size_t num = 0; + ITest() : num(1234) {}; + }; + + class CA : public ITest { + public: + size_t num2 = 0; + CA() : ITest(), num2(4321) {}; + }; + + class CB : public ITest { + public: + int num2 = 0; + CB() : ITest(), num2(-1) {}; + }; + + ASP genericAtomic = nullptr; + SP genericNormal = nullptr; + { + auto derivedAtomic = makeAtomicShared(); + auto derivedNormal = makeShared(); + genericAtomic = derivedAtomic; + genericNormal = derivedNormal; + } + + EXPECT(!!genericAtomic, true); + EXPECT(!!genericNormal, true); + } + return ret; } @@ -113,6 +145,7 @@ int main(int argc, char** argv, char** envp) { EXPECT(intPtr.strongRef(), 1); EXPECT(*weak, 10); EXPECT(weak.expired(), false); + EXPECT(!!weak.lock(), true); EXPECT(*weakUnique, 420); EXPECT(weakUnique.expired(), false); EXPECT(intUnique.impl_->wref(), 1); @@ -149,6 +182,5 @@ int main(int argc, char** argv, char** envp) { EXPECT(*intPtr2, 10); EXPECT(testAtomicImpl(), 0); - return ret; }