memory/shared: add dynamicPointerCast (#92)

This commit is contained in:
Vaxry 2025-12-01 21:18:31 +00:00 committed by GitHub
parent 7e6346f84b
commit 9f8e158dbd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 207 additions and 24 deletions

View file

@ -54,11 +54,11 @@ namespace Hyprutils::Memory {
using validHierarchy = std::enable_if_t<std::is_assignable_v<CAtomicSharedPointer<T>&, X>, CAtomicSharedPointer&>; using validHierarchy = std::enable_if_t<std::is_assignable_v<CAtomicSharedPointer<T>&, X>, CAtomicSharedPointer&>;
public: public:
explicit CAtomicSharedPointer(T* object) noexcept : m_ptr(new Atomic_::impl(sc<void*>(object), _delete)) { explicit CAtomicSharedPointer(T* object) noexcept : m_ptr(new Atomic_::impl(sc<void*>(object), _delete), sc<void*>(object)) {
; ;
} }
CAtomicSharedPointer(Impl_::impl_base* impl) noexcept : m_ptr(impl) { CAtomicSharedPointer(Impl_::impl_base* impl, void* data) noexcept : m_ptr(impl, data) {
; ;
} }
@ -219,13 +219,17 @@ namespace Hyprutils::Memory {
return m_ptr.impl_ ? m_ptr.impl_->ref() : 0; return m_ptr.impl_ ? m_ptr.impl_->ref() : 0;
} }
Atomic_::impl* impl() const {
return sc<Atomic_::impl*>(m_ptr.impl_);
}
private: private:
static void _delete(void* p) { static void _delete(void* p) {
std::default_delete<T>{}(sc<T*>(p)); std::default_delete<T>{}(sc<T*>(p));
} }
std::lock_guard<std::recursive_mutex> implLockGuard() const { std::lock_guard<std::recursive_mutex> implLockGuard() const {
return sc<Atomic_::impl*>(m_ptr.impl_)->lockGuard(); return impl()->lockGuard();
} }
CSharedPointer<T> m_ptr; CSharedPointer<T> m_ptr;
@ -391,12 +395,16 @@ namespace Hyprutils::Memory {
if (!m_ptr.impl_->dataNonNull() || m_ptr.impl_->destroying() || !m_ptr.impl_->lockable()) if (!m_ptr.impl_->dataNonNull() || m_ptr.impl_->destroying() || !m_ptr.impl_->lockable())
return {}; return {};
return CAtomicSharedPointer<T>(m_ptr.impl_); return CAtomicSharedPointer<T>(m_ptr.impl_, m_ptr.m_data);
}
Atomic_::impl* impl() const {
return sc<Atomic_::impl*>(m_ptr.impl_);
} }
private: private:
std::lock_guard<std::recursive_mutex> implLockGuard() const { std::lock_guard<std::recursive_mutex> implLockGuard() const {
return sc<Atomic_::impl*>(m_ptr.impl_)->lockGuard(); return impl()->lockGuard();
} }
CWeakPointer<T> m_ptr; CWeakPointer<T> m_ptr;
@ -411,4 +419,19 @@ namespace Hyprutils::Memory {
[[nodiscard]] inline CAtomicSharedPointer<U> makeAtomicShared(Args&&... args) { [[nodiscard]] inline CAtomicSharedPointer<U> makeAtomicShared(Args&&... args) {
return CAtomicSharedPointer<U>(new U(std::forward<Args>(args)...)); return CAtomicSharedPointer<U>(new U(std::forward<Args>(args)...));
} }
template <typename T, typename U>
CAtomicSharedPointer<T> reinterpretPointerCast(const CAtomicSharedPointer<U>& ref) {
return CAtomicSharedPointer<T>(ref.impl(), ref.m_data);
}
template <typename T, typename U>
CAtomicSharedPointer<T> dynamicPointerCast(const CAtomicSharedPointer<U>& ref) {
if (!ref)
return nullptr;
T* newPtr = dynamic_cast<T*>(sc<U*>(ref.impl()->getData()));
if (!newPtr)
return nullptr;
return CAtomicSharedPointer<T>(ref.impl(), newPtr);
}
} }

View file

@ -28,31 +28,33 @@ namespace Hyprutils {
/* creates a new shared pointer managing a resource /* creates a new shared pointer managing a resource
avoid calling. Could duplicate ownership. Prefer makeShared */ avoid calling. Could duplicate ownership. Prefer makeShared */
explicit CSharedPointer(T* object) noexcept : impl_(new Impl_::impl_base(sc<void*>(object), _delete)) { explicit CSharedPointer(T* object) noexcept : impl_(new Impl_::impl_base(sc<void*>(object), _delete)), m_data(sc<void*>(object)) {
increment(); increment();
} }
/* creates a shared pointer from a reference */ /* creates a shared pointer from a reference */
template <typename U, typename = isConstructible<U>> template <typename U, typename = isConstructible<U>>
CSharedPointer(const CSharedPointer<U>& ref) noexcept : impl_(ref.impl_) { CSharedPointer(const CSharedPointer<U>& ref) noexcept : impl_(ref.impl_), m_data(ref.m_data) {
increment(); increment();
} }
CSharedPointer(const CSharedPointer& ref) noexcept : impl_(ref.impl_) { CSharedPointer(const CSharedPointer& ref) noexcept : impl_(ref.impl_), m_data(ref.m_data) {
increment(); increment();
} }
template <typename U, typename = isConstructible<U>> template <typename U, typename = isConstructible<U>>
CSharedPointer(CSharedPointer<U>&& ref) noexcept { CSharedPointer(CSharedPointer<U>&& ref) noexcept {
std::swap(impl_, ref.impl_); std::swap(impl_, ref.impl_);
std::swap(m_data, ref.m_data);
} }
CSharedPointer(CSharedPointer&& ref) noexcept { CSharedPointer(CSharedPointer&& ref) noexcept {
std::swap(impl_, ref.impl_); std::swap(impl_, ref.impl_);
std::swap(m_data, ref.m_data);
} }
/* allows weakPointer to create from an impl */ /* allows weakPointer to create from an impl */
CSharedPointer(Impl_::impl_base* implementation) noexcept : impl_(implementation) { CSharedPointer(Impl_::impl_base* implementation, void* data) noexcept : impl_(implementation), m_data(data) {
increment(); increment();
} }
@ -74,7 +76,8 @@ namespace Hyprutils {
return *this; return *this;
decrement(); decrement();
impl_ = rhs.impl_; impl_ = rhs.impl_;
m_data = rhs.m_data;
increment(); increment();
return *this; return *this;
} }
@ -84,7 +87,8 @@ namespace Hyprutils {
return *this; return *this;
decrement(); decrement();
impl_ = rhs.impl_; impl_ = rhs.impl_;
m_data = rhs.m_data;
increment(); increment();
return *this; return *this;
} }
@ -92,11 +96,13 @@ namespace Hyprutils {
template <typename U> template <typename U>
validHierarchy<const CSharedPointer<U>&> operator=(CSharedPointer<U>&& rhs) { validHierarchy<const CSharedPointer<U>&> operator=(CSharedPointer<U>&& rhs) {
std::swap(impl_, rhs.impl_); std::swap(impl_, rhs.impl_);
std::swap(m_data, rhs.m_data);
return *this; return *this;
} }
CSharedPointer& operator=(CSharedPointer&& rhs) noexcept { CSharedPointer& operator=(CSharedPointer&& rhs) noexcept {
std::swap(impl_, rhs.impl_); std::swap(impl_, rhs.impl_);
std::swap(m_data, rhs.m_data);
return *this; return *this;
} }
@ -104,6 +110,8 @@ namespace Hyprutils {
return impl_ && impl_->dataNonNull(); return impl_ && impl_->dataNonNull();
} }
// this compares that the pointed-to object is the same, but in multiple inheritance,
// different typed pointers can be equal if the object is the same
bool operator==(const CSharedPointer& rhs) const { bool operator==(const CSharedPointer& rhs) const {
return impl_ == rhs.impl_; return impl_ == rhs.impl_;
} }
@ -126,11 +134,12 @@ namespace Hyprutils {
void reset() { void reset() {
decrement(); decrement();
impl_ = nullptr; impl_ = nullptr;
m_data = nullptr;
} }
T* get() const { T* get() const {
return impl_ ? sc<T*>(impl_->getData()) : nullptr; return impl_ && impl_->dataNonNull() ? sc<T*>(m_data) : nullptr;
} }
unsigned int strongRef() const { unsigned int strongRef() const {
@ -139,6 +148,9 @@ namespace Hyprutils {
Impl_::impl_base* impl_ = nullptr; Impl_::impl_base* impl_ = nullptr;
// Never use directly: raw data ptr, could be UAF
void* m_data = nullptr;
private: private:
static void _delete(void* p) { static void _delete(void* p) {
std::default_delete<T>{}(sc<T*>(p)); std::default_delete<T>{}(sc<T*>(p));
@ -188,7 +200,17 @@ namespace Hyprutils {
template <typename T, typename U> template <typename T, typename U>
CSharedPointer<T> reinterpretPointerCast(const CSharedPointer<U>& ref) { CSharedPointer<T> reinterpretPointerCast(const CSharedPointer<U>& ref) {
return CSharedPointer<T>(ref.impl_); return CSharedPointer<T>(ref.impl_, ref.m_data);
}
template <typename T, typename U>
CSharedPointer<T> dynamicPointerCast(const CSharedPointer<U>& ref) {
if (!ref)
return nullptr;
T* newPtr = dynamic_cast<T*>(sc<U*>(ref.impl_->getData()));
if (!newPtr)
return nullptr;
return CSharedPointer<T>(ref.impl_, newPtr);
} }
} }
} }

View file

@ -26,7 +26,8 @@ namespace Hyprutils {
if (!ref.impl_) if (!ref.impl_)
return; return;
impl_ = ref.impl_; impl_ = ref.impl_;
m_data = ref.m_data;
incrementWeak(); incrementWeak();
} }
@ -36,7 +37,8 @@ namespace Hyprutils {
if (!ref.impl_) if (!ref.impl_)
return; return;
impl_ = ref.impl_; impl_ = ref.impl_;
m_data = ref.impl_->getData();
incrementWeak(); incrementWeak();
} }
@ -46,7 +48,8 @@ namespace Hyprutils {
if (!ref.impl_) if (!ref.impl_)
return; return;
impl_ = ref.impl_; impl_ = ref.impl_;
m_data = ref.m_data;
incrementWeak(); incrementWeak();
} }
@ -54,17 +57,20 @@ namespace Hyprutils {
if (!ref.impl_) if (!ref.impl_)
return; return;
impl_ = ref.impl_; impl_ = ref.impl_;
m_data = ref.m_data;
incrementWeak(); incrementWeak();
} }
template <typename U, typename = isConstructible<U>> template <typename U, typename = isConstructible<U>>
CWeakPointer(CWeakPointer<U>&& ref) noexcept { CWeakPointer(CWeakPointer<U>&& ref) noexcept {
std::swap(impl_, ref.impl_); std::swap(impl_, ref.impl_);
std::swap(m_data, ref.m_data);
} }
CWeakPointer(CWeakPointer&& ref) noexcept { CWeakPointer(CWeakPointer&& ref) noexcept {
std::swap(impl_, ref.impl_); std::swap(impl_, ref.impl_);
std::swap(m_data, ref.m_data);
} }
/* create a weak ptr from another weak ptr with assignment */ /* create a weak ptr from another weak ptr with assignment */
@ -74,7 +80,8 @@ namespace Hyprutils {
return *this; return *this;
decrementWeak(); decrementWeak();
impl_ = rhs.impl_; impl_ = rhs.impl_;
m_data = rhs.m_data;
incrementWeak(); incrementWeak();
return *this; return *this;
} }
@ -84,7 +91,8 @@ namespace Hyprutils {
return *this; return *this;
decrementWeak(); decrementWeak();
impl_ = rhs.impl_; impl_ = rhs.impl_;
m_data = rhs.m_data;
incrementWeak(); incrementWeak();
return *this; return *this;
} }
@ -96,7 +104,8 @@ namespace Hyprutils {
return *this; return *this;
decrementWeak(); decrementWeak();
impl_ = rhs.impl_; impl_ = rhs.impl_;
m_data = rhs.m_data;
incrementWeak(); incrementWeak();
return *this; return *this;
} }
@ -125,14 +134,15 @@ namespace Hyprutils {
void reset() { void reset() {
decrementWeak(); decrementWeak();
impl_ = nullptr; impl_ = nullptr;
m_data = nullptr;
} }
CSharedPointer<T> lock() const { CSharedPointer<T> lock() const {
if (!impl_ || !impl_->dataNonNull() || impl_->destroying() || !impl_->lockable()) if (!impl_ || !impl_->dataNonNull() || impl_->destroying() || !impl_->lockable())
return {}; return {};
return CSharedPointer<T>(impl_); return CSharedPointer<T>(impl_, m_data);
} }
/* this returns valid() */ /* this returns valid() */
@ -169,7 +179,7 @@ namespace Hyprutils {
} }
T* get() const { T* get() const {
return impl_ ? sc<T*>(impl_->getData()) : nullptr; return impl_ && impl_->dataNonNull() ? sc<T*>(m_data) : nullptr;
} }
T* operator->() const { T* operator->() const {
@ -182,6 +192,9 @@ namespace Hyprutils {
Impl_::impl_base* impl_ = nullptr; Impl_::impl_base* impl_ = nullptr;
// Never use directly: raw data ptr, could be UAF
void* m_data = nullptr;
private: private:
/* no-op if there is no impl_ */ /* no-op if there is no impl_ */
void decrementWeak() { void decrementWeak() {
@ -207,6 +220,16 @@ namespace Hyprutils {
impl_->incWeak(); impl_->incWeak();
} }
}; };
template <typename T, typename U>
CWeakPointer<T> dynamicPointerCast(const CWeakPointer<U>& ref) {
if (!ref)
return nullptr;
T* newPtr = dynamic_cast<T*>(sc<U*>(ref.impl_->getData()));
if (!newPtr)
return nullptr;
return CWeakPointer<T>(ref.impl_, newPtr);
}
} }
} }

View file

@ -123,6 +123,119 @@ static void testAtomicImpl() {
} }
} }
class InterfaceA {
public:
virtual ~InterfaceA() = default;
int m_ifaceAInt = 69;
int m_ifaceAShit = 1;
};
class InterfaceB {
public:
virtual ~InterfaceB() = default;
int m_ifaceBInt = 2;
int m_ifaceBShit = 3;
};
class CChild : public InterfaceA, public InterfaceB {
public:
virtual ~CChild() = default;
int m_childInt = 4;
};
class CChildA : public InterfaceA {
public:
int m_childAInt = 4;
};
static void testHierarchy() {
// Same test for atomic and non-atomic
{
SP<CChildA> childA = makeShared<CChildA>();
auto ifaceA = SP<InterfaceA>(childA);
EXPECT_TRUE(ifaceA);
EXPECT_EQ(ifaceA->m_ifaceAInt, 69);
auto ifaceB = dynamicPointerCast<InterfaceA>(SP<CChildA>{});
EXPECT_TRUE(!ifaceB);
}
{
SP<CChild> child = makeShared<CChild>();
SP<InterfaceA> ifaceA = dynamicPointerCast<InterfaceA>(child);
SP<InterfaceB> ifaceB = dynamicPointerCast<InterfaceB>(child);
EXPECT_TRUE(ifaceA);
EXPECT_TRUE(ifaceB);
EXPECT_EQ(ifaceA->m_ifaceAInt, 69);
EXPECT_EQ(ifaceB->m_ifaceBInt, 2);
WP<InterfaceA> ifaceAWeak = ifaceA;
child.reset();
EXPECT_TRUE(ifaceAWeak);
EXPECT_TRUE(ifaceA);
EXPECT_EQ(ifaceAWeak->m_ifaceAInt, 69);
EXPECT_EQ(ifaceA->m_ifaceAInt, 69);
ifaceA.reset();
EXPECT_TRUE(ifaceAWeak);
EXPECT_EQ(ifaceAWeak->m_ifaceAInt, 69);
EXPECT_TRUE(ifaceB);
EXPECT_EQ(ifaceB->m_ifaceBInt, 2);
ifaceB.reset();
EXPECT_TRUE(!ifaceAWeak);
}
//
{
ASP<CChildA> childA = makeAtomicShared<CChildA>();
auto ifaceA = ASP<InterfaceA>(childA);
EXPECT_TRUE(ifaceA);
EXPECT_EQ(ifaceA->m_ifaceAInt, 69);
auto ifaceB = dynamicPointerCast<InterfaceA>(ASP<CChildA>{});
EXPECT_TRUE(!ifaceB);
}
{
ASP<CChild> child = makeAtomicShared<CChild>();
ASP<InterfaceA> ifaceA = dynamicPointerCast<InterfaceA>(child);
ASP<InterfaceB> ifaceB = dynamicPointerCast<InterfaceB>(child);
EXPECT_TRUE(ifaceA);
EXPECT_TRUE(ifaceB);
EXPECT_EQ(ifaceA->m_ifaceAInt, 69);
EXPECT_EQ(ifaceB->m_ifaceBInt, 2);
AWP<InterfaceA> ifaceAWeak = ifaceA;
AWP<InterfaceB> ifaceBWeak = dynamicPointerCast<InterfaceB>(ifaceA);
child.reset();
EXPECT_TRUE(ifaceAWeak);
EXPECT_TRUE(ifaceBWeak);
EXPECT_TRUE(ifaceA);
EXPECT_EQ(ifaceAWeak->m_ifaceAInt, 69);
EXPECT_EQ(ifaceA->m_ifaceAInt, 69);
EXPECT_EQ(ifaceBWeak->m_ifaceBInt, 2);
ifaceA.reset();
EXPECT_TRUE(ifaceAWeak);
EXPECT_EQ(ifaceAWeak->m_ifaceAInt, 69);
EXPECT_TRUE(ifaceB);
EXPECT_EQ(ifaceB->m_ifaceBInt, 2);
EXPECT_EQ(ifaceBWeak->m_ifaceBInt, 2);
ifaceB.reset();
EXPECT_TRUE(!ifaceAWeak);
EXPECT_TRUE(!ifaceBWeak);
}
// test for leaks
for (size_t i = 0; i < 10000; ++i) {
auto child = makeAtomicShared<CChild>();
auto child2 = makeShared<CChild>();
}
}
TEST(Memory, memory) { TEST(Memory, memory) {
SP<int> intPtr = makeShared<int>(10); SP<int> intPtr = makeShared<int>(10);
SP<int> intPtr2 = makeShared<int>(-1337); SP<int> intPtr2 = makeShared<int>(-1337);
@ -176,4 +289,6 @@ TEST(Memory, memory) {
EXPECT_EQ(*intPtr2, 10); EXPECT_EQ(*intPtr2, 10);
testAtomicImpl(); testAtomicImpl();
testHierarchy();
} }