From 9b724ab9d545a9254f1dceaa5e957e0821c51b52 Mon Sep 17 00:00:00 2001 From: "J.A. de Jong - Redu-Sone B.V., ASCEE V.O.F" Date: Sat, 10 Jun 2023 15:47:52 +0200 Subject: [PATCH] Made thread pool itself thread safe. Besides, added some extra safety for StreamMgr singleton instance allocation. --- src/lasp/device/lasp_streammgr.cpp | 17 ++++-- src/lasp/device/lasp_streammgr.h | 3 ++ src/lasp/dsp/lasp_biquadbank.cpp | 30 +++++------ src/lasp/dsp/lasp_biquadbank.h | 2 + src/lasp/dsp/lasp_slm.cpp | 3 -- src/lasp/dsp/lasp_slm.h | 2 + src/lasp/dsp/lasp_thread.cpp | 36 ++++++++----- src/lasp/dsp/lasp_thread.h | 60 ++++++++++++++++----- src/lasp/dsp/lasp_threadedindatahandler.cpp | 5 +- src/lasp/dsp/lasp_threadedindatahandler.h | 3 ++ third_party/gsl-lite | 2 +- 11 files changed, 110 insertions(+), 53 deletions(-) diff --git a/src/lasp/device/lasp_streammgr.cpp b/src/lasp/device/lasp_streammgr.cpp index e744d6b..58c7534 100644 --- a/src/lasp/device/lasp_streammgr.cpp +++ b/src/lasp/device/lasp_streammgr.cpp @@ -9,6 +9,7 @@ #include #include #include +#include using std::cerr; using std::endl; @@ -21,8 +22,9 @@ using rte = std::runtime_error; * to it has been destroyed (no global stuff left). */ std::weak_ptr _mgr; +std::mutex _mgr_mutex; - +using Lck = std::scoped_lock; /** * @brief The only way to obtain a stream manager, can only be called from the @@ -35,6 +37,14 @@ SmgrHandle StreamMgr::getInstance() { auto mgr = _mgr.lock(); if (!mgr) { + // Double Check Locking Pattern, if two threads would simultaneously + // instantiate the singleton instance. + Lck lck(_mgr_mutex); + + auto mgr = _mgr.lock(); + if (mgr) { + return mgr; + } mgr = SmgrHandle(new StreamMgr()); if (!mgr) { @@ -54,7 +64,7 @@ SmgrHandle StreamMgr::getInstance() { StreamMgr::StreamMgr() #if LASP_DEBUG == 1 - : main_thread_id(std::this_thread::get_id()) + : main_thread_id(std::this_thread::get_id()) #endif { DEBUGTRACE_ENTER; @@ -70,7 +80,6 @@ void StreamMgr::checkRightThread() const { void StreamMgr::rescanDAQDevices(bool background, std::function callback) { DEBUGTRACE_ENTER; - auto &pool = getPool(); checkRightThread(); if (_inputStream || _outputStream) { @@ -87,7 +96,7 @@ void StreamMgr::rescanDAQDevices(bool background, rescanDAQDevices_impl(callback); } else { DEBUGTRACE_PRINT("Rescanning DAQ devices on different thread..."); - pool.push_task(&StreamMgr::rescanDAQDevices_impl, this, callback); + _pool.push_task(&StreamMgr::rescanDAQDevices_impl, this, callback); } } void StreamMgr::rescanDAQDevices_impl(std::function callback) { diff --git a/src/lasp/device/lasp_streammgr.h b/src/lasp/device/lasp_streammgr.h index 80494b0..8955cc6 100644 --- a/src/lasp/device/lasp_streammgr.h +++ b/src/lasp/device/lasp_streammgr.h @@ -1,6 +1,7 @@ #pragma once #include "lasp_daq.h" #include "lasp_siggen.h" +#include "lasp_thread.h" #include #include #include @@ -30,6 +31,8 @@ class StreamMgr { */ std::unique_ptr _inputStream, _outputStream; + ThreadSafeThreadPool _pool; + /** * @brief All indata handlers are called when input data is available. Note * that they can be called from different threads and should take care of diff --git a/src/lasp/dsp/lasp_biquadbank.cpp b/src/lasp/dsp/lasp_biquadbank.cpp index 09ec834..d31744e 100644 --- a/src/lasp/dsp/lasp_biquadbank.cpp +++ b/src/lasp/dsp/lasp_biquadbank.cpp @@ -44,19 +44,21 @@ SeriesBiquad::SeriesBiquad(const vd &filter_coefs) { SeriesBiquad SeriesBiquad::firstOrderHighPass(const d fs, const d cuton_Hz) { - if(fs <= 0) { + if (fs <= 0) { throw rte("Invalid sampling frequency: " + std::to_string(fs) + " [Hz]"); } - if(cuton_Hz <= 0) { + if (cuton_Hz <= 0) { throw rte("Invalid cuton frequency: " + std::to_string(cuton_Hz) + " [Hz]"); } - if(cuton_Hz >= 0.98*fs/2) { - throw rte("Invalid cuton frequency. We limit this to 0.98* fs / 2. Given value" + std::to_string(cuton_Hz) + " [Hz]"); + if (cuton_Hz >= 0.98 * fs / 2) { + throw rte( + "Invalid cuton frequency. We limit this to 0.98* fs / 2. Given value" + + std::to_string(cuton_Hz) + " [Hz]"); } - const d tau = 1/(2*arma::datum::pi*cuton_Hz); - const d facnum = 2*fs*tau/(1+2*fs*tau); - const d facden = (1-2*fs*tau)/(1+2*fs*tau); + const d tau = 1 / (2 * arma::datum::pi * cuton_Hz); + const d facnum = 2 * fs * tau / (1 + 2 * fs * tau); + const d facden = (1 - 2 * fs * tau) / (1 + 2 * fs * tau); vd coefs(6); // b0 @@ -76,10 +78,8 @@ SeriesBiquad SeriesBiquad::firstOrderHighPass(const d fs, const d cuton_Hz) { coefs(5) = 0; return SeriesBiquad(coefs); - } - std::unique_ptr SeriesBiquad::clone() const { // sos.as_col() concatenates all columns, exactly what we want. return std::make_unique(sos.as_col()); @@ -124,7 +124,6 @@ BiquadBank::BiquadBank(const dmat &filters, const vd *gains) { * for use. */ lock lck(_mtx); - getPool(); for (us i = 0; i < filters.n_cols; i++) { _filters.emplace_back(filters.col(i)); @@ -153,16 +152,15 @@ void BiquadBank::filter(vd &inout) { std::vector> futs; #if 1 - auto &pool = getPool(); vd inout_cpy = inout; for (us i = 0; i < _filters.size(); i++) { - futs.emplace_back(pool.submit( - [&](vd inout, us i) { + futs.emplace_back(_pool.submit( + [&](vd inout, us i) { _filters[i].filter(inout); return inout; - }, // Launch a task to filter. - inout_cpy, i // Column i as argument to the lambda function above. - )); + }, // Launch a task to filter. + inout_cpy, i // Column i as argument to the lambda function above. + )); } // Zero-out in-out and sum-up the filtered values diff --git a/src/lasp/dsp/lasp_biquadbank.h b/src/lasp/dsp/lasp_biquadbank.h index 1426cb5..0328cea 100644 --- a/src/lasp/dsp/lasp_biquadbank.h +++ b/src/lasp/dsp/lasp_biquadbank.h @@ -1,5 +1,6 @@ #pragma once #include "lasp_filter.h" +#include "lasp_thread.h" /** * \addtogroup dsp @@ -60,6 +61,7 @@ public: class BiquadBank : public Filter { std::vector _filters; vd _gains; + ThreadSafeThreadPool _pool; mutable std::mutex _mtx; public: diff --git a/src/lasp/dsp/lasp_slm.cpp b/src/lasp/dsp/lasp_slm.cpp index e934800..74f9b28 100644 --- a/src/lasp/dsp/lasp_slm.cpp +++ b/src/lasp/dsp/lasp_slm.cpp @@ -37,9 +37,6 @@ SLM::SLM(const d fs, const d Lref, const us downsampling_fac, const d tau, DEBUGTRACE_ENTER; DEBUGTRACE_PRINT(_alpha); - // Make sure thread pool is running - getPool(); - if (Lref <= 0) { throw rte("Invalid reference level"); } diff --git a/src/lasp/dsp/lasp_slm.h b/src/lasp/dsp/lasp_slm.h index f4afc0e..1c2d871 100644 --- a/src/lasp/dsp/lasp_slm.h +++ b/src/lasp/dsp/lasp_slm.h @@ -1,6 +1,7 @@ #pragma once #include "lasp_biquadbank.h" #include "lasp_filter.h" +#include "lasp_thread.h" #include #include @@ -14,6 +15,7 @@ * channel. A channel is the result of a filtered signal */ class SLM { + ThreadSafeThreadPool _pool; /** * @brief A, C or Z weighting, depending on the pre-filter installed. */ diff --git a/src/lasp/dsp/lasp_thread.cpp b/src/lasp/dsp/lasp_thread.cpp index 76bc400..a04e051 100644 --- a/src/lasp/dsp/lasp_thread.cpp +++ b/src/lasp/dsp/lasp_thread.cpp @@ -5,21 +5,31 @@ #include /** - * @brief It seems to work much better in cooperation with Pybind11 when this - * singleton is implemented with a unique_ptr. + * @brief Store a global weak_ptr, that is used to create new shared pointers + * if any other shared pointers are still alive. If not, we create a new + * instance. */ -std::unique_ptr _static_storage_threadpool; +std::weak_ptr _global_weak_pool; -void destroyThreadPool() { +/** + * @brief Static storage for the mutex. + */ +std::mutex ThreadSafeThreadPool::_mtx; + +using Lck = std::scoped_lock; +using rte = std::runtime_error; + +ThreadSafeThreadPool::ThreadSafeThreadPool() { DEBUGTRACE_ENTER; - _static_storage_threadpool = nullptr; -} - -BS::thread_pool &getPool() { - /* DEBUGTRACE_ENTER; */ - if (!_static_storage_threadpool) { - DEBUGTRACE_PRINT("Creating new thread pool"); - _static_storage_threadpool = std::make_unique(); + Lck lck(_mtx); + /// See if we can get it from the global ptr. If not, time to allocate it. + _pool = _global_weak_pool.lock(); + if (!_pool) { + _pool = std::make_shared(); + if (!_pool) { + throw rte("Fatal: could not allocate thread pool!"); + } + // Update global weak pointer + _global_weak_pool = _pool; } - return *_static_storage_threadpool; } diff --git a/src/lasp/dsp/lasp_thread.h b/src/lasp/dsp/lasp_thread.h index 8a6c0ca..c28805e 100644 --- a/src/lasp/dsp/lasp_thread.h +++ b/src/lasp/dsp/lasp_thread.h @@ -2,18 +2,54 @@ #include "BS_thread_pool.hpp" /** - * @brief Return reference to global (singleton) thread pool. The threadpool is - * created using the default argument, which results in exactly - * hardware_concurrency() amount of threads. - * - * @return Thread pool ref. + * @brief Simple wrapper around BS::thread_pool that makes a BS::thread_pool a + * singleton, such that a thread pool can be used around in the code, and + * safely spawn threads also from other threads. Only wraps a submit() and + * push_task for now. */ -BS::thread_pool& getPool(); +class ThreadSafeThreadPool { + /** + * @brief Shared access to the thread pool. + */ + std::shared_ptr _pool; + /** + * @brief Global mutex, used to restrict pool access to a single thread at + * once. + */ + static std::mutex _mtx; + + using Lck = std::scoped_lock; + ThreadSafeThreadPool(const ThreadSafeThreadPool&) = delete; + ThreadSafeThreadPool & + operator=(const ThreadSafeThreadPool&) = delete; + +public: + /** + * @brief Instantiate handle to the thread pool. + */ + ThreadSafeThreadPool(); + + + /** + * @brief Wrapper around BS::thread_pool::submit(...) + */ + template < + typename F, typename... A, + typename R = std::invoke_result_t, std::decay_t...>> + [[nodiscard]] std::future submit(F &&task, A &&...args) { + /// Lock access to pool + Lck lck(_mtx); + + return _pool->submit(task, args...); + } + /** + * @brief Wrapper around BS::thread_pool::push_task(...) + */ + template void push_task(F &&task, A &&...args) { + /// Lock access to pool + Lck lck(_mtx); + _pool->push_task(task, args...); + } +}; -/** - * @brief The global thread pool is stored in a unique_ptr, so in normal C++ - * code the thread pool is deleted at the end of main(). However this does not - * hold when LASP code is run - */ -void destroyThreadPool(); diff --git a/src/lasp/dsp/lasp_threadedindatahandler.cpp b/src/lasp/dsp/lasp_threadedindatahandler.cpp index bc60b86..340dc1a 100644 --- a/src/lasp/dsp/lasp_threadedindatahandler.cpp +++ b/src/lasp/dsp/lasp_threadedindatahandler.cpp @@ -62,8 +62,6 @@ ThreadedInDataHandlerBase::ThreadedInDataHandlerBase(SmgrHandle mgr, DEBUGTRACE_ENTER; - // Initialize thread pool, if not already done - getPool(); } void ThreadedInDataHandlerBase::startThread() { DEBUGTRACE_ENTER; @@ -82,10 +80,9 @@ void ThreadedInDataHandlerBase::_inCallbackFromInDataHandler( _queue->push(daqdata); if (!_thread_running) { - auto &pool = getPool(); DEBUGTRACE_PRINT("Pushing new thread in pool"); _thread_running = true; - pool.push_task(&ThreadedInDataHandlerBase::threadFcn, this); + _pool.push_task(&ThreadedInDataHandlerBase::threadFcn, this); } } diff --git a/src/lasp/dsp/lasp_threadedindatahandler.h b/src/lasp/dsp/lasp_threadedindatahandler.h index 0569205..b769ad1 100644 --- a/src/lasp/dsp/lasp_threadedindatahandler.h +++ b/src/lasp/dsp/lasp_threadedindatahandler.h @@ -1,6 +1,7 @@ #pragma once #include "debugtrace.hpp" #include "lasp_indatahandler.h" +#include "lasp_thread.h" #include #include #include @@ -36,6 +37,8 @@ class ThreadedInDataHandlerBase { std::atomic _thread_running{false}; std::atomic _thread_can_safely_run{false}; + ThreadSafeThreadPool _pool; + /** * @brief Function pointer that is called when new DaqData arrives. */ diff --git a/third_party/gsl-lite b/third_party/gsl-lite index 4720a29..a8c7e5b 160000 --- a/third_party/gsl-lite +++ b/third_party/gsl-lite @@ -1 +1 @@ -Subproject commit 4720a2980a30da085b4ddb4a0ea2a71af7351a48 +Subproject commit a8c7e5bbbd08841836f9b92d72747fb8769dbec4