diff --git a/.gitignore b/.gitignore index 041d4c5..f5c0222 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,7 @@ build/ .vscode/ .cache/ +.direnv/ .cmake/ CMakeCache.txt @@ -44,3 +45,4 @@ Makefile cmake_install.cmake compile_commands.json hyprutils.pc +.envrc diff --git a/include/hyprutils/memory/ImplAtomic.hpp b/include/hyprutils/memory/ImplAtomic.hpp new file mode 100644 index 0000000..3105536 --- /dev/null +++ b/include/hyprutils/memory/ImplAtomic.hpp @@ -0,0 +1,144 @@ +#pragma once + +#include "./ImplBase.hpp" +#include + +/* + Impl for Hyprutils shared pointers that does thread-safe reference counting. + Instead of using atomic counters (e.g. std::atomic), this implementation just uses a mutex. + That helps to keep the implementation simple and avoids compare_exchange loops. + + The implementation in ImplBase.hpp is perferred for single threaded contexts as it will be a bit faster than this one. + + Keep in mind that this implementation only ensures thread-safe reference counting of impl_. It does not protect the data itself. + Doing a lock() in multithreaded context will guaranty that the data still exists. + However, checking valid() on a weakpointer in a multithreaded context and accessing the data in the next line without any synchronization is not safe. +*/ + +namespace Hyprutils::Memory::Impl_ { + template + class CAtomicImpl : public impl_base { + public: + CAtomicImpl(T* data, bool lock = true) noexcept : _lockable(lock), _data(data) { + ; + } + + CAtomicImpl(const CAtomicImpl&) = delete; + CAtomicImpl(CAtomicImpl&&) = delete; + + /* strong refcount */ + unsigned int _ref = 1; + /* weak refcount */ + unsigned int _weak = 0; + /* if this is lockable (shared) */ + bool _lockable = true; + + std::mutex _mtx; + + T* _data = nullptr; + + friend void swap(CAtomicImpl*& a, CAtomicImpl*& b) noexcept { + CAtomicImpl* tmp = a; + a = b; + b = tmp; + } + + /* if the destructor was called, + creating shared_ptrs is no longer valid */ + bool _destroying = false; + + void _destroy() { + if (!_data || _destroying) + return; + + // first, we destroy the data, but keep the pointer. + // this way, weak pointers will still be able to + // reference and use, but no longer create shared ones. + _destroying = true; + __deleter(_data); + // now, we can reset the data and call it a day. + _data = nullptr; + _destroying = false; + } + + std::default_delete __deleter{}; + + // + virtual bool inc() { + std::lock_guard lg(_mtx); + if (_ref == 0) + return false; + + _ref++; + return true; + } + + virtual bool dec() { + std::lock_guard lg(_mtx); + _ref--; + + if (_ref == 0) { + // if ref == 0, we can destroy impl + _mtx.unlock(); + destroy(); + _mtx.lock(); + // if weak == 0, we tell the actual impl to delete this + return _weak == 0; + } + + return false; + } + + virtual bool incWeak() { + std::lock_guard lg(_mtx); + if (_ref == 0) + return false; + + _weak++; + return true; + } + + virtual bool decWeak() { + std::lock_guard lg(_mtx); + _weak--; + + // we need to check for _destroying, + // because otherwise we could destroy here + // and have a shared_ptr destroy the same thing + // later (in situations where we have a weak_ptr to self) + return _ref == 0 && _weak == 0 && !_destroying; + } + + virtual unsigned int ref() noexcept { + return _ref; + } + + virtual unsigned int wref() noexcept { + return _weak; + } + + virtual void destroy() noexcept { + _destroy(); + } + + virtual bool destroying() noexcept { + return _destroying; + } + + virtual bool lockable() noexcept { + return _lockable; + } + + virtual bool dataNonNull() noexcept { + return _data != nullptr; + } + + virtual void* getData() noexcept { + return _data; + } + + virtual ~CAtomicImpl() { + _destroy(); + } + }; +} diff --git a/include/hyprutils/memory/ImplBase.hpp b/include/hyprutils/memory/ImplBase.hpp index 8c32540..180575f 100644 --- a/include/hyprutils/memory/ImplBase.hpp +++ b/include/hyprutils/memory/ImplBase.hpp @@ -5,14 +5,23 @@ namespace Hyprutils { namespace Memory { namespace Impl_ { + // Control block implementation interface for hyprutils smart pointers. class impl_base { public: - virtual ~impl_base() {}; + virtual ~impl_base() = default; + + // If inc returns false, remove the pointer to this impl (but don't delete). + virtual bool inc() = 0; + // If dec returns true, the callee owned the last reference (strong and weak). + // Thus the impl must be deleted and any reference to it must be invalidated. + virtual bool dec() = 0; + + // If incWeak returns false, remove the pointer to this impl (but don't delete). + virtual bool incWeak() = 0; + // If dec returns true, there is no strong ref and the callee owned the last weak reference. + // Thus the impl must be deleted and any reference to it must be invalidated. + virtual bool decWeak() = 0; - virtual void inc() noexcept = 0; - virtual void dec() noexcept = 0; - virtual void incWeak() noexcept = 0; - virtual void decWeak() noexcept = 0; virtual unsigned int ref() noexcept = 0; virtual unsigned int wref() noexcept = 0; virtual void destroy() noexcept = 0; @@ -29,8 +38,11 @@ namespace Hyprutils { ; } + impl(const impl&) = delete; + impl(impl&&) = delete; + /* strong refcount */ - unsigned int _ref = 0; + unsigned int _ref = 1; /* weak refcount */ unsigned int _weak = 0; /* if this is lockable (shared) */ @@ -38,7 +50,7 @@ namespace Hyprutils { T* _data = nullptr; - friend void swap(impl*& a, impl*& b) { + friend void swap(impl*& a, impl*& b) noexcept { impl* tmp = a; a = b; b = tmp; @@ -65,20 +77,43 @@ namespace Hyprutils { std::default_delete __deleter{}; // - virtual void inc() noexcept { + virtual bool inc() { + if (_ref == 0) + return false; + _ref++; + return true; } - virtual void dec() noexcept { + virtual bool dec() { _ref--; + + if (_ref == 0) { + // if ref == 0, we can destroy impl + destroy(); + // if weak == 0, we tell the actual impl to delete this + return _weak == 0; + } + + return false; } - virtual void incWeak() noexcept { + virtual bool incWeak() { + if (_ref == 0) + return false; + _weak++; + return true; } - virtual void decWeak() noexcept { + virtual bool decWeak() { _weak--; + + // we need to check for _destroying, + // because otherwise we could destroy here + // and have a shared_ptr destroy the same thing + // later (in situations where we have a weak_ptr to self) + return _ref == 0 && _weak == 0 && !_destroying; } virtual unsigned int ref() noexcept { @@ -110,7 +145,7 @@ namespace Hyprutils { } virtual ~impl() { - destroy(); + _destroy(); } }; } diff --git a/include/hyprutils/memory/SharedPtr.hpp b/include/hyprutils/memory/SharedPtr.hpp index b86dfb2..2cd8bf7 100644 --- a/include/hyprutils/memory/SharedPtr.hpp +++ b/include/hyprutils/memory/SharedPtr.hpp @@ -2,16 +2,19 @@ #include #include "ImplBase.hpp" +#include "ImplAtomic.hpp" /* This is a custom impl of std::shared_ptr. - It is not thread-safe like the STL one, - but Hyprland is single-threaded anyways. It differs a bit from how the STL one works, namely in the fact that it keeps the T* inside the control block, and that you can still make a CWeakPtr or deref an existing one inside the destructor. + + Hyprutils comes with two different control block implementations. + The default one you get via makeShared is not thread-safe like STL shared pointers. + Use makeAtomicShared get a SharedPointer using a thread-safe control block. */ namespace Hyprutils { @@ -27,20 +30,15 @@ namespace Hyprutils { /* creates a new shared pointer managing a resource avoid calling. Could duplicate ownership. Prefer makeShared */ - explicit CSharedPointer(T* object) noexcept { - impl_ = new Impl_::impl(object); - increment(); - } + explicit CSharedPointer(T* object) noexcept : impl_(new Impl_::impl(object)) {} /* creates a shared pointer from a reference */ template > - CSharedPointer(const CSharedPointer& ref) noexcept { - impl_ = ref.impl_; + CSharedPointer(const CSharedPointer& ref) noexcept : impl_(ref.impl_) { increment(); } - CSharedPointer(const CSharedPointer& ref) noexcept { - impl_ = ref.impl_; + CSharedPointer(const CSharedPointer& ref) noexcept : impl_(ref.impl_) { increment(); } @@ -53,11 +51,8 @@ namespace Hyprutils { std::swap(impl_, ref.impl_); } - /* allows weakPointer to create from an impl */ - CSharedPointer(Impl_::impl_base* implementation) noexcept { - impl_ = implementation; - increment(); - } + /* allows weakPointer to create from an impl; impl_->inc() must be called before using it*/ + explicit CSharedPointer(Impl_::impl_base* implementation) noexcept : impl_(implementation) {} /* creates an empty shared pointer with no implementation */ CSharedPointer() noexcept { @@ -154,31 +149,18 @@ namespace Hyprutils { if (!impl_) return; - impl_->dec(); - - // if ref == 0, we can destroy impl - if (impl_->ref() == 0) - destroyImpl(); + if (impl_->dec()) { + delete impl_; + impl_ = nullptr; + } } /* no-op if there is no impl_ */ void increment() { if (!impl_) return; - impl_->inc(); - } - - /* destroy the pointed-to object - if able, will also destroy impl */ - void destroyImpl() { - // destroy the impl contents - impl_->destroy(); - - // check for weak refs, if zero, we can also delete impl_ - if (impl_->wref() == 0) { - delete impl_; + if (!impl_->inc()) impl_ = nullptr; - } } }; @@ -187,8 +169,17 @@ namespace Hyprutils { return CSharedPointer(new U(std::forward(args)...)); } + // Use instead of makeShared if thread-safe refcounting is desired. + template + static CSharedPointer makeAtomicShared(Args&&... args) { + return CSharedPointer(new Impl_::CAtomicImpl(new U(std::forward(args)...))); + } + template CSharedPointer reinterpretPointerCast(const CSharedPointer& ref) { + if (!ref.impl_->inc()) + return {}; + return CSharedPointer(ref.impl_); } } diff --git a/include/hyprutils/memory/UniquePtr.hpp b/include/hyprutils/memory/UniquePtr.hpp index 8588560..f879563 100644 --- a/include/hyprutils/memory/UniquePtr.hpp +++ b/include/hyprutils/memory/UniquePtr.hpp @@ -21,10 +21,7 @@ namespace Hyprutils { /* creates a new unique pointer managing a resource avoid calling. Could duplicate ownership. Prefer makeUnique */ - explicit CUniquePointer(T* object) noexcept { - impl_ = new Impl_::impl(object, false); - increment(); - } + explicit CUniquePointer(T* object) : impl_(new Impl_::impl(object, false)) {} /* creates a shared pointer from a reference */ template > @@ -106,31 +103,18 @@ namespace Hyprutils { if (!impl_) return; - impl_->dec(); - - // if ref == 0, we can destroy impl - if (impl_->ref() == 0) - destroyImpl(); + if (impl_->dec()) { + delete impl_; + impl_ = nullptr; + } } /* no-op if there is no impl_ */ void increment() { if (!impl_) return; - impl_->inc(); - } - - /* destroy the pointed-to object - if able, will also destroy impl */ - void destroyImpl() { - // destroy the impl contents - impl_->destroy(); - - // check for weak refs, if zero, we can also delete impl_ - if (impl_->wref() == 0) { - delete impl_; + if (!impl_->inc()) impl_ = nullptr; - } } }; @@ -146,4 +130,4 @@ struct std::hash> { std::size_t operator()(const Hyprutils::Memory::CUniquePointer& p) const noexcept { return std::hash{}(p.impl_); } -}; \ No newline at end of file +}; diff --git a/include/hyprutils/memory/WeakPtr.hpp b/include/hyprutils/memory/WeakPtr.hpp index 023bf78..51cb162 100644 --- a/include/hyprutils/memory/WeakPtr.hpp +++ b/include/hyprutils/memory/WeakPtr.hpp @@ -21,39 +21,23 @@ namespace Hyprutils { /* create a weak ptr from a reference */ template > - CWeakPointer(const CSharedPointer& ref) noexcept { - if (!ref.impl_) - return; - - impl_ = ref.impl_; + CWeakPointer(const CSharedPointer& ref) noexcept : impl_(ref.impl_) { incrementWeak(); } /* create a weak ptr from a reference */ template > - CWeakPointer(const CUniquePointer& ref) noexcept { - if (!ref.impl_) - return; - - impl_ = ref.impl_; + CWeakPointer(const CUniquePointer& ref) noexcept : impl_(ref.impl_) { incrementWeak(); } /* create a weak ptr from another weak ptr */ template > - CWeakPointer(const CWeakPointer& ref) noexcept { - if (!ref.impl_) - return; - - impl_ = ref.impl_; + CWeakPointer(const CWeakPointer& ref) noexcept : impl_(ref.impl_) { incrementWeak(); } - CWeakPointer(const CWeakPointer& ref) noexcept { - if (!ref.impl_) - return; - - impl_ = ref.impl_; + CWeakPointer(const CWeakPointer& ref) noexcept : impl_(ref.impl_) { incrementWeak(); } @@ -130,7 +114,10 @@ namespace Hyprutils { } CSharedPointer lock() const { - if (!impl_ || !impl_->dataNonNull() || impl_->destroying() || !impl_->lockable()) + if (!impl_ || !impl_->lockable()) + return {}; + + if (!impl_->inc()) return {}; return CSharedPointer(impl_); @@ -189,13 +176,7 @@ namespace Hyprutils { if (!impl_) return; - impl_->decWeak(); - - // we need to check for ->destroying, - // because otherwise we could destroy here - // and have a shared_ptr destroy the same thing - // later (in situations where we have a weak_ptr to self) - if (impl_->wref() == 0 && impl_->ref() == 0 && !impl_->destroying()) { + if (impl_->decWeak()) { delete impl_; impl_ = nullptr; } @@ -205,7 +186,8 @@ namespace Hyprutils { if (!impl_) return; - impl_->incWeak(); + if (!impl_->incWeak()) + impl_ = nullptr; } }; } diff --git a/prmsg.txt b/prmsg.txt new file mode 100644 index 0000000..a408fd4 --- /dev/null +++ b/prmsg.txt @@ -0,0 +1,17 @@ + +This PR was motivated by https://github.com/hyprwm/hyprlock/pull/799. + +It's goal is to get a thread-safe smart pointer implementation into Hyprutils. + +Vaxry suggested the following: +> No. Instead, make hyprutils pointers thread safe. Wrap a shared ptr into an ARC (just dont name it that cuz rust cring) and make it thread safe with a simple std::mutex + +I first tried to wrap a shared pointer. I have an implementation with a global mutex where that works ok, but it really wasn't ideal and a bit messy. I also never figured out how to not make it a recursive_mutex in that case. + +I experimented with different ways of wrapping the shared pointer, but i wasn't able to come up with a decent implementation. +Then I decided to try to rework the interface to the control block slightly, such that thread safety could be achieved by using a different implementation for the controlblock. +I liked that idea, because it would allow to share the interface for shared and weak pointers for default ones and thread-safe ones. + +First I experimented with atomic counters, but I ditched that and just added a std::mutex to the thread-safe control block implementation. + + diff --git a/tests/memory.cpp b/tests/memory.cpp index 118e8c8..86954e2 100644 --- a/tests/memory.cpp +++ b/tests/memory.cpp @@ -1,6 +1,10 @@ -#include #include +#include +#include #include "shared.hpp" +#include +#include +#include #include using namespace Hyprutils::Memory; @@ -9,6 +13,75 @@ using namespace Hyprutils::Memory; #define WP CWeakPointer #define UP CUniquePointer +#define NTHREADS 8 +#define ITERATIONS 10000 + +static int testAtomicity() { + int ret = 0; + + { + // Using makeShared here could lead to invalid refcounts. + SP shared = makeAtomicShared(0); + std::vector threads; + + threads.reserve(NTHREADS); + for (size_t i = 0; i < NTHREADS; i++) { + threads.emplace_back([shared]() { + for (size_t j = 0; j < ITERATIONS; j++) { + SP strongRef = shared; + (*shared)++; + strongRef.reset(); + } + }); + } + + for (auto& thread : threads) { + thread.join(); + } + + // Actual count is not incremented in a thread-safe manner here, so we can't check it. + // We just want to check that the concurent refcounting doesn't cause any memory corruption. + shared.reset(); + EXPECT(shared, false); + } + + { + SP shared = makeAtomicShared(0); + WP ref = shared; + std::vector threads; + + threads.reserve(NTHREADS); + for (size_t i = 0; i < NTHREADS; i++) { + threads.emplace_back([ref]() { + for (size_t j = 0; j < ITERATIONS; j++) { + if (auto s = ref.lock(); s) { + (*s)++; + } + } + }); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + shared.reset(); + + for (auto& thread : threads) { + thread.join(); + } + + EXPECT(shared.strongRef(), 0); + EXPECT(ref.valid(), false); + + auto shared2 = ref.lock(); + EXPECT(shared, false); + EXPECT(shared2.get(), nullptr); + EXPECT(shared.strongRef(), 0); + EXPECT(ref.valid(), false); + EXPECT(ref.expired(), true); + } + + return ret; +} + int main(int argc, char** argv, char** envp) { SP intPtr = makeShared(10); SP intPtr2 = makeShared(-1337); @@ -62,5 +135,7 @@ int main(int argc, char** argv, char** envp) { EXPECT(*intPtr2AsUint, 10); EXPECT(*intPtr2, 10); + EXPECT(testAtomicity(), 0); + return ret; -} \ No newline at end of file +}