Commit 8656bcb0 by Ben Clayton

Update Marl to ca8408f68

Contains a number of optimizations that improve the Subzero coroutine performance up to 10x. Changes: ca8408f68 Scheduler: Reduce the number of mutex locks / unlock. 575b61e76 Fix compilation of marl::Ticket::onCall() e9f312688 waitForWork(): Early out if there work.num > 0 3196a0539 Scheduler: Use std::deque instead of std::queue 08a820171 Add flags to marl::Task cb3c481d0 Scheduler: Use a separate flag to indicate whether to call notify() 598c993ec marl::ConditionVariable - use containers::list d0c501a9c Add marl::containers::list aa1de9091 Benchmarks: Add EventBaton Commands: git subtree pull --prefix third_party/marl https://github.com/google/marl master --squash Bug: b/140546382 Change-Id: I2b7adc3c624a1f3aef686de7e0e88c52a5666e3a
parents 440fc995 36835bcf
...@@ -15,14 +15,15 @@ ...@@ -15,14 +15,15 @@
#ifndef marl_condition_variable_h #ifndef marl_condition_variable_h
#define marl_condition_variable_h #define marl_condition_variable_h
#include "containers.h"
#include "debug.h" #include "debug.h"
#include "defer.h" #include "defer.h"
#include "memory.h"
#include "scheduler.h" #include "scheduler.h"
#include <atomic> #include <atomic>
#include <condition_variable> #include <condition_variable>
#include <mutex> #include <mutex>
#include <unordered_set>
namespace marl { namespace marl {
...@@ -34,7 +35,7 @@ namespace marl { ...@@ -34,7 +35,7 @@ namespace marl {
// thread will work on other tasks until the ConditionVariable is unblocked. // thread will work on other tasks until the ConditionVariable is unblocked.
class ConditionVariable { class ConditionVariable {
public: public:
inline ConditionVariable(); inline ConditionVariable(Allocator* allocator = Allocator::Default);
// notify_one() notifies and potentially unblocks one waiting fiber or thread. // notify_one() notifies and potentially unblocks one waiting fiber or thread.
inline void notify_one(); inline void notify_one();
...@@ -73,13 +74,15 @@ class ConditionVariable { ...@@ -73,13 +74,15 @@ class ConditionVariable {
ConditionVariable& operator=(ConditionVariable&&) = delete; ConditionVariable& operator=(ConditionVariable&&) = delete;
std::mutex mutex; std::mutex mutex;
std::unordered_set<Scheduler::Fiber*> waiting; containers::list<Scheduler::Fiber*> waiting;
std::condition_variable condition; std::condition_variable condition;
std::atomic<int> numWaiting = {0}; std::atomic<int> numWaiting = {0};
std::atomic<int> numWaitingOnCondition = {0}; std::atomic<int> numWaitingOnCondition = {0};
}; };
ConditionVariable::ConditionVariable() {} ConditionVariable::ConditionVariable(
Allocator* allocator /* = Allocator::Default */)
: waiting(allocator) {}
void ConditionVariable::notify_one() { void ConditionVariable::notify_one() {
if (numWaiting == 0) { if (numWaiting == 0) {
...@@ -122,13 +125,13 @@ void ConditionVariable::wait(std::unique_lock<std::mutex>& lock, ...@@ -122,13 +125,13 @@ void ConditionVariable::wait(std::unique_lock<std::mutex>& lock,
// Currently executing on a scheduler fiber. // Currently executing on a scheduler fiber.
// Yield to let other tasks run that can unblock this fiber. // Yield to let other tasks run that can unblock this fiber.
mutex.lock(); mutex.lock();
waiting.emplace(fiber); auto it = waiting.emplace_front(fiber);
mutex.unlock(); mutex.unlock();
fiber->wait(lock, pred); fiber->wait(lock, pred);
mutex.lock(); mutex.lock();
waiting.erase(fiber); waiting.erase(it);
mutex.unlock(); mutex.unlock();
} else { } else {
// Currently running outside of the scheduler. // Currently running outside of the scheduler.
...@@ -163,13 +166,13 @@ bool ConditionVariable::wait_until( ...@@ -163,13 +166,13 @@ bool ConditionVariable::wait_until(
// Currently executing on a scheduler fiber. // Currently executing on a scheduler fiber.
// Yield to let other tasks run that can unblock this fiber. // Yield to let other tasks run that can unblock this fiber.
mutex.lock(); mutex.lock();
waiting.emplace(fiber); auto it = waiting.emplace_front(fiber);
mutex.unlock(); mutex.unlock();
auto res = fiber->wait(lock, timeout, pred); auto res = fiber->wait(lock, timeout, pred);
mutex.lock(); mutex.lock();
waiting.erase(fiber); waiting.erase(it);
mutex.unlock(); mutex.unlock();
return res; return res;
......
...@@ -19,10 +19,9 @@ ...@@ -19,10 +19,9 @@
#include "memory.h" #include "memory.h"
#include <algorithm> // std::max #include <algorithm> // std::max
#include <cstddef> // size_t
#include <utility> // std::move #include <utility> // std::move
#include <cstddef> // size_t
namespace marl { namespace marl {
namespace containers { namespace containers {
...@@ -243,6 +242,207 @@ void vector<T, BASE_CAPACITY>::free() { ...@@ -243,6 +242,207 @@ void vector<T, BASE_CAPACITY>::free() {
} }
} }
////////////////////////////////////////////////////////////////////////////////
// list<T, BASE_CAPACITY>
////////////////////////////////////////////////////////////////////////////////
// list is a minimal std::list like container that supports constant time
// insertion and removal of elements.
// list keeps hold of allocations (it only releases allocations on destruction),
// to avoid repeated heap allocations and frees when frequently inserting and
// removing elements.
template <typename T>
class list {
struct Entry {
T data;
Entry* next;
Entry* prev;
};
public:
class iterator {
public:
inline iterator(Entry*);
inline T* operator->();
inline T& operator*();
inline iterator& operator++();
inline bool operator==(const iterator&) const;
inline bool operator!=(const iterator&) const;
private:
friend list;
Entry* entry;
};
inline list(Allocator* allocator = Allocator::Default);
inline ~list();
inline iterator begin();
inline iterator end();
inline size_t size() const;
template <typename... Args>
iterator emplace_front(Args&&... args);
inline void erase(iterator);
private:
// copy / move is currently unsupported.
list(const list&) = delete;
list(list&&) = delete;
list& operator=(const list&) = delete;
list& operator=(list&&) = delete;
void grow(size_t count);
static void unlink(Entry* entry, Entry*& list);
static void link(Entry* entry, Entry*& list);
Allocator* const allocator;
size_t size_ = 0;
size_t capacity = 0;
vector<Allocation, 8> allocations;
Entry* free = nullptr;
Entry* head = nullptr;
};
template <typename T>
list<T>::iterator::iterator(Entry* entry) : entry(entry) {}
template <typename T>
T* list<T>::iterator::operator->() {
return &entry->data;
}
template <typename T>
T& list<T>::iterator::operator*() {
return entry->data;
}
template <typename T>
typename list<T>::iterator& list<T>::iterator::operator++() {
entry = entry->next;
return *this;
}
template <typename T>
bool list<T>::iterator::operator==(const iterator& rhs) const {
return entry == rhs.entry;
}
template <typename T>
bool list<T>::iterator::operator!=(const iterator& rhs) const {
return entry != rhs.entry;
}
template <typename T>
list<T>::list(Allocator* allocator /* = Allocator::Default */)
: allocator(allocator), allocations(allocator) {
grow(8);
}
template <typename T>
list<T>::~list() {
for (auto el = head; el != nullptr; el = el->next) {
el->data.~T();
}
for (auto alloc : allocations) {
allocator->free(alloc);
}
}
template <typename T>
typename list<T>::iterator list<T>::begin() {
return {head};
}
template <typename T>
typename list<T>::iterator list<T>::end() {
return {nullptr};
}
template <typename T>
size_t list<T>::size() const {
return size_;
}
template <typename T>
template <typename... Args>
typename list<T>::iterator list<T>::emplace_front(Args&&... args) {
if (free == nullptr) {
grow(capacity);
}
auto entry = free;
unlink(entry, free);
link(entry, head);
new (&entry->data) T(std::forward<T>(args)...);
size_++;
return entry;
}
template <typename T>
void list<T>::erase(iterator it) {
auto entry = it.entry;
unlink(entry, head);
link(entry, free);
entry->data.~T();
size_--;
}
template <typename T>
void list<T>::grow(size_t count) {
Allocation::Request request;
request.size = sizeof(Entry) * count;
request.alignment = alignof(Entry);
request.usage = Allocation::Usage::List;
auto alloc = allocator->allocate(request);
auto entries = reinterpret_cast<Entry*>(alloc.ptr);
for (size_t i = 0; i < count; i++) {
auto entry = &entries[i];
entry->prev = nullptr;
entry->next = free;
if (free) {
free->prev = entry;
}
free = entry;
}
allocations.emplace_back(std::move(alloc));
capacity += count;
}
template <typename T>
void list<T>::unlink(Entry* entry, Entry*& list) {
if (list == entry) {
list = list->next;
}
if (entry->prev) {
entry->prev->next = entry->next;
}
if (entry->next) {
entry->next->prev = entry->prev;
}
entry->prev = nullptr;
entry->next = nullptr;
}
template <typename T>
void list<T>::link(Entry* entry, Entry*& list) {
MARL_ASSERT(entry->next == nullptr, "link() called on entry already linked");
MARL_ASSERT(entry->prev == nullptr, "link() called on entry already linked");
if (list) {
entry->next = list;
list->prev = entry;
}
list = entry;
}
} // namespace containers } // namespace containers
} // namespace marl } // namespace marl
......
...@@ -102,7 +102,7 @@ class Event { ...@@ -102,7 +102,7 @@ class Event {
private: private:
struct Shared { struct Shared {
inline Shared(Mode mode, bool initialState); inline Shared(Allocator* allocator, Mode mode, bool initialState);
inline void signal(); inline void signal();
inline void wait(); inline void wait();
...@@ -123,8 +123,8 @@ class Event { ...@@ -123,8 +123,8 @@ class Event {
const std::shared_ptr<Shared> shared; const std::shared_ptr<Shared> shared;
}; };
Event::Shared::Shared(Mode mode, bool initialState) Event::Shared::Shared(Allocator* allocator, Mode mode, bool initialState)
: mode(mode), signalled(initialState) {} : cv(allocator), mode(mode), signalled(initialState) {}
void Event::Shared::signal() { void Event::Shared::signal() {
std::unique_lock<std::mutex> lock(mutex); std::unique_lock<std::mutex> lock(mutex);
...@@ -179,7 +179,7 @@ bool Event::Shared::wait_until( ...@@ -179,7 +179,7 @@ bool Event::Shared::wait_until(
Event::Event(Mode mode /* = Mode::Auto */, Event::Event(Mode mode /* = Mode::Auto */,
bool initialState /* = false */, bool initialState /* = false */,
Allocator* allocator /* = Allocator::Default */) Allocator* allocator /* = Allocator::Default */)
: shared(allocator->make_shared<Shared>(mode, initialState)) {} : shared(allocator->make_shared<Shared>(allocator, mode, initialState)) {}
void Event::signal() const { void Event::signal() const {
shared->signal(); shared->signal();
......
...@@ -38,7 +38,8 @@ struct Allocation { ...@@ -38,7 +38,8 @@ struct Allocation {
Undefined = 0, Undefined = 0,
Stack, // Fiber stack Stack, // Fiber stack
Create, // Allocator::create(), make_unique(), make_shared() Create, // Allocator::create(), make_unique(), make_shared()
Vector, // marl::vector<T> Vector, // marl::containers::vector<T>
List, // marl::containers::list<T>
Count, // Not intended to be used as a usage type - used for upper bound. Count, // Not intended to be used as a usage type - used for upper bound.
}; };
......
...@@ -232,7 +232,7 @@ class BoundedPool : public Pool<T> { ...@@ -232,7 +232,7 @@ class BoundedPool : public Pool<T> {
private: private:
class Storage : public Pool<T>::Storage { class Storage : public Pool<T>::Storage {
public: public:
inline Storage(); inline Storage(Allocator* allocator);
inline ~Storage(); inline ~Storage();
inline void return_(Item*) override; inline void return_(Item*) override;
...@@ -245,7 +245,8 @@ class BoundedPool : public Pool<T> { ...@@ -245,7 +245,8 @@ class BoundedPool : public Pool<T> {
}; };
template <typename T, int N, PoolPolicy POLICY> template <typename T, int N, PoolPolicy POLICY>
BoundedPool<T, N, POLICY>::Storage::Storage() { BoundedPool<T, N, POLICY>::Storage::Storage(Allocator* allocator)
: returned(allocator) {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
if (POLICY == PoolPolicy::Preserve) { if (POLICY == PoolPolicy::Preserve) {
items[i].construct(); items[i].construct();
...@@ -267,7 +268,7 @@ BoundedPool<T, N, POLICY>::Storage::~Storage() { ...@@ -267,7 +268,7 @@ BoundedPool<T, N, POLICY>::Storage::~Storage() {
template <typename T, int N, PoolPolicy POLICY> template <typename T, int N, PoolPolicy POLICY>
BoundedPool<T, N, POLICY>::BoundedPool( BoundedPool<T, N, POLICY>::BoundedPool(
Allocator* allocator /* = Allocator::Default */) Allocator* allocator /* = Allocator::Default */)
: storage(allocator->make_shared<Storage>()) {} : storage(allocator->make_shared<Storage>(allocator)) {}
template <typename T, int N, PoolPolicy POLICY> template <typename T, int N, PoolPolicy POLICY>
typename BoundedPool<T, N, POLICY>::Loan BoundedPool<T, N, POLICY>::borrow() typename BoundedPool<T, N, POLICY>::Loan BoundedPool<T, N, POLICY>::borrow()
......
...@@ -18,28 +18,27 @@ ...@@ -18,28 +18,27 @@
#include "debug.h" #include "debug.h"
#include "memory.h" #include "memory.h"
#include "sal.h" #include "sal.h"
#include "task.h"
#include "thread.h" #include "thread.h"
#include <array> #include <array>
#include <atomic> #include <atomic>
#include <chrono> #include <chrono>
#include <condition_variable> #include <condition_variable>
#include <deque>
#include <functional> #include <functional>
#include <map> #include <map>
#include <mutex> #include <mutex>
#include <queue>
#include <set> #include <set>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector>
namespace marl { namespace marl {
class OSFiber; class OSFiber;
// Task is a unit of work for the scheduler.
using Task = std::function<void()>;
// Scheduler asynchronously processes Tasks. // Scheduler asynchronously processes Tasks.
// A scheduler can be bound to one or more threads using the bind() method. // A scheduler can be bound to one or more threads using the bind() method.
// Once bound to a thread, that thread can call marl::schedule() to enqueue // Once bound to a thread, that thread can call marl::schedule() to enqueue
...@@ -257,8 +256,8 @@ class Scheduler { ...@@ -257,8 +256,8 @@ class Scheduler {
// TODO: Implement a queue that recycles elements to reduce number of // TODO: Implement a queue that recycles elements to reduce number of
// heap allocations. // heap allocations.
using TaskQueue = std::queue<Task>; using TaskQueue = std::deque<Task>;
using FiberQueue = std::queue<Fiber*>; using FiberQueue = std::deque<Fiber*>;
using FiberSet = std::unordered_set<Fiber*>; using FiberSet = std::unordered_set<Fiber*>;
// Workers executes Tasks on a single thread. // Workers executes Tasks on a single thread.
...@@ -318,9 +317,9 @@ class Scheduler { ...@@ -318,9 +317,9 @@ class Scheduler {
// flush() processes all pending tasks before returning. // flush() processes all pending tasks before returning.
void flush(); void flush();
// dequeue() attempts to take a Task from the worker. Returns true if // steal() attempts to steal a Task from the worker for another worker.
// a task was taken and assigned to out, otherwise false. // Returns true if a task was taken and assigned to out, otherwise false.
bool dequeue(Task& out); bool steal(Task& out);
// getCurrent() returns the Worker currently bound to the current // getCurrent() returns the Worker currently bound to the current
// thread. // thread.
...@@ -338,14 +337,17 @@ class Scheduler { ...@@ -338,14 +337,17 @@ class Scheduler {
// continue to process tasks until stop() is called. // continue to process tasks until stop() is called.
// If the worker was constructed in Mode::SingleThreaded, run() call // If the worker was constructed in Mode::SingleThreaded, run() call
// flush() and return. // flush() and return.
_Requires_lock_held_(work.mutex)
void run(); void run();
// createWorkerFiber() creates a new fiber that when executed calls // createWorkerFiber() creates a new fiber that when executed calls
// run(). // run().
_Requires_lock_held_(work.mutex)
Fiber* createWorkerFiber(); Fiber* createWorkerFiber();
// switchToFiber() switches execution to the given fiber. The fiber // switchToFiber() switches execution to the given fiber. The fiber
// must belong to this worker. // must belong to this worker.
_Requires_lock_held_(work.mutex)
void switchToFiber(Fiber*); void switchToFiber(Fiber*);
// runUntilIdle() executes all pending tasks and then returns. // runUntilIdle() executes all pending tasks and then returns.
...@@ -387,8 +389,13 @@ class Scheduler { ...@@ -387,8 +389,13 @@ class Scheduler {
_Guarded_by_(mutex) TaskQueue tasks; _Guarded_by_(mutex) TaskQueue tasks;
_Guarded_by_(mutex) FiberQueue fibers; _Guarded_by_(mutex) FiberQueue fibers;
_Guarded_by_(mutex) WaitingFibers waiting; _Guarded_by_(mutex) WaitingFibers waiting;
_Guarded_by_(mutex) bool notifyAdded = true;
std::condition_variable added; std::condition_variable added;
std::mutex mutex; std::mutex mutex;
_Requires_lock_held_(mutex)
template <typename F>
inline void wait(F&&);
}; };
// https://en.wikipedia.org/wiki/Xorshift // https://en.wikipedia.org/wiki/Xorshift
...@@ -418,7 +425,7 @@ class Scheduler { ...@@ -418,7 +425,7 @@ class Scheduler {
std::vector<Allocator::unique_ptr<Fiber>> std::vector<Allocator::unique_ptr<Fiber>>
workerFibers; // All fibers created by this worker. workerFibers; // All fibers created by this worker.
FastRnd rng; FastRnd rng;
std::atomic<bool> shutdown = {false}; bool shutdown = false;
}; };
// stealWork() attempts to steal a task from the worker with the given id. // stealWork() attempts to steal a task from the worker with the given id.
...@@ -472,6 +479,14 @@ Scheduler::Fiber* Scheduler::Worker::getCurrentFiber() const { ...@@ -472,6 +479,14 @@ Scheduler::Fiber* Scheduler::Worker::getCurrentFiber() const {
return currentFiber; return currentFiber;
} }
// schedule() schedules the task T to be asynchronously called using the
// currently bound scheduler.
inline void schedule(Task&& t) {
MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
auto scheduler = Scheduler::get();
scheduler->enqueue(std::move(t));
}
// schedule() schedules the function f to be asynchronously called with the // schedule() schedules the function f to be asynchronously called with the
// given arguments using the currently bound scheduler. // given arguments using the currently bound scheduler.
template <typename Function, typename... Args> template <typename Function, typename... Args>
...@@ -488,7 +503,7 @@ template <typename Function> ...@@ -488,7 +503,7 @@ template <typename Function>
inline void schedule(Function&& f) { inline void schedule(Function&& f) {
MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule"); MARL_ASSERT_HAS_BOUND_SCHEDULER("marl::schedule");
auto scheduler = Scheduler::get(); auto scheduler = Scheduler::get();
scheduler->enqueue(std::forward<Function>(f)); scheduler->enqueue(Task(std::forward<Function>(f)));
} }
} // namespace marl } // namespace marl
......
// Copyright 2020 The Marl Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef marl_task_h
#define marl_task_h
#include <functional>
namespace marl {
// Task is a unit of work for the scheduler.
class Task {
public:
using Function = std::function<void()>;
enum class Flags {
None = 0,
// SameThread ensures the task will be run on the same thread that scheduled
// the task. This can offer performance improvements if the current thread
// is immediately going to block on the newly scheduled task, by reducing
// overheads of waking another thread.
SameThread = 1,
};
inline Task();
inline Task(const Task&);
inline Task(Task&&);
inline Task(const Function& function, Flags flags = Flags::None);
inline Task(Function&& function, Flags flags = Flags::None);
inline Task& operator=(const Task&);
inline Task& operator=(Task&&);
inline Task& operator=(const Function&);
inline Task& operator=(Function&&);
// operator bool() returns true if the Task has a valid function.
inline operator bool() const;
// operator()() runs the task.
inline void operator()() const;
// is() returns true if the Task was created with the given flag.
inline bool is(Flags flag) const;
private:
Function function;
Flags flags = Flags::None;
};
Task::Task() {}
Task::Task(const Task& o) : function(o.function), flags(o.flags) {}
Task::Task(Task&& o) : function(std::move(o.function)), flags(o.flags) {}
Task::Task(const Function& function, Flags flags /* = Flags::None */)
: function(function), flags(flags) {}
Task::Task(Function&& function, Flags flags /* = Flags::None */)
: function(std::move(function)), flags(flags) {}
Task& Task::operator=(const Task& o) {
function = o.function;
flags = o.flags;
return *this;
}
Task& Task::operator=(Task&& o) {
function = std::move(o.function);
flags = o.flags;
return *this;
}
Task& Task::operator=(const Function& f) {
function = f;
flags = Flags::None;
return *this;
}
Task& Task::operator=(Function&& f) {
function = std::move(f);
flags = Flags::None;
return *this;
}
Task::operator bool() const {
return function.operator bool();
}
void Task::operator()() const {
function();
}
bool Task::is(Flags flag) const {
return (static_cast<int>(flags) & static_cast<int>(flag)) ==
static_cast<int>(flag);
}
} // namespace marl
#endif // marl_task_h
...@@ -62,6 +62,8 @@ class Ticket { ...@@ -62,6 +62,8 @@ class Ticket {
struct Record; struct Record;
public: public:
using OnCall = std::function<void()>;
// Queue hands out Tickets. // Queue hands out Tickets.
class Queue { class Queue {
public: public:
...@@ -93,7 +95,7 @@ class Ticket { ...@@ -93,7 +95,7 @@ class Ticket {
// onCall() registers the function f to be invoked when this ticket is // onCall() registers the function f to be invoked when this ticket is
// called. If the ticket is already called prior to calling onCall(), then // called. If the ticket is already called prior to calling onCall(), then
// f() will be executed immediately. // f() will be executed immediately.
// F must be a function of the signature: void F() // F must be a function of the OnCall signature.
template <typename F> template <typename F>
inline void onCall(F&& f) const; inline void onCall(F&& f) const;
...@@ -111,7 +113,7 @@ class Ticket { ...@@ -111,7 +113,7 @@ class Ticket {
Record* next = nullptr; // guarded by shared->mutex Record* next = nullptr; // guarded by shared->mutex
Record* prev = nullptr; // guarded by shared->mutex Record* prev = nullptr; // guarded by shared->mutex
inline void unlink(); // guarded by shared->mutex inline void unlink(); // guarded by shared->mutex
Task onCall; // guarded by shared->mutex OnCall onCall; // guarded by shared->mutex
bool isCalled = false; // guarded by shared->mutex bool isCalled = false; // guarded by shared->mutex
std::atomic<bool> isDone = {false}; std::atomic<bool> isDone = {false};
}; };
...@@ -155,7 +157,7 @@ void Ticket::onCall(Function&& f) const { ...@@ -155,7 +157,7 @@ void Ticket::onCall(Function&& f) const {
a(); a();
b(); b();
} }
Task a, b; OnCall a, b;
}; };
record->onCall = std::move(Joined{std::move(record->onCall), std::move(f)}); record->onCall = std::move(Joined{std::move(record->onCall), std::move(f)});
} else { } else {
...@@ -228,13 +230,13 @@ void Ticket::Record::callAndUnlock(std::unique_lock<std::mutex>& lock) { ...@@ -228,13 +230,13 @@ void Ticket::Record::callAndUnlock(std::unique_lock<std::mutex>& lock) {
return; return;
} }
isCalled = true; isCalled = true;
Task task; OnCall callback;
std::swap(task, onCall); std::swap(callback, onCall);
isCalledCondVar.notify_all(); isCalledCondVar.notify_all();
lock.unlock(); lock.unlock();
if (task) { if (callback) {
marl::schedule(std::move(task)); marl::schedule(std::move(callback));
} }
} }
......
...@@ -51,7 +51,8 @@ namespace marl { ...@@ -51,7 +51,8 @@ namespace marl {
class WaitGroup { class WaitGroup {
public: public:
// Constructs the WaitGroup with the specified initial count. // Constructs the WaitGroup with the specified initial count.
inline WaitGroup(unsigned int initialCount = 0); inline WaitGroup(unsigned int initialCount = 0,
Allocator* allocator = Allocator::Default);
// add() increments the internal counter by count. // add() increments the internal counter by count.
inline void add(unsigned int count = 1) const; inline void add(unsigned int count = 1) const;
...@@ -65,14 +66,20 @@ class WaitGroup { ...@@ -65,14 +66,20 @@ class WaitGroup {
private: private:
struct Data { struct Data {
inline Data(Allocator* allocator);
std::atomic<unsigned int> count = {0}; std::atomic<unsigned int> count = {0};
ConditionVariable condition; ConditionVariable cv;
std::mutex mutex; std::mutex mutex;
}; };
const std::shared_ptr<Data> data = std::make_shared<Data>(); const std::shared_ptr<Data> data;
}; };
inline WaitGroup::WaitGroup(unsigned int initialCount /* = 0 */) { WaitGroup::Data::Data(Allocator* allocator) : cv(allocator) {}
WaitGroup::WaitGroup(unsigned int initialCount /* = 0 */,
Allocator* allocator /* = Allocator::Default */)
: data(std::make_shared<Data>(allocator)) {
data->count = initialCount; data->count = initialCount;
} }
...@@ -85,7 +92,7 @@ bool WaitGroup::done() const { ...@@ -85,7 +92,7 @@ bool WaitGroup::done() const {
auto count = --data->count; auto count = --data->count;
if (count == 0) { if (count == 0) {
std::unique_lock<std::mutex> lock(data->mutex); std::unique_lock<std::mutex> lock(data->mutex);
data->condition.notify_all(); data->cv.notify_all();
return true; return true;
} }
return false; return false;
...@@ -93,7 +100,7 @@ bool WaitGroup::done() const { ...@@ -93,7 +100,7 @@ bool WaitGroup::done() const {
void WaitGroup::wait() const { void WaitGroup::wait() const {
std::unique_lock<std::mutex> lock(data->mutex); std::unique_lock<std::mutex> lock(data->mutex);
data->condition.wait(lock, [this] { return data->count == 0; }); data->cv.wait(lock, [this] { return data->count == 0; });
} }
} // namespace marl } // namespace marl
......
...@@ -193,3 +193,93 @@ TEST_F(ContainersVectorTest, Move) { ...@@ -193,3 +193,93 @@ TEST_F(ContainersVectorTest, Move) {
ASSERT_EQ(vectorB[1], "B"); ASSERT_EQ(vectorB[1], "B");
ASSERT_EQ(vectorB[2], "C"); ASSERT_EQ(vectorB[2], "C");
} }
class ContainersListTest : public WithoutBoundScheduler {};
TEST_F(ContainersListTest, Empty) {
marl::containers::list<std::string> list(allocator);
ASSERT_EQ(list.size(), size_t(0));
}
TEST_F(ContainersListTest, EmplaceOne) {
marl::containers::list<std::string> list(allocator);
auto itEntry = list.emplace_front("hello world");
ASSERT_EQ(*itEntry, "hello world");
ASSERT_EQ(list.size(), size_t(1));
auto it = list.begin();
ASSERT_EQ(it, itEntry);
++it;
ASSERT_EQ(it, list.end());
}
TEST_F(ContainersListTest, EmplaceThree) {
marl::containers::list<std::string> list(allocator);
auto itA = list.emplace_front("a");
auto itB = list.emplace_front("b");
auto itC = list.emplace_front("c");
ASSERT_EQ(*itA, "a");
ASSERT_EQ(*itB, "b");
ASSERT_EQ(*itC, "c");
ASSERT_EQ(list.size(), size_t(3));
auto it = list.begin();
ASSERT_EQ(it, itC);
++it;
ASSERT_EQ(it, itB);
++it;
ASSERT_EQ(it, itA);
++it;
ASSERT_EQ(it, list.end());
}
TEST_F(ContainersListTest, EraseFront) {
marl::containers::list<std::string> list(allocator);
auto itA = list.emplace_front("a");
auto itB = list.emplace_front("b");
auto itC = list.emplace_front("c");
list.erase(itC);
ASSERT_EQ(list.size(), size_t(2));
auto it = list.begin();
ASSERT_EQ(it, itB);
++it;
ASSERT_EQ(it, itA);
++it;
ASSERT_EQ(it, list.end());
}
TEST_F(ContainersListTest, EraseBack) {
marl::containers::list<std::string> list(allocator);
auto itA = list.emplace_front("a");
auto itB = list.emplace_front("b");
auto itC = list.emplace_front("c");
list.erase(itA);
ASSERT_EQ(list.size(), size_t(2));
auto it = list.begin();
ASSERT_EQ(it, itC);
++it;
ASSERT_EQ(it, itB);
++it;
ASSERT_EQ(it, list.end());
}
TEST_F(ContainersListTest, EraseMid) {
marl::containers::list<std::string> list(allocator);
auto itA = list.emplace_front("a");
auto itB = list.emplace_front("b");
auto itC = list.emplace_front("c");
list.erase(itB);
ASSERT_EQ(list.size(), size_t(2));
auto it = list.begin();
ASSERT_EQ(it, itC);
++it;
ASSERT_EQ(it, itA);
++it;
ASSERT_EQ(it, list.end());
}
TEST_F(ContainersListTest, Grow) {
marl::containers::list<std::string> list(allocator);
for (int i = 0; i < 256; i++) {
list.emplace_front(std::to_string(i));
}
ASSERT_EQ(list.size(), size_t(256));
}
...@@ -36,3 +36,37 @@ BENCHMARK_DEFINE_F(Schedule, Event)(benchmark::State& state) { ...@@ -36,3 +36,37 @@ BENCHMARK_DEFINE_F(Schedule, Event)(benchmark::State& state) {
}); });
} }
BENCHMARK_REGISTER_F(Schedule, Event)->Apply(Schedule::args<512>); BENCHMARK_REGISTER_F(Schedule, Event)->Apply(Schedule::args<512>);
// EventBaton benchmarks alternating execution of two tasks.
BENCHMARK_DEFINE_F(Schedule, EventBaton)(benchmark::State& state) {
run(state, [&](int numPasses) {
for (auto _ : state) {
marl::Event passToA(marl::Event::Mode::Auto);
marl::Event passToB(marl::Event::Mode::Auto);
marl::Event done(marl::Event::Mode::Auto);
marl::schedule(marl::Task(
[=] {
for (int i = 0; i < numPasses; i++) {
passToA.wait();
passToB.signal();
}
},
marl::Task::Flags::SameThread));
marl::schedule(marl::Task(
[=] {
for (int i = 0; i < numPasses; i++) {
passToB.wait();
passToA.signal();
}
done.signal();
},
marl::Task::Flags::SameThread));
passToA.signal();
done.wait();
}
});
}
BENCHMARK_REGISTER_F(Schedule, EventBaton)->Apply(Schedule::args<1000000>);
...@@ -48,13 +48,13 @@ class Schedule : public benchmark::Fixture { ...@@ -48,13 +48,13 @@ class Schedule : public benchmark::Fixture {
} }
} }
// numThreads return the number of threads in the benchmark run from the // numThreads() return the number of threads in the benchmark run from the
// state. // state.
static int numThreads(const ::benchmark::State& state) { static int numThreads(const ::benchmark::State& state) {
return static_cast<int>(state.range(1)); return static_cast<int>(state.range(1));
} }
// numTasks return the number of tasks in the benchmark run from the state. // numTasks() return the number of tasks in the benchmark run from the state.
static int numTasks(const ::benchmark::State& state) { static int numTasks(const ::benchmark::State& state) {
return static_cast<int>(state.range(0)); return static_cast<int>(state.range(0));
} }
......
...@@ -60,9 +60,9 @@ inline uint64_t threadID() { ...@@ -60,9 +60,9 @@ inline uint64_t threadID() {
#endif #endif
template <typename T> template <typename T>
inline T take(std::queue<T>& queue) { inline T take(std::deque<T>& queue) {
auto out = std::move(queue.front()); auto out = std::move(queue.front());
queue.pop(); queue.pop_front();
return out; return out;
} }
...@@ -189,6 +189,10 @@ int Scheduler::getWorkerThreadCount() { ...@@ -189,6 +189,10 @@ int Scheduler::getWorkerThreadCount() {
} }
void Scheduler::enqueue(Task&& task) { void Scheduler::enqueue(Task&& task) {
if (task.is(Task::Flags::SameThread)) {
Scheduler::Worker::getCurrent()->enqueue(std::move(task));
return;
}
if (numWorkerThreads > 0) { if (numWorkerThreads > 0) {
while (true) { while (true) {
// Prioritize workers that have recently started spinning. // Prioritize workers that have recently started spinning.
...@@ -220,7 +224,7 @@ bool Scheduler::stealWork(Worker* thief, uint64_t from, Task& out) { ...@@ -220,7 +224,7 @@ bool Scheduler::stealWork(Worker* thief, uint64_t from, Task& out) {
if (numWorkerThreads > 0) { if (numWorkerThreads > 0) {
auto thread = workerThreads[from % numWorkerThreads]; auto thread = workerThreads[from % numWorkerThreads];
if (thread != thief) { if (thread != thief) {
if (thread->dequeue(out)) { if (thread->steal(out)) {
return true; return true;
} }
} }
...@@ -372,7 +376,10 @@ void Scheduler::Worker::start() { ...@@ -372,7 +376,10 @@ void Scheduler::Worker::start() {
Worker::current = this; Worker::current = this;
mainFiber = Fiber::createFromCurrentThread(scheduler->allocator, 0); mainFiber = Fiber::createFromCurrentThread(scheduler->allocator, 0);
currentFiber = mainFiber.get(); currentFiber = mainFiber.get();
run(); {
std::unique_lock<std::mutex> lock(work.mutex);
run();
}
mainFiber.reset(); mainFiber.reset();
Worker::current = nullptr; Worker::current = nullptr;
}); });
...@@ -392,8 +399,7 @@ void Scheduler::Worker::start() { ...@@ -392,8 +399,7 @@ void Scheduler::Worker::start() {
void Scheduler::Worker::stop() { void Scheduler::Worker::stop() {
switch (mode) { switch (mode) {
case Mode::MultiThreaded: case Mode::MultiThreaded:
shutdown = true; enqueue(Task([this] { shutdown = true; }, Task::Flags::SameThread));
enqueue([] {}); // Ensure the worker is woken up to notice the shutdown.
thread.join(); thread.join();
break; break;
...@@ -462,22 +468,16 @@ void Scheduler::Worker::suspend( ...@@ -462,22 +468,16 @@ void Scheduler::Worker::suspend(
work.num--; work.num--;
auto to = take(work.fibers); auto to = take(work.fibers);
ASSERT_FIBER_STATE(to, Fiber::State::Queued); ASSERT_FIBER_STATE(to, Fiber::State::Queued);
work.mutex.unlock();
switchToFiber(to); switchToFiber(to);
work.mutex.lock();
} else if (idleFibers.size() > 0) { } else if (idleFibers.size() > 0) {
// There's an old fiber we can reuse, resume that. // There's an old fiber we can reuse, resume that.
auto to = take(idleFibers); auto to = take(idleFibers);
ASSERT_FIBER_STATE(to, Fiber::State::Idle); ASSERT_FIBER_STATE(to, Fiber::State::Idle);
work.mutex.unlock();
switchToFiber(to); switchToFiber(to);
work.mutex.lock();
} else { } else {
// Tasks to process and no existing fibers to resume. // Tasks to process and no existing fibers to resume.
// Spawn a new fiber. // Spawn a new fiber.
work.mutex.unlock();
switchToFiber(createWorkerFiber()); switchToFiber(createWorkerFiber());
work.mutex.lock();
} }
setFiberState(currentFiber, Fiber::State::Running); setFiberState(currentFiber, Fiber::State::Running);
...@@ -503,15 +503,15 @@ void Scheduler::Worker::enqueue(Fiber* fiber) { ...@@ -503,15 +503,15 @@ void Scheduler::Worker::enqueue(Fiber* fiber) {
case Fiber::State::Yielded: case Fiber::State::Yielded:
break; break;
} }
bool wasIdle = work.num == 0; bool notify = work.notifyAdded;
work.fibers.push(std::move(fiber)); work.fibers.push_back(std::move(fiber));
MARL_ASSERT(!work.waiting.contains(fiber), MARL_ASSERT(!work.waiting.contains(fiber),
"fiber is unexpectedly in the waiting list"); "fiber is unexpectedly in the waiting list");
setFiberState(fiber, Fiber::State::Queued); setFiberState(fiber, Fiber::State::Queued);
work.num++; work.num++;
lock.unlock(); lock.unlock();
if (wasIdle) { if (notify) {
work.added.notify_one(); work.added.notify_one();
} }
} }
...@@ -524,23 +524,24 @@ void Scheduler::Worker::enqueue(Task&& task) { ...@@ -524,23 +524,24 @@ void Scheduler::Worker::enqueue(Task&& task) {
_Requires_lock_held_(work.mutex) _Requires_lock_held_(work.mutex)
_Releases_lock_(work.mutex) _Releases_lock_(work.mutex)
void Scheduler::Worker::enqueueAndUnlock(Task&& task) { void Scheduler::Worker::enqueueAndUnlock(Task&& task) {
auto wasIdle = work.num == 0; auto notify = work.notifyAdded;
work.tasks.push(std::move(task)); work.tasks.push_back(std::move(task));
work.num++; work.num++;
work.mutex.unlock(); work.mutex.unlock();
if (wasIdle) { if (notify) {
work.added.notify_one(); work.added.notify_one();
} }
} }
bool Scheduler::Worker::dequeue(Task& out) { bool Scheduler::Worker::steal(Task& out) {
if (work.num.load() == 0) { if (work.num.load() == 0) {
return false; return false;
} }
if (!work.mutex.try_lock()) { if (!work.mutex.try_lock()) {
return false; return false;
} }
if (work.tasks.size() == 0) { if (work.tasks.size() == 0 ||
work.tasks.front().is(Task::Flags::SameThread)) {
work.mutex.unlock(); work.mutex.unlock();
return false; return false;
} }
...@@ -550,6 +551,7 @@ bool Scheduler::Worker::dequeue(Task& out) { ...@@ -550,6 +551,7 @@ bool Scheduler::Worker::dequeue(Task& out) {
return true; return true;
} }
_Requires_lock_held_(work.mutex)
void Scheduler::Worker::flush() { void Scheduler::Worker::flush() {
MARL_ASSERT(mode == Mode::SingleThreaded, MARL_ASSERT(mode == Mode::SingleThreaded,
"flush() can only be used on a single-threaded worker"); "flush() can only be used on a single-threaded worker");
...@@ -557,33 +559,30 @@ void Scheduler::Worker::flush() { ...@@ -557,33 +559,30 @@ void Scheduler::Worker::flush() {
runUntilIdle(); runUntilIdle();
} }
_Requires_lock_held_(work.mutex)
void Scheduler::Worker::run() { void Scheduler::Worker::run() {
switch (mode) { switch (mode) {
case Mode::MultiThreaded: { case Mode::MultiThreaded: {
MARL_NAME_THREAD("Thread<%.2d> Fiber<%.2d>", int(id), MARL_NAME_THREAD("Thread<%.2d> Fiber<%.2d>", int(id),
Fiber::current()->id); Fiber::current()->id);
{ work.wait([this] { return work.num > 0 || work.waiting || shutdown; });
std::unique_lock<std::mutex> lock(work.mutex); while (!shutdown || work.num > 0 || numBlockedFibers() > 0U) {
work.added.wait( waitForWork();
lock, [this] { return work.num > 0 || work.waiting || shutdown; }); runUntilIdle();
while (!shutdown || work.num > 0 || numBlockedFibers() > 0U) {
waitForWork();
runUntilIdle();
}
Worker::current = nullptr;
} }
Worker::current = nullptr;
switchToFiber(mainFiber.get()); switchToFiber(mainFiber.get());
break; break;
} }
case Mode::SingleThreaded: case Mode::SingleThreaded: {
ASSERT_FIBER_STATE(currentFiber, Fiber::State::Running); ASSERT_FIBER_STATE(currentFiber, Fiber::State::Running);
while (!shutdown) { while (!shutdown) {
flush(); runUntilIdle();
idleFibers.emplace(currentFiber); idleFibers.emplace(currentFiber);
switchToFiber(mainFiber.get()); switchToFiber(mainFiber.get());
} }
break; break;
}
default: default:
MARL_ASSERT(false, "Unknown mode: %d", int(mode)); MARL_ASSERT(false, "Unknown mode: %d", int(mode));
} }
...@@ -593,26 +592,22 @@ _Requires_lock_held_(work.mutex) ...@@ -593,26 +592,22 @@ _Requires_lock_held_(work.mutex)
void Scheduler::Worker::waitForWork() { void Scheduler::Worker::waitForWork() {
MARL_ASSERT(work.num == work.fibers.size() + work.tasks.size(), MARL_ASSERT(work.num == work.fibers.size() + work.tasks.size(),
"work.num out of sync"); "work.num out of sync");
if (work.num == 0 && mode == Mode::MultiThreaded) { if (work.num > 0) {
return;
}
if (mode == Mode::MultiThreaded) {
scheduler->onBeginSpinning(id); scheduler->onBeginSpinning(id);
work.mutex.unlock(); work.mutex.unlock();
spinForWork(); spinForWork();
work.mutex.lock(); work.mutex.lock();
} }
work.wait([this] {
return work.num > 0 || (shutdown && numBlockedFibers() == 0U);
});
if (work.waiting) { if (work.waiting) {
std::unique_lock<std::mutex> lock(work.mutex, std::adopt_lock);
work.added.wait_until(lock, work.waiting.next(), [this] {
return work.num > 0 || (shutdown && numBlockedFibers() == 0U);
});
lock.release(); // Keep the lock held.
enqueueFiberTimeouts(); enqueueFiberTimeouts();
} else {
std::unique_lock<std::mutex> lock(work.mutex, std::adopt_lock);
work.added.wait(lock, [this] {
return work.num > 0 || (shutdown && numBlockedFibers() == 0U);
});
lock.release(); // Keep the lock held.
} }
} }
...@@ -622,7 +617,7 @@ void Scheduler::Worker::enqueueFiberTimeouts() { ...@@ -622,7 +617,7 @@ void Scheduler::Worker::enqueueFiberTimeouts() {
while (auto fiber = work.waiting.take(now)) { while (auto fiber = work.waiting.take(now)) {
changeFiberState(fiber, Fiber::State::Waiting, Fiber::State::Queued); changeFiberState(fiber, Fiber::State::Waiting, Fiber::State::Queued);
DBG_LOG("%d: TIMEOUT(%d)", (int)id, (int)fiber->id); DBG_LOG("%d: TIMEOUT(%d)", (int)id, (int)fiber->id);
work.fibers.push(fiber); work.fibers.push_back(fiber);
work.num++; work.num++;
} }
} }
...@@ -667,7 +662,7 @@ void Scheduler::Worker::spinForWork() { ...@@ -667,7 +662,7 @@ void Scheduler::Worker::spinForWork() {
if (scheduler->stealWork(this, rng(), stolen)) { if (scheduler->stealWork(this, rng(), stolen)) {
std::unique_lock<std::mutex> lock(work.mutex); std::unique_lock<std::mutex> lock(work.mutex);
work.tasks.emplace(std::move(stolen)); work.tasks.emplace_back(std::move(stolen));
work.num++; work.num++;
return; return;
} }
...@@ -695,15 +690,11 @@ void Scheduler::Worker::runUntilIdle() { ...@@ -695,15 +690,11 @@ void Scheduler::Worker::runUntilIdle() {
ASSERT_FIBER_STATE(fiber, Fiber::State::Queued); ASSERT_FIBER_STATE(fiber, Fiber::State::Queued);
changeFiberState(currentFiber, Fiber::State::Running, Fiber::State::Idle); changeFiberState(currentFiber, Fiber::State::Running, Fiber::State::Idle);
work.mutex.unlock(); auto added = idleFibers.emplace(currentFiber).second;
{ // unlocked (void)added;
auto added = idleFibers.emplace(currentFiber).second; MARL_ASSERT(added, "fiber already idle");
(void)added;
MARL_ASSERT(added, "fiber already idle");
switchToFiber(fiber); switchToFiber(fiber);
}
work.mutex.lock();
changeFiberState(currentFiber, Fiber::State::Idle, Fiber::State::Running); changeFiberState(currentFiber, Fiber::State::Idle, Fiber::State::Running);
} }
...@@ -724,6 +715,7 @@ void Scheduler::Worker::runUntilIdle() { ...@@ -724,6 +715,7 @@ void Scheduler::Worker::runUntilIdle() {
} }
} }
_Requires_lock_held_(work.mutex)
Scheduler::Fiber* Scheduler::Worker::createWorkerFiber() { Scheduler::Fiber* Scheduler::Worker::createWorkerFiber() {
auto fiberId = static_cast<uint32_t>(workerFibers.size() + 1); auto fiberId = static_cast<uint32_t>(workerFibers.size() + 1);
DBG_LOG("%d: CREATE(%d)", (int)id, (int)fiberId); DBG_LOG("%d: CREATE(%d)", (int)id, (int)fiberId);
...@@ -734,6 +726,7 @@ Scheduler::Fiber* Scheduler::Worker::createWorkerFiber() { ...@@ -734,6 +726,7 @@ Scheduler::Fiber* Scheduler::Worker::createWorkerFiber() {
return ptr; return ptr;
} }
_Requires_lock_held_(work.mutex)
void Scheduler::Worker::switchToFiber(Fiber* to) { void Scheduler::Worker::switchToFiber(Fiber* to) {
DBG_LOG("%d: SWITCH(%d -> %d)", (int)id, (int)currentFiber->id, (int)to->id); DBG_LOG("%d: SWITCH(%d -> %d)", (int)id, (int)currentFiber->id, (int)to->id);
MARL_ASSERT(to == mainFiber.get() || idleFibers.count(to) == 0, MARL_ASSERT(to == mainFiber.get() || idleFibers.count(to) == 0,
...@@ -743,4 +736,21 @@ void Scheduler::Worker::switchToFiber(Fiber* to) { ...@@ -743,4 +736,21 @@ void Scheduler::Worker::switchToFiber(Fiber* to) {
from->switchTo(to); from->switchTo(to);
} }
////////////////////////////////////////////////////////////////////////////////
// Scheduler::Worker::Work
////////////////////////////////////////////////////////////////////////////////
_Requires_lock_held_(mutex)
template <typename F>
void Scheduler::Worker::Work::wait(F&& f) {
std::unique_lock<std::mutex> lock(mutex, std::adopt_lock);
notifyAdded = true;
if (waiting) {
added.wait_until(lock, waiting.next(), f);
} else {
added.wait(lock, f);
}
notifyAdded = false;
lock.release(); // Keep the lock held.
}
} // namespace marl } // namespace marl
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment