diff --git a/include/hyprutils/animation/AnimationManager.hpp b/include/hyprutils/animation/AnimationManager.hpp index 1266e2d..9179968 100644 --- a/include/hyprutils/animation/AnimationManager.hpp +++ b/include/hyprutils/animation/AnimationManager.hpp @@ -35,8 +35,8 @@ namespace Hyprutils { const std::unordered_map>& getAllBeziers(); struct SAnimationManagerSignals { - Signal::CSignal connect; // WP - Signal::CSignal disconnect; // WP + Signal::CSignalT> connect; + Signal::CSignalT> disconnect; }; Memory::CWeakPointer getSignals() const; @@ -48,9 +48,6 @@ namespace Hyprutils { bool m_bTickScheduled = false; - void onConnect(std::any data); - void onDisconnect(std::any data); - struct SAnimVarListeners { Signal::CHyprSignalListener connect; Signal::CHyprSignalListener disconnect; diff --git a/include/hyprutils/signal/Listener.hpp b/include/hyprutils/signal/Listener.hpp index 4453896..14f1ca1 100644 --- a/include/hyprutils/signal/Listener.hpp +++ b/include/hyprutils/signal/Listener.hpp @@ -6,7 +6,7 @@ namespace Hyprutils { namespace Signal { - class CUntypedSignal; + class CSignalBase; class CSignalListener { public: @@ -24,7 +24,7 @@ namespace Hyprutils { std::function m_fHandler; - friend class CUntypedSignal; + friend class CSignalBase; }; typedef Hyprutils::Memory::CSharedPointer CHyprSignalListener; diff --git a/include/hyprutils/signal/Signal.hpp b/include/hyprutils/signal/Signal.hpp index e78b48a..31262be 100644 --- a/include/hyprutils/signal/Signal.hpp +++ b/include/hyprutils/signal/Signal.hpp @@ -2,53 +2,107 @@ #include #include +#include +#include #include #include #include +#include #include #include "./Listener.hpp" namespace Hyprutils { namespace Signal { - class CUntypedSignal { + class CSignalBase { protected: - CHyprSignalListener registerListenerInternal(std::function handler); - void registerStaticListenerInternal(std::function handler); - void emitInternal(void* args); + CHyprSignalListener registerListenerInternal(std::function handler); + void registerStaticListenerInternal(std::function handler); + void emitInternal(void* args); - std::vector> m_vListeners; - std::vector> m_vStaticListeners; + std::vector> m_vListeners; + std::vector> m_vStaticListeners; }; template - class CSignalT : public CUntypedSignal { + class CSignalT : public CSignalBase { + template + using RefArg = std::conditional_t || std::is_arithmetic_v, T, const T&>; + public: - void emit(Args... args) { - auto argsTuple = std::make_tuple(args...); - emitInternal(&argsTuple); + void emit(RefArg... args) { + if constexpr (sizeof...(Args) == 0) + emitInternal(nullptr); + else { + auto argsTuple = std::tuple...>(args...); + + if constexpr (sizeof...(Args) == 1) + // NOLINTNEXTLINE: const is reapplied by handler invocation if required + emitInternal(const_cast(static_cast(&std::get<0>(argsTuple)))); + else + emitInternal(&argsTuple); + } } - [[nodiscard("Listener is unregistered when the ptr is lost")]] CHyprSignalListener registerListener(std::function handler) { - return registerListenerInternal([handler](void* argsPtr) { std::apply(handler, *static_cast*>(argsPtr)); }); + [[nodiscard("Listener is unregistered when the ptr is lost")]] CHyprSignalListener listen(std::function...)> handler) { + return registerListenerInternal(mkHandler(handler)); + } + + [[nodiscard("Listener is unregistered when the ptr is lost")]] CHyprSignalListener listen(std::function handler) + requires(sizeof...(Args) != 0) + { + return listen([handler](RefArg... args) { handler(); }); + } + + template + [[nodiscard("Listener is unregistered when the ptr is lost")]] CHyprSignalListener forward(CSignalT& signal) { + if constexpr (sizeof...(OtherArgs) == 0) + return listen([&signal](RefArg... args) { signal.emit(); }); + else + return listen([&signal](RefArg... args) { signal.emit(args...); }); + } + + [[deprecated("Use listener()")]] CHyprSignalListener registerListener(std::function handler) { + return listen([handler](const Args&... args) { + constexpr auto mkAny = [](std::any d = {}) { return d; }; + handler(mkAny(args...)); + }); } // this is for static listeners. They die with this signal. - void registerStaticListener(std::function handler) { - registerStaticListenerInternal([handler](void* argsPtr) { std::apply(handler, *static_cast*>(argsPtr)); }); + void listenStatic(std::function...)> handler) { + registerStaticListenerInternal(mkHandler(handler)); } - template - void registerStaticListener(std::function handler, Owner* owner) { - registerStaticListener([owner, handler](Args... args) { handler(owner, args...); }); + void listenStatic(std::function handler) + requires(sizeof...(Args) != 0) + { + return listenStatic([handler](RefArg... args) { handler(); }); + } + + [[deprecated("Use staticListener()")]] void registerStaticListener(std::function handler, void* owner) { + return listenStatic([handler, owner](const RefArg&... args) { + constexpr auto mkAny = [](std::any d = {}) { return d; }; + handler(owner, mkAny(args...)); + }); + } + + private: + std::function mkHandler(std::function...)> handler) { + return [handler](void* args) { + if constexpr (sizeof...(Args) == 0) + handler(); + else if constexpr (sizeof...(Args) == 1) + handler(*static_cast...>>>*>(args)); + else + std::apply(handler, *static_cast...>*>(args)); + }; } }; // compat - class CSignal : public CSignalT { + class [[deprecated("Use CSignalT")]] CSignal : public CSignalT { public: - void emit(std::any data = {}); - [[nodiscard("Listener is unregistered when the ptr is lost")]] CHyprSignalListener registerListener(std::function handler); - void registerStaticListener(std::function handler, void* owner); + void emit(std::any data = {}); }; } } diff --git a/src/animation/AnimationManager.cpp b/src/animation/AnimationManager.cpp index cc62616..446b2d9 100644 --- a/src/animation/AnimationManager.cpp +++ b/src/animation/AnimationManager.cpp @@ -20,31 +20,18 @@ CAnimationManager::CAnimationManager() { m_events = makeUnique(); m_listeners = makeUnique(); - m_listeners->connect = m_events->connect.registerListener([this](std::any data) { onConnect(data); }); - m_listeners->disconnect = m_events->disconnect.registerListener([this](std::any data) { onDisconnect(data); }); -} + m_listeners->connect = m_events->connect.listen([this](const WP& animVar) { + if (!m_bTickScheduled) + scheduleTick(); -void CAnimationManager::onConnect(std::any data) { - if (!m_bTickScheduled) - scheduleTick(); + if (animVar) + m_vActiveAnimatedVariables.emplace_back(animVar); + }); - try { - const auto PAV = std::any_cast>(data); - if (!PAV) - return; - - m_vActiveAnimatedVariables.emplace_back(PAV); - } catch (const std::bad_any_cast&) { return; } -} - -void CAnimationManager::onDisconnect(std::any data) { - try { - const auto PAV = std::any_cast>(data); - if (!PAV) - return; - - std::erase_if(m_vActiveAnimatedVariables, [&](const auto& other) { return !other || other == PAV; }); - } catch (const std::bad_any_cast&) { return; } + m_listeners->disconnect = m_events->disconnect.listen([this](const WP& animVar) { + if (animVar) + std::erase_if(m_vActiveAnimatedVariables, [&](const auto& other) { return !other || other == animVar; }); + }); } void CAnimationManager::removeAllBeziers() { diff --git a/src/signal/Signal.cpp b/src/signal/Signal.cpp index e9bfea9..6a410c0 100644 --- a/src/signal/Signal.cpp +++ b/src/signal/Signal.cpp @@ -1,3 +1,4 @@ +#include "hyprutils/memory/SharedPtr.hpp" #include #include #include @@ -8,7 +9,7 @@ using namespace Hyprutils::Memory; #define SP CSharedPointer #define WP CWeakPointer -void Hyprutils::Signal::CUntypedSignal::emitInternal(void* args) { +void Hyprutils::Signal::CSignalBase::emitInternal(void* args) { std::vector> listeners; for (auto& l : m_vListeners) { if (l.expired()) @@ -17,11 +18,7 @@ void Hyprutils::Signal::CUntypedSignal::emitInternal(void* args) { listeners.emplace_back(l.lock()); } - std::vector statics; - statics.reserve(m_vStaticListeners.size()); - for (auto& l : m_vStaticListeners) { - statics.emplace_back(l.get()); - } + auto statics = m_vStaticListeners; for (auto& l : listeners) { // if there is only one lock, it means the event is only held by the listeners @@ -43,7 +40,7 @@ void Hyprutils::Signal::CUntypedSignal::emitInternal(void* args) { // as such we'd be doing a UAF } -CHyprSignalListener Hyprutils::Signal::CUntypedSignal::registerListenerInternal(std::function handler) { +CHyprSignalListener Hyprutils::Signal::CSignalBase::registerListenerInternal(std::function handler) { CHyprSignalListener listener = SP(new CSignalListener(handler)); m_vListeners.emplace_back(listener); @@ -53,18 +50,10 @@ CHyprSignalListener Hyprutils::Signal::CUntypedSignal::registerListenerInternal( return listener; } -void Hyprutils::Signal::CUntypedSignal::registerStaticListenerInternal(std::function handler) { - m_vStaticListeners.emplace_back(std::unique_ptr(new CSignalListener(handler))); +void Hyprutils::Signal::CSignalBase::registerStaticListenerInternal(std::function handler) { + m_vStaticListeners.emplace_back(SP(new CSignalListener(handler))); } void Hyprutils::Signal::CSignal::emit(std::any data) { CSignalT::emit(data); } - -CHyprSignalListener Hyprutils::Signal::CSignal::registerListener(std::function handler) { - return CSignalT::registerListener(handler); -} - -void Hyprutils::Signal::CSignal::registerStaticListener(std::function handler, void* owner) { - CSignalT::registerStaticListener(handler, owner); -} diff --git a/tests/signal.cpp b/tests/signal.cpp index df8177b..99a8274 100644 --- a/tests/signal.cpp +++ b/tests/signal.cpp @@ -1,11 +1,18 @@ #include #include #include +#include +#include "hyprutils/memory/SharedPtr.hpp" +#include "hyprutils/signal/Listener.hpp" #include "shared.hpp" using namespace Hyprutils::Signal; using namespace Hyprutils::Memory; +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +// + void legacy(int& ret) { CSignal signal; int data = 0; @@ -33,11 +40,32 @@ void legacyListenerEmit(int& ret) { EXPECT(data, 1); } +void legacyListeners(int& ret) { + int data = 0; + + CSignalT<> signal0; + CSignalT signal1; + + auto listener0 = signal0.registerListener([&](std::any d) { data += 1; }); + auto listener1 = signal1.registerListener([&](std::any d) { data += std::any_cast(d); }); + + signal0.registerStaticListener([&](void* o, std::any d) { data += 10; }, nullptr); + signal1.registerStaticListener([&](void* o, std::any d) { data += std::any_cast(d) * 10; }, nullptr); + + signal0.emit(); + signal1.emit(2); + + EXPECT(data, 33); +} + +#pragma GCC diagnostic pop +// + void empty(int& ret) { int data = 0; CSignalT<> signal; - auto listener = signal.registerListener([&] { data = 1; }); + auto listener = signal.listen([&] { data = 1; }); signal.emit(); EXPECT(data, 1); @@ -52,19 +80,31 @@ void typed(int& ret) { int data = 0; CSignalT signal; - auto listener = signal.registerListener([&](int newData) { data = newData; }); + auto listener = signal.listen([&](int newData) { data = newData; }); signal.emit(1); EXPECT(data, 1); } +void ignoreParams(int& ret) { + int data = 0; + + CSignalT signal; + auto listener = signal.listen([&] { data += 1; }); + + signal.listenStatic([&] { data += 1; }); + + signal.emit(2); + EXPECT(data, 2); +} + void typedMany(int& ret) { int data1 = 0; int data2 = 0; int data3 = 0; CSignalT signal; - auto listener = signal.registerListener([&](int d1, int d2, int d3) { + auto listener = signal.listen([&](int d1, int d2, int d3) { data1 = d1; data2 = d2; data3 = d3; @@ -76,25 +116,255 @@ void typedMany(int& ret) { EXPECT(data3, 3); } +void ref(int& ret) { + int count = 0; + int data = 0; + + CSignalT signal; + auto l1 = signal.listen([&](int& v) { v += 1; }); + auto l2 = signal.listen([&](int v) { count += v; }); + signal.emit(data); + + CSignalT constSignal; + auto l3 = constSignal.listen([&](const int& v) { count += v; }); + auto l4 = constSignal.listen([&](int v) { count += v; }); + constSignal.emit(data); + + EXPECT(data, 1); + EXPECT(count, 3); +} + +void refMany(int& ret) { + int count = 0; + int data1 = 0; + int data2 = 10; + + CSignalT signal; + auto l1 = signal.listen([&](int& v, const int&) { v += 1; }); + auto l2 = signal.listen([&](int v1, int v2) { count += v1 + v2; }); + + signal.emit(data1, data2); + EXPECT(data1, 1); + EXPECT(count, 11); +} + +void autoRefTypes(int& ret) { + class CCopyCounter { + public: + CCopyCounter(int& createCount, int& destroyCount) : createCount(createCount), destroyCount(destroyCount) { + createCount += 1; + } + + CCopyCounter(CCopyCounter&& other) noexcept : CCopyCounter(other.createCount, other.destroyCount) {} + CCopyCounter(const CCopyCounter& other) noexcept : CCopyCounter(other.createCount, other.destroyCount) {} + + ~CCopyCounter() { + destroyCount += 1; + } + + private: + int& createCount; + int& destroyCount; + }; + + auto createCount = 0; + auto destroyCount = 0; + + CSignalT signal; + auto listener = signal.listen([](const CCopyCounter& counter) {}); + + signal.emit(CCopyCounter(createCount, destroyCount)); + EXPECT(createCount, 1); + EXPECT(destroyCount, 1); +} + +void forward(int& ret) { + int count = 0; + + CSignalT sig; + CSignalT connected1; + CSignalT<> connected2; + + auto conn1 = sig.forward(connected1); + auto conn2 = sig.forward(connected2); + + auto listener1 = connected1.listen([&](int v) { count += v; }); + auto listener2 = connected2.listen([&] { count += 1; }); + + sig.emit(2); + + EXPECT(count, 3); +} + +void listenerAdded(int& ret) { + int count = 0; + + CSignalT<> signal; + CHyprSignalListener secondListener; + + auto listener = signal.listen([&] { + count += 1; + + if (!secondListener) + secondListener = signal.listen([&] { count += 1; }); + }); + + signal.emit(); + EXPECT(count, 1); // second should NOT be invoked as it was registed during emit + + signal.emit(); + EXPECT(count, 3); // second should be invoked +} + +void lastListenerSwapped(int& ret) { + int count = 0; + + CSignalT<> signal; + CHyprSignalListener removedListener; + CHyprSignalListener addedListener; + + auto firstListener = signal.listen([&] { + removedListener.reset(); // dropped and should NOT be invoked + + if (!addedListener) + addedListener = signal.listen([&] { count += 2; }); + }); + + removedListener = signal.listen([&] { count += 1; }); + + signal.emit(); + EXPECT(count, 0); // neither the removed nor added listeners should fire + + signal.emit(); + EXPECT(count, 2); // only the new listener should fire +} + +void signalDestroyed(int& ret) { + int count = 0; + + auto signal = std::make_unique>(); + + // This ensures a destructor of a listener called before signal reset is safe. + auto preListener = signal->listen([&] { count += 1; }); + + auto listener = signal->listen([&] { signal.reset(); }); + + // This ensures a destructor of a listener called after signal reset is safe + // and gets called. + auto postListener = signal->listen([&] { count += 1; }); + + signal->emit(); + EXPECT(count, 2); // all listeners should fire regardless of signal deletion +} + +// purely an asan test +void signalDestroyedBeforeListener() { + CHyprSignalListener listener1; + CHyprSignalListener listener2; + + CSignalT<> signal; + + listener1 = signal.listen([] {}); + listener2 = signal.listen([] {}); +} + +void signalDestroyedWithAddedListener(int& ret) { + int count = 0; + + auto signal = std::make_unique>(); + CHyprSignalListener shouldNotRun; + + auto listener = signal->listen([&] { + shouldNotRun = signal->listen([&] { count += 2; }); + signal.reset(); + }); + + signal->emit(); + EXPECT(count, 0); +} + +void signalDestroyedWithRemovedAndAddedListener(int& ret) { + int count = 0; + + auto signal = std::make_unique>(); + CHyprSignalListener removed; + CHyprSignalListener shouldNotRun; + + auto listener = signal->listen([&] { + removed.reset(); + shouldNotRun = signal->listen([&] { count += 2; }); + signal.reset(); + }); + + removed = signal->listen([&] { count += 1; }); + + signal->emit(); + EXPECT(count, 0); +} + void staticListener(int& ret) { - struct STestOwner { - int data = 0; - } owner; + int data = 0; CSignalT signal; - signal.registerStaticListener([&](STestOwner* owner, int newData) { owner->data = newData; }, &owner); + signal.listenStatic([&](int newData) { data = newData; }); signal.emit(1); - EXPECT(owner.data, 1); + EXPECT(data, 1); +} + +void staticListenerDestroy(int& ret) { + int count = 0; + + auto signal = makeShared>(); + signal->listenStatic([&] { count += 1; }); + + signal->listenStatic([&] { + // should not fire but SHOULD be freed + signal->listenStatic([&] { count += 3; }); + + signal.reset(); + }); + + signal->listenStatic([&] { count += 1; }); + + signal->emit(); + EXPECT(count, 2); +} + +// purely an asan test +void listenerDestroysSelf() { + CSignalT<> signal; + + CHyprSignalListener listener; + listener = signal.listen([&] { listener.reset(); }); + + // the static signal case is taken care of above + + signal.emit(); } int main(int argc, char** argv, char** envp) { int ret = 0; legacy(ret); legacyListenerEmit(ret); + legacyListeners(ret); empty(ret); typed(ret); + ignoreParams(ret); typedMany(ret); + ref(ret); + refMany(ret); + autoRefTypes(ret); + forward(ret); + listenerAdded(ret); + lastListenerSwapped(ret); + signalDestroyed(ret); + signalDestroyedBeforeListener(); + signalDestroyedWithAddedListener(ret); + signalDestroyedWithRemovedAndAddedListener(ret); staticListener(ret); + staticListenerDestroy(ret); + signalDestroyed(ret); + listenerDestroysSelf(); return ret; }