Commit d50aacb1 by Ben Clayton

Update Marl to 539094011

Includes new Thread Safety Analysis helpers Changes: 539094011 CMake: Export MARL_THREAD_SAFETY_ANALYSIS_SUPPORTED c7f70ba7a CMake: Bump min version + include(CheckCXXSourceCompiles) 658a204fc Replace SAL annotations with clang's TSA annotations 9630bec2f Fix CMake warning: "Policy CMP0023 is not set" 9f369ad5d Update yarn:: to marl:: in an example Commands: ./third_party/update-marl.sh --squash Bug: b/140546382 Change-Id: Idb4253c11cece99ea4b22f965b974b63e26b51a7
parents b4a27407 8787897e
...@@ -3,8 +3,3 @@ BasedOnStyle: Chromium ...@@ -3,8 +3,3 @@ BasedOnStyle: Chromium
--- ---
Language: Cpp Language: Cpp
StatementMacros:
- _Acquires_lock_
- _Releases_lock_
- _Requires_lock_held_
- _When_
\ No newline at end of file
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
cmake_minimum_required(VERSION 2.8) cmake_minimum_required(VERSION 3.0)
set (CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD 11)
project(Marl C CXX ASM) project(Marl C CXX ASM)
include(CheckCXXSourceCompiles)
########################################################### ###########################################################
# Options # Options
########################################################### ###########################################################
...@@ -76,6 +78,32 @@ if(MARL_BUILD_BENCHMARKS) ...@@ -76,6 +78,32 @@ if(MARL_BUILD_BENCHMARKS)
endif(MARL_BUILD_BENCHMARKS) endif(MARL_BUILD_BENCHMARKS)
########################################################### ###########################################################
# Compiler feature tests
###########################################################
# Check that the Clang Thread Safety Analysis' try_acquire_capability behaves
# correctly. This is broken on some earlier versions of clang.
# See: https://bugs.llvm.org/show_bug.cgi?id=32954
set(SAVE_CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS})
set(CMAKE_REQUIRED_FLAGS "-Wthread-safety -Werror")
check_cxx_source_compiles(
"int main() {
struct __attribute__((capability(\"mutex\"))) Mutex {
void Unlock() __attribute__((release_capability)) {};
bool TryLock() __attribute__((try_acquire_capability(true))) { return true; };
};
Mutex m;
if (m.TryLock()) {
m.Unlock(); // Should not warn.
}
return 0;
}"
MARL_THREAD_SAFETY_ANALYSIS_SUPPORTED)
set(CMAKE_REQUIRED_FLAGS ${SAVE_CMAKE_REQUIRED_FLAGS})
# Export MARL_THREAD_SAFETY_ANALYSIS_SUPPORTED as this may be useful to parent projects
set(MARL_THREAD_SAFETY_ANALYSIS_SUPPORTED PARENT_SCOPE ${MARL_THREAD_SAFETY_ANALYSIS_SUPPORTED})
###########################################################
# File lists # File lists
########################################################### ###########################################################
set(MARL_LIST set(MARL_LIST
...@@ -111,6 +139,10 @@ find_package(Threads REQUIRED) ...@@ -111,6 +139,10 @@ find_package(Threads REQUIRED)
# Functions # Functions
########################################################### ###########################################################
function(marl_set_target_options target) function(marl_set_target_options target)
if (MARL_THREAD_SAFETY_ANALYSIS_SUPPORTED)
target_compile_options(${target} PRIVATE "-Wthread-safety")
endif()
# Enable all warnings # Enable all warnings
if(MSVC) if(MSVC)
target_compile_options(${target} PRIVATE target_compile_options(${target} PRIVATE
...@@ -137,13 +169,13 @@ function(marl_set_target_options target) ...@@ -137,13 +169,13 @@ function(marl_set_target_options target)
if(MARL_ASAN) if(MARL_ASAN)
target_compile_options(${target} PUBLIC "-fsanitize=address") target_compile_options(${target} PUBLIC "-fsanitize=address")
target_link_libraries(${target} "-fsanitize=address") target_link_libraries(${target} PUBLIC "-fsanitize=address")
elseif(MARL_MSAN) elseif(MARL_MSAN)
target_compile_options(${target} PUBLIC "-fsanitize=memory") target_compile_options(${target} PUBLIC "-fsanitize=memory")
target_link_libraries(${target} "-fsanitize=memory") target_link_libraries(${target} PUBLIC "-fsanitize=memory")
elseif(MARL_TSAN) elseif(MARL_TSAN)
target_compile_options(${target} PUBLIC "-fsanitize=thread") target_compile_options(${target} PUBLIC "-fsanitize=thread")
target_link_libraries(${target} "-fsanitize=thread") target_link_libraries(${target} PUBLIC "-fsanitize=thread")
endif() endif()
target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${MARL_INCLUDE_DIR}>) target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${MARL_INCLUDE_DIR}>)
...@@ -232,7 +264,7 @@ if(MARL_BUILD_TESTS) ...@@ -232,7 +264,7 @@ if(MARL_BUILD_TESTS)
marl_set_target_options(marl-unittests) marl_set_target_options(marl-unittests)
target_link_libraries(marl-unittests marl) target_link_libraries(marl-unittests PRIVATE marl)
endif(MARL_BUILD_TESTS) endif(MARL_BUILD_TESTS)
# benchmarks # benchmarks
...@@ -252,7 +284,7 @@ if(MARL_BUILD_BENCHMARKS) ...@@ -252,7 +284,7 @@ if(MARL_BUILD_BENCHMARKS)
marl_set_target_options(marl-benchmarks) marl_set_target_options(marl-benchmarks)
target_link_libraries(marl-benchmarks benchmark::benchmark marl) target_link_libraries(marl-benchmarks PRIVATE benchmark::benchmark marl)
endif(MARL_BUILD_BENCHMARKS) endif(MARL_BUILD_BENCHMARKS)
# examples # examples
...@@ -263,7 +295,7 @@ if(MARL_BUILD_EXAMPLES) ...@@ -263,7 +295,7 @@ if(MARL_BUILD_EXAMPLES)
FOLDER "Examples" FOLDER "Examples"
) )
marl_set_target_options(${target}) marl_set_target_options(${target})
target_link_libraries(${target} marl) target_link_libraries(${target} PRIVATE marl)
endfunction(build_example) endfunction(build_example)
build_example(fractal) build_example(fractal)
......
...@@ -85,10 +85,10 @@ class OnNewThread<void> { ...@@ -85,10 +85,10 @@ class OnNewThread<void> {
// void runABlockingFunctionOnATask() // void runABlockingFunctionOnATask()
// { // {
// // Schedule a task that calls a blocking, non-yielding function. // // Schedule a task that calls a blocking, non-yielding function.
// yarn::schedule([=] { // marl::schedule([=] {
// // call_blocking_function() may block indefinitely. // // call_blocking_function() may block indefinitely.
// // Ensure this call does not block other tasks from running. // // Ensure this call does not block other tasks from running.
// auto result = yarn::blocking_call(call_blocking_function); // auto result = marl::blocking_call(call_blocking_function);
// // call_blocking_function() has now returned. // // call_blocking_function() has now returned.
// // result holds the return value of the blocking function call. // // result holds the return value of the blocking function call.
// }); // });
......
...@@ -19,11 +19,12 @@ ...@@ -19,11 +19,12 @@
#include "debug.h" #include "debug.h"
#include "defer.h" #include "defer.h"
#include "memory.h" #include "memory.h"
#include "mutex.h"
#include "scheduler.h" #include "scheduler.h"
#include "tsa.h"
#include <atomic> #include <atomic>
#include <condition_variable> #include <condition_variable>
#include <mutex>
namespace marl { namespace marl {
...@@ -47,14 +48,14 @@ class ConditionVariable { ...@@ -47,14 +48,14 @@ class ConditionVariable {
// wait() blocks the current fiber or thread until the predicate is satisfied // wait() blocks the current fiber or thread until the predicate is satisfied
// and the ConditionVariable is notified. // and the ConditionVariable is notified.
template <typename Predicate> template <typename Predicate>
inline void wait(std::unique_lock<std::mutex>& lock, Predicate&& pred); inline void wait(marl::lock& lock, Predicate&& pred);
// wait_for() blocks the current fiber or thread until the predicate is // wait_for() blocks the current fiber or thread until the predicate is
// satisfied, and the ConditionVariable is notified, or the timeout has been // satisfied, and the ConditionVariable is notified, or the timeout has been
// reached. Returns false if pred still evaluates to false after the timeout // reached. Returns false if pred still evaluates to false after the timeout
// has been reached, otherwise true. // has been reached, otherwise true.
template <typename Rep, typename Period, typename Predicate> template <typename Rep, typename Period, typename Predicate>
bool wait_for(std::unique_lock<std::mutex>& lock, bool wait_for(marl::lock& lock,
const std::chrono::duration<Rep, Period>& duration, const std::chrono::duration<Rep, Period>& duration,
Predicate&& pred); Predicate&& pred);
...@@ -63,7 +64,7 @@ class ConditionVariable { ...@@ -63,7 +64,7 @@ class ConditionVariable {
// reached. Returns false if pred still evaluates to false after the timeout // reached. Returns false if pred still evaluates to false after the timeout
// has been reached, otherwise true. // has been reached, otherwise true.
template <typename Clock, typename Duration, typename Predicate> template <typename Clock, typename Duration, typename Predicate>
bool wait_until(std::unique_lock<std::mutex>& lock, bool wait_until(marl::lock& lock,
const std::chrono::time_point<Clock, Duration>& timeout, const std::chrono::time_point<Clock, Duration>& timeout,
Predicate&& pred); Predicate&& pred);
...@@ -73,7 +74,7 @@ class ConditionVariable { ...@@ -73,7 +74,7 @@ class ConditionVariable {
ConditionVariable& operator=(const ConditionVariable&) = delete; ConditionVariable& operator=(const ConditionVariable&) = delete;
ConditionVariable& operator=(ConditionVariable&&) = delete; ConditionVariable& operator=(ConditionVariable&&) = delete;
std::mutex mutex; marl::mutex mutex;
containers::list<Scheduler::Fiber*> waiting; containers::list<Scheduler::Fiber*> waiting;
std::condition_variable condition; std::condition_variable condition;
std::atomic<int> numWaiting = {0}; std::atomic<int> numWaiting = {0};
...@@ -89,7 +90,7 @@ void ConditionVariable::notify_one() { ...@@ -89,7 +90,7 @@ void ConditionVariable::notify_one() {
return; return;
} }
{ {
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
for (auto fiber : waiting) { for (auto fiber : waiting) {
fiber->notify(); fiber->notify();
} }
...@@ -104,7 +105,7 @@ void ConditionVariable::notify_all() { ...@@ -104,7 +105,7 @@ void ConditionVariable::notify_all() {
return; return;
} }
{ {
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
for (auto fiber : waiting) { for (auto fiber : waiting) {
fiber->notify(); fiber->notify();
} }
...@@ -115,8 +116,7 @@ void ConditionVariable::notify_all() { ...@@ -115,8 +116,7 @@ void ConditionVariable::notify_all() {
} }
template <typename Predicate> template <typename Predicate>
void ConditionVariable::wait(std::unique_lock<std::mutex>& lock, void ConditionVariable::wait(marl::lock& lock, Predicate&& pred) {
Predicate&& pred) {
if (pred()) { if (pred()) {
return; return;
} }
...@@ -137,7 +137,7 @@ void ConditionVariable::wait(std::unique_lock<std::mutex>& lock, ...@@ -137,7 +137,7 @@ void ConditionVariable::wait(std::unique_lock<std::mutex>& lock,
// Currently running outside of the scheduler. // Currently running outside of the scheduler.
// Delegate to the std::condition_variable. // Delegate to the std::condition_variable.
numWaitingOnCondition++; numWaitingOnCondition++;
condition.wait(lock, pred); lock.wait(condition, pred);
numWaitingOnCondition--; numWaitingOnCondition--;
} }
numWaiting--; numWaiting--;
...@@ -145,7 +145,7 @@ void ConditionVariable::wait(std::unique_lock<std::mutex>& lock, ...@@ -145,7 +145,7 @@ void ConditionVariable::wait(std::unique_lock<std::mutex>& lock,
template <typename Rep, typename Period, typename Predicate> template <typename Rep, typename Period, typename Predicate>
bool ConditionVariable::wait_for( bool ConditionVariable::wait_for(
std::unique_lock<std::mutex>& lock, marl::lock& lock,
const std::chrono::duration<Rep, Period>& duration, const std::chrono::duration<Rep, Period>& duration,
Predicate&& pred) { Predicate&& pred) {
return wait_until(lock, std::chrono::system_clock::now() + duration, pred); return wait_until(lock, std::chrono::system_clock::now() + duration, pred);
...@@ -153,7 +153,7 @@ bool ConditionVariable::wait_for( ...@@ -153,7 +153,7 @@ bool ConditionVariable::wait_for(
template <typename Clock, typename Duration, typename Predicate> template <typename Clock, typename Duration, typename Predicate>
bool ConditionVariable::wait_until( bool ConditionVariable::wait_until(
std::unique_lock<std::mutex>& lock, marl::lock& lock,
const std::chrono::time_point<Clock, Duration>& timeout, const std::chrono::time_point<Clock, Duration>& timeout,
Predicate&& pred) { Predicate&& pred) {
if (pred()) { if (pred()) {
...@@ -181,7 +181,7 @@ bool ConditionVariable::wait_until( ...@@ -181,7 +181,7 @@ bool ConditionVariable::wait_until(
// Delegate to the std::condition_variable. // Delegate to the std::condition_variable.
numWaitingOnCondition++; numWaitingOnCondition++;
defer(numWaitingOnCondition--); defer(numWaitingOnCondition--);
return condition.wait_until(lock, timeout, pred); return lock.wait_until(condition, timeout, pred);
} }
} }
......
...@@ -113,7 +113,7 @@ class Event { ...@@ -113,7 +113,7 @@ class Event {
inline bool wait_until( inline bool wait_until(
const std::chrono::time_point<Clock, Duration>& timeout); const std::chrono::time_point<Clock, Duration>& timeout);
std::mutex mutex; marl::mutex mutex;
ConditionVariable cv; ConditionVariable cv;
const Mode mode; const Mode mode;
bool signalled; bool signalled;
...@@ -127,7 +127,7 @@ Event::Shared::Shared(Allocator* allocator, Mode mode, bool initialState) ...@@ -127,7 +127,7 @@ Event::Shared::Shared(Allocator* allocator, Mode mode, bool initialState)
: cv(allocator), mode(mode), signalled(initialState) {} : cv(allocator), mode(mode), signalled(initialState) {}
void Event::Shared::signal() { void Event::Shared::signal() {
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
if (signalled) { if (signalled) {
return; return;
} }
...@@ -143,7 +143,7 @@ void Event::Shared::signal() { ...@@ -143,7 +143,7 @@ void Event::Shared::signal() {
} }
void Event::Shared::wait() { void Event::Shared::wait() {
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
cv.wait(lock, [&] { return signalled; }); cv.wait(lock, [&] { return signalled; });
if (mode == Mode::Auto) { if (mode == Mode::Auto) {
signalled = false; signalled = false;
...@@ -153,7 +153,7 @@ void Event::Shared::wait() { ...@@ -153,7 +153,7 @@ void Event::Shared::wait() {
template <typename Rep, typename Period> template <typename Rep, typename Period>
bool Event::Shared::wait_for( bool Event::Shared::wait_for(
const std::chrono::duration<Rep, Period>& duration) { const std::chrono::duration<Rep, Period>& duration) {
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
if (!cv.wait_for(lock, duration, [&] { return signalled; })) { if (!cv.wait_for(lock, duration, [&] { return signalled; })) {
return false; return false;
} }
...@@ -166,7 +166,7 @@ bool Event::Shared::wait_for( ...@@ -166,7 +166,7 @@ bool Event::Shared::wait_for(
template <typename Clock, typename Duration> template <typename Clock, typename Duration>
bool Event::Shared::wait_until( bool Event::Shared::wait_until(
const std::chrono::time_point<Clock, Duration>& timeout) { const std::chrono::time_point<Clock, Duration>& timeout) {
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
if (!cv.wait_until(lock, timeout, [&] { return signalled; })) { if (!cv.wait_until(lock, timeout, [&] { return signalled; })) {
return false; return false;
} }
...@@ -186,7 +186,7 @@ void Event::signal() const { ...@@ -186,7 +186,7 @@ void Event::signal() const {
} }
void Event::clear() const { void Event::clear() const {
std::unique_lock<std::mutex> lock(shared->mutex); marl::lock lock(shared->mutex);
shared->signalled = false; shared->signalled = false;
} }
...@@ -206,7 +206,7 @@ bool Event::wait_until( ...@@ -206,7 +206,7 @@ bool Event::wait_until(
} }
bool Event::test() const { bool Event::test() const {
std::unique_lock<std::mutex> lock(shared->mutex); marl::lock lock(shared->mutex);
if (!shared->signalled) { if (!shared->signalled) {
return false; return false;
} }
...@@ -217,7 +217,7 @@ bool Event::test() const { ...@@ -217,7 +217,7 @@ bool Event::test() const {
} }
bool Event::isSignalled() const { bool Event::isSignalled() const {
std::unique_lock<std::mutex> lock(shared->mutex); marl::lock lock(shared->mutex);
return shared->signalled; return shared->signalled;
} }
...@@ -226,7 +226,7 @@ Event Event::any(Mode mode, const Iterator& begin, const Iterator& end) { ...@@ -226,7 +226,7 @@ Event Event::any(Mode mode, const Iterator& begin, const Iterator& end) {
Event any(mode, false); Event any(mode, false);
for (auto it = begin; it != end; it++) { for (auto it = begin; it != end; it++) {
auto s = it->shared; auto s = it->shared;
std::unique_lock<std::mutex> lock(s->mutex); marl::lock lock(s->mutex);
if (s->signalled) { if (s->signalled) {
any.signal(); any.signal();
} }
......
// Copyright 2020 The Marl Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Wrappers around std::mutex and std::unique_lock that provide clang's
// Thread Safety Analysis annotations.
// See: https://clang.llvm.org/docs/ThreadSafetyAnalysis.html
#ifndef marl_mutex_h
#define marl_mutex_h
#include "tsa.h"
#include <condition_variable>
#include <mutex>
namespace marl {
// mutex is a wrapper around std::mutex that offers Thread Safety Analysis
// annotations.
// mutex also holds methods for performing std::condition_variable::wait() calls
// as these require a std::unique_lock<> which are unsupported by the TSA.
class CAPABILITY("mutex") mutex {
public:
inline void lock() ACQUIRE() { _.lock(); }
inline void unlock() RELEASE() { _.unlock(); }
inline bool try_lock() TRY_ACQUIRE(true) { return _.try_lock(); }
// wait_locked calls cv.wait() on this already locked mutex.
template <typename Predicate>
inline void wait_locked(std::condition_variable& cv, Predicate&& p)
REQUIRES(this) {
std::unique_lock<std::mutex> lock(_, std::adopt_lock);
cv.wait(lock, std::forward<Predicate>(p));
lock.release(); // Keep lock held.
}
// wait_until_locked calls cv.wait() on this already locked mutex.
template <typename Predicate, typename Time>
inline bool wait_until_locked(std::condition_variable& cv,
Time&& time,
Predicate&& p) REQUIRES(this) {
std::unique_lock<std::mutex> lock(_, std::adopt_lock);
auto res = cv.wait_until(lock, std::forward<Time>(time),
std::forward<Predicate>(p));
lock.release(); // Keep lock held.
return res;
}
private:
friend class lock;
std::mutex _;
};
// lock is a RAII lock helper that offers Thread Safety Analysis annotations.
// lock also holds methods for performing std::condition_variable::wait()
// calls as these require a std::unique_lock<> which are unsupported by the TSA.
class SCOPED_CAPABILITY lock {
public:
inline lock(mutex& m) ACQUIRE(m) : _(m._) {}
inline ~lock() RELEASE() {}
// wait calls cv.wait() on this lock.
template <typename Predicate>
inline void wait(std::condition_variable& cv, Predicate&& p) {
cv.wait(_, std::forward<Predicate>(p));
}
// wait_until calls cv.wait() on this lock.
template <typename Predicate, typename Time>
inline bool wait_until(std::condition_variable& cv,
Time&& time,
Predicate&& p) {
return cv.wait_until(_, std::forward<Time>(time),
std::forward<Predicate>(p));
}
inline bool owns_lock() const { return _.owns_lock(); }
// lock_no_tsa locks the mutex outside of the visiblity of the thread
// safety analysis. Use with caution.
inline void lock_no_tsa() { _.lock(); }
// unlock_no_tsa unlocks the mutex outside of the visiblity of the thread
// safety analysis. Use with caution.
inline void unlock_no_tsa() { _.unlock(); }
private:
std::unique_lock<std::mutex> _;
};
} // namespace marl
#endif // marl_mutex_h
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
#include "conditionvariable.h" #include "conditionvariable.h"
#include "memory.h" #include "memory.h"
#include "mutex.h"
#include <atomic> #include <atomic>
#include <mutex>
namespace marl { namespace marl {
...@@ -237,7 +237,7 @@ class BoundedPool : public Pool<T> { ...@@ -237,7 +237,7 @@ class BoundedPool : public Pool<T> {
inline void return_(Item*) override; inline void return_(Item*) override;
Item items[N]; Item items[N];
std::mutex mutex; marl::mutex mutex;
ConditionVariable returned; ConditionVariable returned;
Item* free = nullptr; Item* free = nullptr;
}; };
...@@ -281,7 +281,7 @@ typename BoundedPool<T, N, POLICY>::Loan BoundedPool<T, N, POLICY>::borrow() ...@@ -281,7 +281,7 @@ typename BoundedPool<T, N, POLICY>::Loan BoundedPool<T, N, POLICY>::borrow()
template <typename T, int N, PoolPolicy POLICY> template <typename T, int N, PoolPolicy POLICY>
template <typename F> template <typename F>
void BoundedPool<T, N, POLICY>::borrow(size_t n, const F& f) const { void BoundedPool<T, N, POLICY>::borrow(size_t n, const F& f) const {
std::unique_lock<std::mutex> lock(storage->mutex); marl::lock lock(storage->mutex);
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
storage->returned.wait(lock, [&] { return storage->free != nullptr; }); storage->returned.wait(lock, [&] { return storage->free != nullptr; });
auto item = storage->free; auto item = storage->free;
...@@ -296,14 +296,16 @@ void BoundedPool<T, N, POLICY>::borrow(size_t n, const F& f) const { ...@@ -296,14 +296,16 @@ void BoundedPool<T, N, POLICY>::borrow(size_t n, const F& f) const {
template <typename T, int N, PoolPolicy POLICY> template <typename T, int N, PoolPolicy POLICY>
std::pair<typename BoundedPool<T, N, POLICY>::Loan, bool> std::pair<typename BoundedPool<T, N, POLICY>::Loan, bool>
BoundedPool<T, N, POLICY>::tryBorrow() const { BoundedPool<T, N, POLICY>::tryBorrow() const {
std::unique_lock<std::mutex> lock(storage->mutex); Item* item = nullptr;
if (storage->free == nullptr) { {
return std::make_pair(Loan(), false); marl::lock lock(storage->mutex);
if (storage->free == nullptr) {
return std::make_pair(Loan(), false);
}
item = storage->free;
storage->free = storage->free->next;
item->pool = this;
} }
auto item = storage->free;
storage->free = storage->free->next;
item->pool = this;
lock.unlock();
if (POLICY == PoolPolicy::Reconstruct) { if (POLICY == PoolPolicy::Reconstruct) {
item->construct(); item->construct();
} }
...@@ -315,10 +317,11 @@ void BoundedPool<T, N, POLICY>::Storage::return_(Item* item) { ...@@ -315,10 +317,11 @@ void BoundedPool<T, N, POLICY>::Storage::return_(Item* item) {
if (POLICY == PoolPolicy::Reconstruct) { if (POLICY == PoolPolicy::Reconstruct) {
item->destruct(); item->destruct();
} }
std::unique_lock<std::mutex> lock(mutex); {
item->next = free; marl::lock lock(mutex);
free = item; item->next = free;
lock.unlock(); free = item;
}
returned.notify_one(); returned.notify_one();
} }
...@@ -359,7 +362,7 @@ class UnboundedPool : public Pool<T> { ...@@ -359,7 +362,7 @@ class UnboundedPool : public Pool<T> {
inline void return_(Item*) override; inline void return_(Item*) override;
Allocator* allocator; Allocator* allocator;
std::mutex mutex; marl::mutex mutex;
std::vector<Item*> items; std::vector<Item*> items;
Item* free = nullptr; Item* free = nullptr;
}; };
...@@ -398,7 +401,7 @@ Loan<T> UnboundedPool<T, POLICY>::borrow() const { ...@@ -398,7 +401,7 @@ Loan<T> UnboundedPool<T, POLICY>::borrow() const {
template <typename T, PoolPolicy POLICY> template <typename T, PoolPolicy POLICY>
template <typename F> template <typename F>
inline void UnboundedPool<T, POLICY>::borrow(size_t n, const F& f) const { inline void UnboundedPool<T, POLICY>::borrow(size_t n, const F& f) const {
std::unique_lock<std::mutex> lock(storage->mutex); marl::lock lock(storage->mutex);
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
if (storage->free == nullptr) { if (storage->free == nullptr) {
auto count = std::max<size_t>(storage->items.size(), 32); auto count = std::max<size_t>(storage->items.size(), 32);
...@@ -427,10 +430,9 @@ void UnboundedPool<T, POLICY>::Storage::return_(Item* item) { ...@@ -427,10 +430,9 @@ void UnboundedPool<T, POLICY>::Storage::return_(Item* item) {
if (POLICY == PoolPolicy::Reconstruct) { if (POLICY == PoolPolicy::Reconstruct) {
item->destruct(); item->destruct();
} }
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
item->next = free; item->next = free;
free = item; free = item;
lock.unlock();
} }
} // namespace marl } // namespace marl
......
// Copyright 2019 The Marl Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Stubs SAL annotation macros for platforms that do not support them.
// See
// https://docs.microsoft.com/en-us/visualstudio/code-quality/annotating-locking-behavior?view=vs-2019
#ifndef marl_sal_h
#define marl_sal_h
#ifndef _Acquires_lock_
#define _Acquires_lock_(...)
#endif
#ifndef _Guarded_by_
#define _Guarded_by_(...)
#endif
#ifndef _Releases_lock_
#define _Releases_lock_(...)
#endif
#ifndef _Requires_lock_held_
#define _Requires_lock_held_(...)
#endif
#ifndef _When_
#define _When_(...)
#endif
#endif // marl_sal_h
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "debug.h" #include "debug.h"
#include "memory.h" #include "memory.h"
#include "sal.h" #include "mutex.h"
#include "task.h" #include "task.h"
#include "thread.h" #include "thread.h"
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
#include <deque> #include <deque>
#include <functional> #include <functional>
#include <map> #include <map>
#include <mutex>
#include <set> #include <set>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
...@@ -104,8 +103,6 @@ class Scheduler { ...@@ -104,8 +103,6 @@ class Scheduler {
// thread that previously executed it. // thread that previously executed it.
class Fiber { class Fiber {
public: public:
using Lock = std::unique_lock<std::mutex>;
// current() returns the currently executing fiber, or nullptr if called // current() returns the currently executing fiber, or nullptr if called
// without a bound scheduler. // without a bound scheduler.
static Fiber* current(); static Fiber* current();
...@@ -122,8 +119,7 @@ class Scheduler { ...@@ -122,8 +119,7 @@ class Scheduler {
// will be locked before wait() returns. // will be locked before wait() returns.
// pred will be always be called with the lock held. // pred will be always be called with the lock held.
// wait() must only be called on the currently executing fiber. // wait() must only be called on the currently executing fiber.
_Requires_lock_held_(lock) void wait(marl::lock& lock, const Predicate& pred);
void wait(Lock& lock, const Predicate& pred);
// wait() suspends execution of this Fiber until the Fiber is woken up with // wait() suspends execution of this Fiber until the Fiber is woken up with
// a call to notify() and the predicate pred returns true, or sometime after // a call to notify() and the predicate pred returns true, or sometime after
...@@ -139,9 +135,8 @@ class Scheduler { ...@@ -139,9 +135,8 @@ class Scheduler {
// will be locked before wait() returns. // will be locked before wait() returns.
// pred will be always be called with the lock held. // pred will be always be called with the lock held.
// wait() must only be called on the currently executing fiber. // wait() must only be called on the currently executing fiber.
_Requires_lock_held_(lock)
template <typename Clock, typename Duration> template <typename Clock, typename Duration>
inline bool wait(Lock& lock, inline bool wait(marl::lock& lock,
const std::chrono::time_point<Clock, Duration>& timeout, const std::chrono::time_point<Clock, Duration>& timeout,
const Predicate& pred); const Predicate& pred);
...@@ -307,56 +302,51 @@ class Scheduler { ...@@ -307,56 +302,51 @@ class Scheduler {
Worker(Scheduler* scheduler, Mode mode, uint32_t id); Worker(Scheduler* scheduler, Mode mode, uint32_t id);
// start() begins execution of the worker. // start() begins execution of the worker.
void start(); void start() EXCLUDES(work.mutex);
// stop() ceases execution of the worker, blocking until all pending // stop() ceases execution of the worker, blocking until all pending
// tasks have fully finished. // tasks have fully finished.
void stop(); void stop() EXCLUDES(work.mutex);
// wait() suspends execution of the current task until the predicate pred // wait() suspends execution of the current task until the predicate pred
// returns true or the optional timeout is reached. // returns true or the optional timeout is reached.
// See Fiber::wait() for more information. // See Fiber::wait() for more information.
_Requires_lock_held_(lock) bool wait(marl::lock& lock, const TimePoint* timeout, const Predicate& pred)
bool wait(Fiber::Lock& lock, EXCLUDES(work.mutex);
const TimePoint* timeout,
const Predicate& pred);
// wait() suspends execution of the current task until the fiber is // wait() suspends execution of the current task until the fiber is
// notified, or the optional timeout is reached. // notified, or the optional timeout is reached.
// See Fiber::wait() for more information. // See Fiber::wait() for more information.
bool wait(const TimePoint* timeout); bool wait(const TimePoint* timeout) EXCLUDES(work.mutex);
// suspend() suspends the currenetly executing Fiber until the fiber is // suspend() suspends the currenetly executing Fiber until the fiber is
// woken with a call to enqueue(Fiber*), or automatically sometime after the // woken with a call to enqueue(Fiber*), or automatically sometime after the
// optional timeout. // optional timeout.
_Requires_lock_held_(work.mutex) void suspend(const TimePoint* timeout) REQUIRES(work.mutex);
void suspend(const TimePoint* timeout);
// enqueue(Fiber*) enqueues resuming of a suspended fiber. // enqueue(Fiber*) enqueues resuming of a suspended fiber.
void enqueue(Fiber* fiber); void enqueue(Fiber* fiber) EXCLUDES(work.mutex);
// enqueue(Task&&) enqueues a new, unstarted task. // enqueue(Task&&) enqueues a new, unstarted task.
void enqueue(Task&& task); void enqueue(Task&& task) EXCLUDES(work.mutex);
// tryLock() attempts to lock the worker for task enqueing. // tryLock() attempts to lock the worker for task enqueing.
// If the lock was successful then true is returned, and the caller must // If the lock was successful then true is returned, and the caller must
// call enqueueAndUnlock(). // call enqueueAndUnlock().
_When_(return == true, _Acquires_lock_(work.mutex)) bool tryLock() EXCLUDES(work.mutex) TRY_ACQUIRE(true, work.mutex);
bool tryLock();
// enqueueAndUnlock() enqueues the task and unlocks the worker. // enqueueAndUnlock() enqueues the task and unlocks the worker.
// Must only be called after a call to tryLock() which returned true. // Must only be called after a call to tryLock() which returned true.
_Requires_lock_held_(work.mutex) // _Releases_lock_(work.mutex)
_Releases_lock_(work.mutex) void enqueueAndUnlock(Task&& task) REQUIRES(work.mutex) RELEASE(work.mutex);
void enqueueAndUnlock(Task&& task);
// runUntilShutdown() processes all tasks and fibers until there are no more // runUntilShutdown() processes all tasks and fibers until there are no more
// and shutdown is true, upon runUntilShutdown() returns. // and shutdown is true, upon runUntilShutdown() returns.
void runUntilShutdown(); void runUntilShutdown() REQUIRES(work.mutex);
// steal() attempts to steal a Task from the worker for another worker. // steal() attempts to steal a Task from the worker for another worker.
// Returns true if a task was taken and assigned to out, otherwise false. // Returns true if a task was taken and assigned to out, otherwise false.
bool steal(Task& out); bool steal(Task& out) EXCLUDES(work.mutex);
// getCurrent() returns the Worker currently bound to the current // getCurrent() returns the Worker currently bound to the current
// thread. // thread.
...@@ -371,27 +361,22 @@ class Scheduler { ...@@ -371,27 +361,22 @@ class Scheduler {
private: private:
// run() is the task processing function for the worker. // run() is the task processing function for the worker.
// run() processes tasks until stop() is called. // run() processes tasks until stop() is called.
_Requires_lock_held_(work.mutex) void run() REQUIRES(work.mutex);
void run();
// createWorkerFiber() creates a new fiber that when executed calls // createWorkerFiber() creates a new fiber that when executed calls
// run(). // run().
_Requires_lock_held_(work.mutex) Fiber* createWorkerFiber() REQUIRES(work.mutex);
Fiber* createWorkerFiber();
// switchToFiber() switches execution to the given fiber. The fiber // switchToFiber() switches execution to the given fiber. The fiber
// must belong to this worker. // must belong to this worker.
_Requires_lock_held_(work.mutex) void switchToFiber(Fiber*) REQUIRES(work.mutex);
void switchToFiber(Fiber*);
// runUntilIdle() executes all pending tasks and then returns. // runUntilIdle() executes all pending tasks and then returns.
_Requires_lock_held_(work.mutex) void runUntilIdle() REQUIRES(work.mutex);
void runUntilIdle();
// waitForWork() blocks until new work is available, potentially calling // waitForWork() blocks until new work is available, potentially calling
// spinForWork(). // spinForWork().
_Requires_lock_held_(work.mutex) void waitForWork() REQUIRES(work.mutex);
void waitForWork();
// spinForWork() attempts to steal work from another Worker, and keeps // spinForWork() attempts to steal work from another Worker, and keeps
// the thread awake for a short duration. This reduces overheads of // the thread awake for a short duration. This reduces overheads of
...@@ -400,31 +385,28 @@ class Scheduler { ...@@ -400,31 +385,28 @@ class Scheduler {
// enqueueFiberTimeouts() enqueues all the fibers that have finished // enqueueFiberTimeouts() enqueues all the fibers that have finished
// waiting. // waiting.
_Requires_lock_held_(work.mutex) void enqueueFiberTimeouts() REQUIRES(work.mutex);
void enqueueFiberTimeouts();
_Requires_lock_held_(work.mutex)
inline void changeFiberState(Fiber* fiber, inline void changeFiberState(Fiber* fiber,
Fiber::State from, Fiber::State from,
Fiber::State to) const; Fiber::State to) const REQUIRES(work.mutex);
_Requires_lock_held_(work.mutex) inline void setFiberState(Fiber* fiber, Fiber::State to) const
inline void setFiberState(Fiber* fiber, Fiber::State to) const; REQUIRES(work.mutex);
// Work holds tasks and fibers that are enqueued on the Worker. // Work holds tasks and fibers that are enqueued on the Worker.
struct Work { struct Work {
std::atomic<uint64_t> num = {0}; // tasks.size() + fibers.size() std::atomic<uint64_t> num = {0}; // tasks.size() + fibers.size()
_Guarded_by_(mutex) uint64_t numBlockedFibers = 0; GUARDED_BY(mutex) uint64_t numBlockedFibers = 0;
_Guarded_by_(mutex) TaskQueue tasks; GUARDED_BY(mutex) TaskQueue tasks;
_Guarded_by_(mutex) FiberQueue fibers; GUARDED_BY(mutex) FiberQueue fibers;
_Guarded_by_(mutex) WaitingFibers waiting; GUARDED_BY(mutex) WaitingFibers waiting;
_Guarded_by_(mutex) bool notifyAdded = true; GUARDED_BY(mutex) bool notifyAdded = true;
std::condition_variable added; std::condition_variable added;
std::mutex mutex; marl::mutex mutex;
_Requires_lock_held_(mutex)
template <typename F> template <typename F>
inline void wait(F&&); inline void wait(F&&) REQUIRES(mutex);
}; };
// https://en.wikipedia.org/wiki/Xorshift // https://en.wikipedia.org/wiki/Xorshift
...@@ -472,7 +454,7 @@ class Scheduler { ...@@ -472,7 +454,7 @@ class Scheduler {
Allocator* const allocator; Allocator* const allocator;
std::function<void()> threadInitFunc; std::function<void()> threadInitFunc;
std::mutex threadInitFuncMutex; mutex threadInitFuncMutex;
std::array<std::atomic<int>, 8> spinningWorkers; std::array<std::atomic<int>, 8> spinningWorkers;
std::atomic<unsigned int> nextSpinningWorkerIdx = {0x8000000}; std::atomic<unsigned int> nextSpinningWorkerIdx = {0x8000000};
...@@ -484,17 +466,18 @@ class Scheduler { ...@@ -484,17 +466,18 @@ class Scheduler {
std::array<Worker*, MaxWorkerThreads> workerThreads; std::array<Worker*, MaxWorkerThreads> workerThreads;
struct SingleThreadedWorkers { struct SingleThreadedWorkers {
std::mutex mutex; using WorkerByTid =
std::condition_variable unbind; std::unordered_map<std::thread::id, Allocator::unique_ptr<Worker>>;
std::unordered_map<std::thread::id, Allocator::unique_ptr<Worker>> byTid; marl::mutex mutex;
GUARDED_BY(mutex) std::condition_variable unbind;
GUARDED_BY(mutex) WorkerByTid byTid;
}; };
SingleThreadedWorkers singleThreadedWorkers; SingleThreadedWorkers singleThreadedWorkers;
}; };
_Requires_lock_held_(lock)
template <typename Clock, typename Duration> template <typename Clock, typename Duration>
bool Scheduler::Fiber::wait( bool Scheduler::Fiber::wait(
Lock& lock, marl::lock& lock,
const std::chrono::time_point<Clock, Duration>& timeout, const std::chrono::time_point<Clock, Duration>& timeout,
const Predicate& pred) { const Predicate& pred) {
using ToDuration = typename TimePoint::duration; using ToDuration = typename TimePoint::duration;
......
...@@ -105,14 +105,14 @@ class Ticket { ...@@ -105,14 +105,14 @@ class Ticket {
inline ~Record(); inline ~Record();
inline void done(); inline void done();
inline void callAndUnlock(std::unique_lock<std::mutex>& lock); inline void callAndUnlock(marl::lock& lock);
inline void unlink(); // guarded by shared->mutex
ConditionVariable isCalledCondVar; ConditionVariable isCalledCondVar;
std::shared_ptr<Shared> shared; std::shared_ptr<Shared> shared;
Record* next = nullptr; // guarded by shared->mutex Record* next = nullptr; // guarded by shared->mutex
Record* prev = nullptr; // guarded by shared->mutex Record* prev = nullptr; // guarded by shared->mutex
inline void unlink(); // guarded by shared->mutex
OnCall onCall; // guarded by shared->mutex OnCall onCall; // guarded by shared->mutex
bool isCalled = false; // guarded by shared->mutex bool isCalled = false; // guarded by shared->mutex
std::atomic<bool> isDone = {false}; std::atomic<bool> isDone = {false};
...@@ -120,7 +120,7 @@ class Ticket { ...@@ -120,7 +120,7 @@ class Ticket {
// Data shared between all tickets and the queue. // Data shared between all tickets and the queue.
struct Shared { struct Shared {
std::mutex mutex; marl::mutex mutex;
Record tail; Record tail;
}; };
...@@ -136,7 +136,7 @@ class Ticket { ...@@ -136,7 +136,7 @@ class Ticket {
Ticket::Ticket(Loan<Record>&& record) : record(std::move(record)) {} Ticket::Ticket(Loan<Record>&& record) : record(std::move(record)) {}
void Ticket::wait() const { void Ticket::wait() const {
std::unique_lock<std::mutex> lock(record->shared->mutex); marl::lock lock(record->shared->mutex);
record->isCalledCondVar.wait(lock, [this] { return record->isCalled; }); record->isCalledCondVar.wait(lock, [this] { return record->isCalled; });
} }
...@@ -146,7 +146,7 @@ void Ticket::done() const { ...@@ -146,7 +146,7 @@ void Ticket::done() const {
template <typename Function> template <typename Function>
void Ticket::onCall(Function&& f) const { void Ticket::onCall(Function&& f) const {
std::unique_lock<std::mutex> lock(record->shared->mutex); marl::lock lock(record->shared->mutex);
if (record->isCalled) { if (record->isCalled) {
marl::schedule(std::move(f)); marl::schedule(std::move(f));
return; return;
...@@ -192,7 +192,7 @@ void Ticket::Queue::take(size_t n, const F& f) { ...@@ -192,7 +192,7 @@ void Ticket::Queue::take(size_t n, const F& f) {
f(std::move(Ticket(std::move(rec)))); f(std::move(Ticket(std::move(rec))));
}); });
last->next = &shared->tail; last->next = &shared->tail;
std::unique_lock<std::mutex> lock(shared->mutex); marl::lock lock(shared->mutex);
first->prev = shared->tail.prev; first->prev = shared->tail.prev;
shared->tail.prev = last.get(); shared->tail.prev = last.get();
if (first->prev == nullptr) { if (first->prev == nullptr) {
...@@ -216,7 +216,7 @@ void Ticket::Record::done() { ...@@ -216,7 +216,7 @@ void Ticket::Record::done() {
if (isDone.exchange(true)) { if (isDone.exchange(true)) {
return; return;
} }
std::unique_lock<std::mutex> lock(shared->mutex); marl::lock lock(shared->mutex);
auto callNext = (prev == nullptr && next != nullptr) ? next : nullptr; auto callNext = (prev == nullptr && next != nullptr) ? next : nullptr;
unlink(); unlink();
if (callNext != nullptr) { if (callNext != nullptr) {
...@@ -225,7 +225,7 @@ void Ticket::Record::done() { ...@@ -225,7 +225,7 @@ void Ticket::Record::done() {
} }
} }
void Ticket::Record::callAndUnlock(std::unique_lock<std::mutex>& lock) { void Ticket::Record::callAndUnlock(marl::lock& lock) {
if (isCalled) { if (isCalled) {
return; return;
} }
...@@ -233,7 +233,7 @@ void Ticket::Record::callAndUnlock(std::unique_lock<std::mutex>& lock) { ...@@ -233,7 +233,7 @@ void Ticket::Record::callAndUnlock(std::unique_lock<std::mutex>& lock) {
OnCall callback; OnCall callback;
std::swap(callback, onCall); std::swap(callback, onCall);
isCalledCondVar.notify_all(); isCalledCondVar.notify_all();
lock.unlock(); lock.unlock_no_tsa();
if (callback) { if (callback) {
marl::schedule(std::move(callback)); marl::schedule(std::move(callback));
......
// Copyright 2020 The Marl Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Stubs Thread-Safty-Analysis annotation macros for platforms that do not
// support them.
// See https://clang.llvm.org/docs/ThreadSafetyAnalysis.html
#ifndef marl_tsa_h
#define marl_tsa_h
// Enable thread safety attributes only with clang.
// The attributes can be safely erased when compiling with other compilers.
#if defined(__clang__) && (!defined(SWIG))
#define THREAD_ANNOTATION_ATTRIBUTE__(x) __attribute__((x))
#else
#define THREAD_ANNOTATION_ATTRIBUTE__(x) // no-op
#endif
#define CAPABILITY(x) THREAD_ANNOTATION_ATTRIBUTE__(capability(x))
#define SCOPED_CAPABILITY THREAD_ANNOTATION_ATTRIBUTE__(scoped_lockable)
#define GUARDED_BY(x) THREAD_ANNOTATION_ATTRIBUTE__(guarded_by(x))
#define PT_GUARDED_BY(x) THREAD_ANNOTATION_ATTRIBUTE__(pt_guarded_by(x))
#define ACQUIRED_BEFORE(...) \
THREAD_ANNOTATION_ATTRIBUTE__(acquired_before(__VA_ARGS__))
#define ACQUIRED_AFTER(...) \
THREAD_ANNOTATION_ATTRIBUTE__(acquired_after(__VA_ARGS__))
#define REQUIRES(...) \
THREAD_ANNOTATION_ATTRIBUTE__(requires_capability(__VA_ARGS__))
#define REQUIRES_SHARED(...) \
THREAD_ANNOTATION_ATTRIBUTE__(requires_shared_capability(__VA_ARGS__))
#define ACQUIRE(...) \
THREAD_ANNOTATION_ATTRIBUTE__(acquire_capability(__VA_ARGS__))
#define ACQUIRE_SHARED(...) \
THREAD_ANNOTATION_ATTRIBUTE__(acquire_shared_capability(__VA_ARGS__))
#define RELEASE(...) \
THREAD_ANNOTATION_ATTRIBUTE__(release_capability(__VA_ARGS__))
#define RELEASE_SHARED(...) \
THREAD_ANNOTATION_ATTRIBUTE__(release_shared_capability(__VA_ARGS__))
#define TRY_ACQUIRE(...) \
THREAD_ANNOTATION_ATTRIBUTE__(try_acquire_capability(__VA_ARGS__))
#define TRY_ACQUIRE_SHARED(...) \
THREAD_ANNOTATION_ATTRIBUTE__(try_acquire_shared_capability(__VA_ARGS__))
#define EXCLUDES(...) THREAD_ANNOTATION_ATTRIBUTE__(locks_excluded(__VA_ARGS__))
#define ASSERT_CAPABILITY(x) THREAD_ANNOTATION_ATTRIBUTE__(assert_capability(x))
#define ASSERT_SHARED_CAPABILITY(x) \
THREAD_ANNOTATION_ATTRIBUTE__(assert_shared_capability(x))
#define RETURN_CAPABILITY(x) THREAD_ANNOTATION_ATTRIBUTE__(lock_returned(x))
#define NO_THREAD_SAFETY_ANALYSIS \
THREAD_ANNOTATION_ATTRIBUTE__(no_thread_safety_analysis)
#endif // marl_tsa_h
...@@ -70,7 +70,7 @@ class WaitGroup { ...@@ -70,7 +70,7 @@ class WaitGroup {
std::atomic<unsigned int> count = {0}; std::atomic<unsigned int> count = {0};
ConditionVariable cv; ConditionVariable cv;
std::mutex mutex; marl::mutex mutex;
}; };
const std::shared_ptr<Data> data; const std::shared_ptr<Data> data;
}; };
...@@ -91,7 +91,7 @@ bool WaitGroup::done() const { ...@@ -91,7 +91,7 @@ bool WaitGroup::done() const {
MARL_ASSERT(data->count > 0, "marl::WaitGroup::done() called too many times"); MARL_ASSERT(data->count > 0, "marl::WaitGroup::done() called too many times");
auto count = --data->count; auto count = --data->count;
if (count == 0) { if (count == 0) {
std::unique_lock<std::mutex> lock(data->mutex); marl::lock lock(data->mutex);
data->cv.notify_all(); data->cv.notify_all();
return true; return true;
} }
...@@ -99,7 +99,7 @@ bool WaitGroup::done() const { ...@@ -99,7 +99,7 @@ bool WaitGroup::done() const {
} }
void WaitGroup::wait() const { void WaitGroup::wait() const {
std::unique_lock<std::mutex> lock(data->mutex); marl::lock lock(data->mutex);
data->cv.wait(lock, [this] { return data->count == 0; }); data->cv.wait(lock, [this] { return data->count == 0; });
} }
......
...@@ -18,17 +18,16 @@ ...@@ -18,17 +18,16 @@
#include "marl_test.h" #include "marl_test.h"
#include <condition_variable> #include <condition_variable>
#include <mutex>
TEST_F(WithoutBoundScheduler, ConditionVariable) { TEST_F(WithoutBoundScheduler, ConditionVariable) {
bool trigger[3] = {false, false, false}; bool trigger[3] = {false, false, false};
bool signal[3] = {false, false, false}; bool signal[3] = {false, false, false};
std::mutex mutex; marl::mutex mutex;
marl::ConditionVariable cv; marl::ConditionVariable cv;
std::thread thread([&] { std::thread thread([&] {
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
cv.wait(lock, [&] { cv.wait(lock, [&] {
EXPECT_TRUE(lock.owns_lock()); EXPECT_TRUE(lock.owns_lock());
return trigger[i]; return trigger[i];
...@@ -45,7 +44,7 @@ TEST_F(WithoutBoundScheduler, ConditionVariable) { ...@@ -45,7 +44,7 @@ TEST_F(WithoutBoundScheduler, ConditionVariable) {
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
{ {
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
trigger[i] = true; trigger[i] = true;
cv.notify_one(); cv.notify_one();
cv.wait(lock, [&] { cv.wait(lock, [&] {
...@@ -66,12 +65,12 @@ TEST_F(WithoutBoundScheduler, ConditionVariable) { ...@@ -66,12 +65,12 @@ TEST_F(WithoutBoundScheduler, ConditionVariable) {
TEST_P(WithBoundScheduler, ConditionVariable) { TEST_P(WithBoundScheduler, ConditionVariable) {
bool trigger[3] = {false, false, false}; bool trigger[3] = {false, false, false};
bool signal[3] = {false, false, false}; bool signal[3] = {false, false, false};
std::mutex mutex; marl::mutex mutex;
marl::ConditionVariable cv; marl::ConditionVariable cv;
std::thread thread([&] { std::thread thread([&] {
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
cv.wait(lock, [&] { cv.wait(lock, [&] {
EXPECT_TRUE(lock.owns_lock()); EXPECT_TRUE(lock.owns_lock());
return trigger[i]; return trigger[i];
...@@ -88,7 +87,7 @@ TEST_P(WithBoundScheduler, ConditionVariable) { ...@@ -88,7 +87,7 @@ TEST_P(WithBoundScheduler, ConditionVariable) {
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
{ {
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
trigger[i] = true; trigger[i] = true;
cv.notify_one(); cv.notify_one();
cv.wait(lock, [&] { cv.wait(lock, [&] {
...@@ -113,14 +112,14 @@ TEST_P(WithBoundScheduler, ConditionVariable) { ...@@ -113,14 +112,14 @@ TEST_P(WithBoundScheduler, ConditionVariable) {
// they are early-unblocked, along with expected lock state. // they are early-unblocked, along with expected lock state.
TEST_P(WithBoundScheduler, ConditionVariableTimeouts) { TEST_P(WithBoundScheduler, ConditionVariableTimeouts) {
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
std::mutex mutex; marl::mutex mutex;
marl::ConditionVariable cv; marl::ConditionVariable cv;
bool signaled = false; // guarded by mutex bool signaled = false; // guarded by mutex
auto wg = marl::WaitGroup(100); auto wg = marl::WaitGroup(100);
for (int j = 0; j < 100; j++) { for (int j = 0; j < 100; j++) {
marl::schedule([=, &mutex, &cv, &signaled] { marl::schedule([=, &mutex, &cv, &signaled] {
{ {
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
cv.wait_for(lock, std::chrono::milliseconds(j), [&] { cv.wait_for(lock, std::chrono::milliseconds(j), [&] {
EXPECT_TRUE(lock.owns_lock()); EXPECT_TRUE(lock.owns_lock());
return signaled; return signaled;
...@@ -134,7 +133,7 @@ TEST_P(WithBoundScheduler, ConditionVariableTimeouts) { ...@@ -134,7 +133,7 @@ TEST_P(WithBoundScheduler, ConditionVariableTimeouts) {
} }
std::this_thread::sleep_for(std::chrono::milliseconds(50)); std::this_thread::sleep_for(std::chrono::milliseconds(50));
{ {
std::unique_lock<std::mutex> lock(mutex); marl::lock lock(mutex);
signaled = true; signaled = true;
cv.notify_all(); cv.notify_all();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment