Made thread pool itself thread safe. Besides, added some extra safety for StreamMgr singleton instance allocation.
Some checks failed
continuous-integration/drone/push Build is failing

This commit is contained in:
Anne de Jong 2023-06-10 15:47:52 +02:00
parent 21df1bc6cf
commit 9b724ab9d5
11 changed files with 110 additions and 53 deletions

View File

@ -9,6 +9,7 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <mutex>
using std::cerr; using std::cerr;
using std::endl; using std::endl;
@ -21,8 +22,9 @@ using rte = std::runtime_error;
* to it has been destroyed (no global stuff left). * to it has been destroyed (no global stuff left).
*/ */
std::weak_ptr<StreamMgr> _mgr; std::weak_ptr<StreamMgr> _mgr;
std::mutex _mgr_mutex;
using Lck = std::scoped_lock<std::mutex>;
/** /**
* @brief The only way to obtain a stream manager, can only be called from the * @brief The only way to obtain a stream manager, can only be called from the
@ -35,6 +37,14 @@ SmgrHandle StreamMgr::getInstance() {
auto mgr = _mgr.lock(); auto mgr = _mgr.lock();
if (!mgr) { if (!mgr) {
// Double Check Locking Pattern, if two threads would simultaneously
// instantiate the singleton instance.
Lck lck(_mgr_mutex);
auto mgr = _mgr.lock();
if (mgr) {
return mgr;
}
mgr = SmgrHandle(new StreamMgr()); mgr = SmgrHandle(new StreamMgr());
if (!mgr) { if (!mgr) {
@ -54,7 +64,7 @@ SmgrHandle StreamMgr::getInstance() {
StreamMgr::StreamMgr() StreamMgr::StreamMgr()
#if LASP_DEBUG == 1 #if LASP_DEBUG == 1
: main_thread_id(std::this_thread::get_id()) : main_thread_id(std::this_thread::get_id())
#endif #endif
{ {
DEBUGTRACE_ENTER; DEBUGTRACE_ENTER;
@ -70,7 +80,6 @@ void StreamMgr::checkRightThread() const {
void StreamMgr::rescanDAQDevices(bool background, void StreamMgr::rescanDAQDevices(bool background,
std::function<void()> callback) { std::function<void()> callback) {
DEBUGTRACE_ENTER; DEBUGTRACE_ENTER;
auto &pool = getPool();
checkRightThread(); checkRightThread();
if (_inputStream || _outputStream) { if (_inputStream || _outputStream) {
@ -87,7 +96,7 @@ void StreamMgr::rescanDAQDevices(bool background,
rescanDAQDevices_impl(callback); rescanDAQDevices_impl(callback);
} else { } else {
DEBUGTRACE_PRINT("Rescanning DAQ devices on different thread..."); DEBUGTRACE_PRINT("Rescanning DAQ devices on different thread...");
pool.push_task(&StreamMgr::rescanDAQDevices_impl, this, callback); _pool.push_task(&StreamMgr::rescanDAQDevices_impl, this, callback);
} }
} }
void StreamMgr::rescanDAQDevices_impl(std::function<void()> callback) { void StreamMgr::rescanDAQDevices_impl(std::function<void()> callback) {

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include "lasp_daq.h" #include "lasp_daq.h"
#include "lasp_siggen.h" #include "lasp_siggen.h"
#include "lasp_thread.h"
#include <list> #include <list>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
@ -30,6 +31,8 @@ class StreamMgr {
*/ */
std::unique_ptr<Daq> _inputStream, _outputStream; std::unique_ptr<Daq> _inputStream, _outputStream;
ThreadSafeThreadPool _pool;
/** /**
* @brief All indata handlers are called when input data is available. Note * @brief All indata handlers are called when input data is available. Note
* that they can be called from different threads and should take care of * that they can be called from different threads and should take care of

View File

@ -44,19 +44,21 @@ SeriesBiquad::SeriesBiquad(const vd &filter_coefs) {
SeriesBiquad SeriesBiquad::firstOrderHighPass(const d fs, const d cuton_Hz) { SeriesBiquad SeriesBiquad::firstOrderHighPass(const d fs, const d cuton_Hz) {
if(fs <= 0) { if (fs <= 0) {
throw rte("Invalid sampling frequency: " + std::to_string(fs) + " [Hz]"); throw rte("Invalid sampling frequency: " + std::to_string(fs) + " [Hz]");
} }
if(cuton_Hz <= 0) { if (cuton_Hz <= 0) {
throw rte("Invalid cuton frequency: " + std::to_string(cuton_Hz) + " [Hz]"); throw rte("Invalid cuton frequency: " + std::to_string(cuton_Hz) + " [Hz]");
} }
if(cuton_Hz >= 0.98*fs/2) { if (cuton_Hz >= 0.98 * fs / 2) {
throw rte("Invalid cuton frequency. We limit this to 0.98* fs / 2. Given value" + std::to_string(cuton_Hz) + " [Hz]"); throw rte(
"Invalid cuton frequency. We limit this to 0.98* fs / 2. Given value" +
std::to_string(cuton_Hz) + " [Hz]");
} }
const d tau = 1/(2*arma::datum::pi*cuton_Hz); const d tau = 1 / (2 * arma::datum::pi * cuton_Hz);
const d facnum = 2*fs*tau/(1+2*fs*tau); const d facnum = 2 * fs * tau / (1 + 2 * fs * tau);
const d facden = (1-2*fs*tau)/(1+2*fs*tau); const d facden = (1 - 2 * fs * tau) / (1 + 2 * fs * tau);
vd coefs(6); vd coefs(6);
// b0 // b0
@ -76,10 +78,8 @@ SeriesBiquad SeriesBiquad::firstOrderHighPass(const d fs, const d cuton_Hz) {
coefs(5) = 0; coefs(5) = 0;
return SeriesBiquad(coefs); return SeriesBiquad(coefs);
} }
std::unique_ptr<Filter> SeriesBiquad::clone() const { std::unique_ptr<Filter> SeriesBiquad::clone() const {
// sos.as_col() concatenates all columns, exactly what we want. // sos.as_col() concatenates all columns, exactly what we want.
return std::make_unique<SeriesBiquad>(sos.as_col()); return std::make_unique<SeriesBiquad>(sos.as_col());
@ -124,7 +124,6 @@ BiquadBank::BiquadBank(const dmat &filters, const vd *gains) {
* for use. * for use.
*/ */
lock lck(_mtx); lock lck(_mtx);
getPool();
for (us i = 0; i < filters.n_cols; i++) { for (us i = 0; i < filters.n_cols; i++) {
_filters.emplace_back(filters.col(i)); _filters.emplace_back(filters.col(i));
@ -153,16 +152,15 @@ void BiquadBank::filter(vd &inout) {
std::vector<std::future<vd>> futs; std::vector<std::future<vd>> futs;
#if 1 #if 1
auto &pool = getPool();
vd inout_cpy = inout; vd inout_cpy = inout;
for (us i = 0; i < _filters.size(); i++) { for (us i = 0; i < _filters.size(); i++) {
futs.emplace_back(pool.submit( futs.emplace_back(_pool.submit(
[&](vd inout, us i) { [&](vd inout, us i) {
_filters[i].filter(inout); _filters[i].filter(inout);
return inout; return inout;
}, // Launch a task to filter. }, // Launch a task to filter.
inout_cpy, i // Column i as argument to the lambda function above. inout_cpy, i // Column i as argument to the lambda function above.
)); ));
} }
// Zero-out in-out and sum-up the filtered values // Zero-out in-out and sum-up the filtered values

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include "lasp_filter.h" #include "lasp_filter.h"
#include "lasp_thread.h"
/** /**
* \addtogroup dsp * \addtogroup dsp
@ -60,6 +61,7 @@ public:
class BiquadBank : public Filter { class BiquadBank : public Filter {
std::vector<SeriesBiquad> _filters; std::vector<SeriesBiquad> _filters;
vd _gains; vd _gains;
ThreadSafeThreadPool _pool;
mutable std::mutex _mtx; mutable std::mutex _mtx;
public: public:

View File

@ -37,9 +37,6 @@ SLM::SLM(const d fs, const d Lref, const us downsampling_fac, const d tau,
DEBUGTRACE_ENTER; DEBUGTRACE_ENTER;
DEBUGTRACE_PRINT(_alpha); DEBUGTRACE_PRINT(_alpha);
// Make sure thread pool is running
getPool();
if (Lref <= 0) { if (Lref <= 0) {
throw rte("Invalid reference level"); throw rte("Invalid reference level");
} }

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include "lasp_biquadbank.h" #include "lasp_biquadbank.h"
#include "lasp_filter.h" #include "lasp_filter.h"
#include "lasp_thread.h"
#include <memory> #include <memory>
#include <optional> #include <optional>
@ -14,6 +15,7 @@
* channel. A channel is the result of a filtered signal * channel. A channel is the result of a filtered signal
*/ */
class SLM { class SLM {
ThreadSafeThreadPool _pool;
/** /**
* @brief A, C or Z weighting, depending on the pre-filter installed. * @brief A, C or Z weighting, depending on the pre-filter installed.
*/ */

View File

@ -5,21 +5,31 @@
#include <memory> #include <memory>
/** /**
* @brief It seems to work much better in cooperation with Pybind11 when this * @brief Store a global weak_ptr, that is used to create new shared pointers
* singleton is implemented with a unique_ptr. * if any other shared pointers are still alive. If not, we create a new
* instance.
*/ */
std::unique_ptr<BS::thread_pool> _static_storage_threadpool; std::weak_ptr<BS::thread_pool> _global_weak_pool;
void destroyThreadPool() { /**
* @brief Static storage for the mutex.
*/
std::mutex ThreadSafeThreadPool::_mtx;
using Lck = std::scoped_lock<std::mutex>;
using rte = std::runtime_error;
ThreadSafeThreadPool::ThreadSafeThreadPool() {
DEBUGTRACE_ENTER; DEBUGTRACE_ENTER;
_static_storage_threadpool = nullptr; Lck lck(_mtx);
} /// See if we can get it from the global ptr. If not, time to allocate it.
_pool = _global_weak_pool.lock();
BS::thread_pool &getPool() { if (!_pool) {
/* DEBUGTRACE_ENTER; */ _pool = std::make_shared<BS::thread_pool>();
if (!_static_storage_threadpool) { if (!_pool) {
DEBUGTRACE_PRINT("Creating new thread pool"); throw rte("Fatal: could not allocate thread pool!");
_static_storage_threadpool = std::make_unique<BS::thread_pool>(); }
// Update global weak pointer
_global_weak_pool = _pool;
} }
return *_static_storage_threadpool;
} }

View File

@ -2,18 +2,54 @@
#include "BS_thread_pool.hpp" #include "BS_thread_pool.hpp"
/** /**
* @brief Return reference to global (singleton) thread pool. The threadpool is * @brief Simple wrapper around BS::thread_pool that makes a BS::thread_pool a
* created using the default argument, which results in exactly * singleton, such that a thread pool can be used around in the code, and
* hardware_concurrency() amount of threads. * safely spawn threads also from other threads. Only wraps a submit() and
* * push_task for now.
* @return Thread pool ref.
*/ */
BS::thread_pool& getPool(); class ThreadSafeThreadPool {
/**
* @brief Shared access to the thread pool.
*/
std::shared_ptr<BS::thread_pool> _pool;
/**
* @brief Global mutex, used to restrict pool access to a single thread at
* once.
*/
static std::mutex _mtx;
using Lck = std::scoped_lock<std::mutex>;
ThreadSafeThreadPool(const ThreadSafeThreadPool&) = delete;
ThreadSafeThreadPool &
operator=(const ThreadSafeThreadPool&) = delete;
public:
/**
* @brief Instantiate handle to the thread pool.
*/
ThreadSafeThreadPool();
/**
* @brief Wrapper around BS::thread_pool::submit(...)
*/
template <
typename F, typename... A,
typename R = std::invoke_result_t<std::decay_t<F>, std::decay_t<A>...>>
[[nodiscard]] std::future<R> submit(F &&task, A &&...args) {
/// Lock access to pool
Lck lck(_mtx);
return _pool->submit(task, args...);
}
/**
* @brief Wrapper around BS::thread_pool::push_task(...)
*/
template <typename F, typename... A> void push_task(F &&task, A &&...args) {
/// Lock access to pool
Lck lck(_mtx);
_pool->push_task(task, args...);
}
};
/**
* @brief The global thread pool is stored in a unique_ptr, so in normal C++
* code the thread pool is deleted at the end of main(). However this does not
* hold when LASP code is run
*/
void destroyThreadPool();

View File

@ -62,8 +62,6 @@ ThreadedInDataHandlerBase::ThreadedInDataHandlerBase(SmgrHandle mgr,
DEBUGTRACE_ENTER; DEBUGTRACE_ENTER;
// Initialize thread pool, if not already done
getPool();
} }
void ThreadedInDataHandlerBase::startThread() { void ThreadedInDataHandlerBase::startThread() {
DEBUGTRACE_ENTER; DEBUGTRACE_ENTER;
@ -82,10 +80,9 @@ void ThreadedInDataHandlerBase::_inCallbackFromInDataHandler(
_queue->push(daqdata); _queue->push(daqdata);
if (!_thread_running) { if (!_thread_running) {
auto &pool = getPool();
DEBUGTRACE_PRINT("Pushing new thread in pool"); DEBUGTRACE_PRINT("Pushing new thread in pool");
_thread_running = true; _thread_running = true;
pool.push_task(&ThreadedInDataHandlerBase::threadFcn, this); _pool.push_task(&ThreadedInDataHandlerBase::threadFcn, this);
} }
} }

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include "debugtrace.hpp" #include "debugtrace.hpp"
#include "lasp_indatahandler.h" #include "lasp_indatahandler.h"
#include "lasp_thread.h"
#include <atomic> #include <atomic>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
@ -36,6 +37,8 @@ class ThreadedInDataHandlerBase {
std::atomic<bool> _thread_running{false}; std::atomic<bool> _thread_running{false};
std::atomic<bool> _thread_can_safely_run{false}; std::atomic<bool> _thread_can_safely_run{false};
ThreadSafeThreadPool _pool;
/** /**
* @brief Function pointer that is called when new DaqData arrives. * @brief Function pointer that is called when new DaqData arrives.
*/ */

@ -1 +1 @@
Subproject commit 4720a2980a30da085b4ddb4a0ea2a71af7351a48 Subproject commit a8c7e5bbbd08841836f9b92d72747fb8769dbec4