Commit 193ce894 by Ben Clayton

Squashed 'third_party/marl/' changes from 12872a0df..49e4e3141

49e4e3141 Add blocking_call test for functions with non-void return. 0acd85c63 blocking_call: Workaround for GCC bug with parameter packs in lambdas. ce47eca32 Work around TSAN false-positive in old versions of GCC 85854b73a Implement allocation page guards d8b38213f Add missing header include guards 776f7a485 Add Allocator interface and use them throughout marl. git-subtree-dir: third_party/marl git-subtree-split: 49e4e314157e8560f6ad620f5ec28a877612f03b
parent 07ed7cf1
...@@ -53,6 +53,7 @@ endif(MARL_BUILD_TESTS) ...@@ -53,6 +53,7 @@ endif(MARL_BUILD_TESTS)
########################################################### ###########################################################
set(MARL_LIST set(MARL_LIST
${MARL_SRC_DIR}/debug.cpp ${MARL_SRC_DIR}/debug.cpp
${MARL_SRC_DIR}/memory.cpp
${MARL_SRC_DIR}/scheduler.cpp ${MARL_SRC_DIR}/scheduler.cpp
${MARL_SRC_DIR}/thread.cpp ${MARL_SRC_DIR}/thread.cpp
${MARL_SRC_DIR}/trace.cpp ${MARL_SRC_DIR}/trace.cpp
......
...@@ -12,11 +12,15 @@ ...@@ -12,11 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef marl_blocking_call_h
#define marl_blocking_call_h
#include "defer.h" #include "defer.h"
#include "waitgroup.h" #include "waitgroup.h"
#include <thread> #include <thread>
#include <type_traits> #include <type_traits>
#include <utility>
namespace marl { namespace marl {
namespace detail { namespace detail {
...@@ -28,10 +32,12 @@ class OnNewThread { ...@@ -28,10 +32,12 @@ class OnNewThread {
inline static RETURN_TYPE call(F&& f, Args&&... args) { inline static RETURN_TYPE call(F&& f, Args&&... args) {
RETURN_TYPE result; RETURN_TYPE result;
WaitGroup wg(1); WaitGroup wg(1);
auto thread = std::thread([&] { auto thread = std::thread(
defer(wg.done()); [&](Args&&... args) {
result = f(args...); defer(wg.done());
}); result = f(std::forward<Args>(args)...);
},
std::forward<Args>(args)...);
wg.wait(); wg.wait();
thread.join(); thread.join();
return result; return result;
...@@ -44,10 +50,12 @@ class OnNewThread<void> { ...@@ -44,10 +50,12 @@ class OnNewThread<void> {
template <typename F, typename... Args> template <typename F, typename... Args>
inline static void call(F&& f, Args&&... args) { inline static void call(F&& f, Args&&... args) {
WaitGroup wg(1); WaitGroup wg(1);
auto thread = std::thread([&] { auto thread = std::thread(
defer(wg.done()); [&](Args&&... args) {
f(args...); defer(wg.done());
}); f(std::forward<Args>(args)...);
},
std::forward<Args>(args)...);
wg.wait(); wg.wait();
thread.join(); thread.join();
} }
...@@ -78,3 +86,5 @@ auto inline blocking_call(F&& f, Args&&... args) -> decltype(f(args...)) { ...@@ -78,3 +86,5 @@ auto inline blocking_call(F&& f, Args&&... args) -> decltype(f(args...)) {
} }
} // namespace marl } // namespace marl
#endif // marl_blocking_call_h
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#define marl_containers_h #define marl_containers_h
#include "debug.h" #include "debug.h"
#include "memory.h" // aligned_storage #include "memory.h"
#include <algorithm> // std::max #include <algorithm> // std::max
#include <utility> // std::move #include <utility> // std::move
...@@ -38,13 +38,15 @@ namespace containers { ...@@ -38,13 +38,15 @@ namespace containers {
template <typename T, int BASE_CAPACITY> template <typename T, int BASE_CAPACITY>
class vector { class vector {
public: public:
inline vector() = default; inline vector(Allocator* allocator = Allocator::Default);
template <int BASE_CAPACITY_2> template <int BASE_CAPACITY_2>
inline vector(const vector<T, BASE_CAPACITY_2>& other); inline vector(const vector<T, BASE_CAPACITY_2>& other,
Allocator* allocator = Allocator::Default);
template <int BASE_CAPACITY_2> template <int BASE_CAPACITY_2>
inline vector(vector<T, BASE_CAPACITY_2>&& other); inline vector(vector<T, BASE_CAPACITY_2>&& other,
Allocator* allocator = Allocator::Default);
inline ~vector(); inline ~vector();
...@@ -73,21 +75,34 @@ class vector { ...@@ -73,21 +75,34 @@ class vector {
inline void free(); inline void free();
Allocator* const allocator;
size_t count = 0; size_t count = 0;
size_t capacity = BASE_CAPACITY; size_t capacity = BASE_CAPACITY;
TStorage buffer[BASE_CAPACITY]; TStorage buffer[BASE_CAPACITY];
TStorage* elements = buffer; TStorage* elements = buffer;
Allocation allocation;
}; };
template <typename T, int BASE_CAPACITY> template <typename T, int BASE_CAPACITY>
vector<T, BASE_CAPACITY>::vector(
Allocator* allocator /* = Allocator::Default */)
: allocator(allocator) {}
template <typename T, int BASE_CAPACITY>
template <int BASE_CAPACITY_2> template <int BASE_CAPACITY_2>
vector<T, BASE_CAPACITY>::vector(const vector<T, BASE_CAPACITY_2>& other) { vector<T, BASE_CAPACITY>::vector(
const vector<T, BASE_CAPACITY_2>& other,
Allocator* allocator /* = Allocator::Default */)
: allocator(allocator) {
*this = other; *this = other;
} }
template <typename T, int BASE_CAPACITY> template <typename T, int BASE_CAPACITY>
template <int BASE_CAPACITY_2> template <int BASE_CAPACITY_2>
vector<T, BASE_CAPACITY>::vector(vector<T, BASE_CAPACITY_2>&& other) { vector<T, BASE_CAPACITY>::vector(
vector<T, BASE_CAPACITY_2>&& other,
Allocator* allocator /* = Allocator::Default */)
: allocator(allocator) {
*this = std::move(other); *this = std::move(other);
} }
...@@ -198,13 +213,21 @@ template <typename T, int BASE_CAPACITY> ...@@ -198,13 +213,21 @@ template <typename T, int BASE_CAPACITY>
void vector<T, BASE_CAPACITY>::reserve(size_t n) { void vector<T, BASE_CAPACITY>::reserve(size_t n) {
if (n > capacity) { if (n > capacity) {
capacity = std::max<size_t>(n * 2, 8); capacity = std::max<size_t>(n * 2, 8);
auto grown = new TStorage[capacity];
Allocation::Request request;
request.size = sizeof(T) * capacity;
request.alignment = alignof(T);
request.usage = Allocation::Usage::Vector;
auto alloc = allocator->allocate(request);
auto grown = reinterpret_cast<TStorage*>(alloc.ptr);
for (size_t i = 0; i < count; i++) { for (size_t i = 0; i < count; i++) {
new (&reinterpret_cast<T*>(grown)[i]) new (&reinterpret_cast<T*>(grown)[i])
T(std::move(reinterpret_cast<T*>(elements)[i])); T(std::move(reinterpret_cast<T*>(elements)[i]));
} }
free(); free();
elements = grown; elements = grown;
allocation = alloc;
} }
} }
...@@ -214,8 +237,8 @@ void vector<T, BASE_CAPACITY>::free() { ...@@ -214,8 +237,8 @@ void vector<T, BASE_CAPACITY>::free() {
reinterpret_cast<T*>(elements)[i].~T(); reinterpret_cast<T*>(elements)[i].~T();
} }
if (elements != buffer) { if (allocation.ptr != nullptr) {
delete[] elements; allocator->free(allocation);
elements = nullptr; elements = nullptr;
} }
} }
......
...@@ -18,63 +18,163 @@ ...@@ -18,63 +18,163 @@
#include "debug.h" #include "debug.h"
#include <stdint.h> #include <stdint.h>
#include <array>
#include <cstdlib> #include <cstdlib>
#include <memory> #include <memory>
#include <mutex>
#include <utility> // std::forward #include <utility> // std::forward
namespace marl { namespace marl {
template <typename T> // pageSize() returns the size in bytes of a virtual memory page for the host
inline T alignUp(T val, T alignment) { // system.
return alignment * ((val + alignment - 1) / alignment); size_t pageSize();
}
// aligned_malloc() allocates size bytes of uninitialized storage with the // Allocation holds the result of a memory allocation from an Allocator.
// specified minimum byte alignment. The pointer returned must be freed with struct Allocation {
// aligned_free(). // Intended usage of the allocation. Used for allocation trackers.
inline void* aligned_malloc(size_t alignment, size_t size) { enum class Usage {
MARL_ASSERT(alignment < 256, "alignment must less than 256"); Undefined = 0,
auto allocation = new uint8_t[size + sizeof(uint8_t) + alignment]; Stack, // Fiber stack
auto aligned = allocation; Create, // Allocator::create(), make_unique(), make_shared()
aligned += sizeof(uint8_t); // Make space for the base-address offset. Vector, // marl::vector<T>
aligned = reinterpret_cast<uint8_t*>( Count, // Not intended to be used as a usage type - used for upper bound.
alignUp(reinterpret_cast<uintptr_t>(aligned), alignment)); // align };
auto offset = static_cast<uint8_t>(aligned - allocation);
aligned[-1] = offset; // Request holds all the information required to make an allocation.
return aligned; struct Request {
} size_t size = 0; // The size of the allocation in bytes.
size_t alignment = 0; // The minimum alignment of the allocation.
bool useGuards = false; // Whether the allocation is guarded.
Usage usage = Usage::Undefined; // Intended usage of the allocation.
};
void* ptr = nullptr; // The pointer to the allocated memory.
Request request; // Request used for the allocation.
};
// Allocator is an interface to a memory allocator.
// Marl provides a default implementation with Allocator::Default.
class Allocator {
public:
// The default allocator. Initialized with an implementation that allocates
// from the OS. Can be assigned a custom implementation.
static Allocator* Default;
// Deleter is a smart-pointer compatible deleter that can be used to delete
// objects created by Allocator::create(). Deleter is used by the smart
// pointers returned by make_shared() and make_unique().
struct Deleter {
inline Deleter();
inline Deleter(Allocator* allocator);
template <typename T>
inline void operator()(T* object);
Allocator* allocator = nullptr;
};
// unique_ptr<T> is an alias to std::unique_ptr<T, Deleter>.
template <typename T>
using unique_ptr = std::unique_ptr<T, Deleter>;
// aligned_free() frees memory allocated by aligned_malloc. virtual ~Allocator() = default;
inline void aligned_free(void* ptr) {
auto aligned = reinterpret_cast<uint8_t*>(ptr); // allocate() allocates memory from the allocator.
auto offset = aligned[-1]; // The returned Allocation::request field must be equal to the Request
auto allocation = aligned - offset; // parameter.
delete[] allocation; virtual Allocation allocate(const Allocation::Request&) = 0;
// free() frees the memory returned by allocate().
// The Allocation must have all fields equal to those returned by allocate().
virtual void free(const Allocation&) = 0;
// create() allocates and constructs an object of type T, respecting the
// alignment of the type.
// The pointer returned by create() must be deleted with destroy().
template <typename T, typename... ARGS>
inline T* create(ARGS&&... args);
// destroy() destructs and frees the object allocated with create().
template <typename T>
inline void destroy(T* object);
// make_unique() returns a new object allocated from the allocator wrapped
// in a unique_ptr that respects the alignemnt of the type.
template <typename T, typename... ARGS>
inline unique_ptr<T> make_unique(ARGS&&... args);
// make_shared() returns a new object allocated from the allocator
// wrapped in a std::shared_ptr that respects the alignemnt of the type.
template <typename T, typename... ARGS>
inline std::shared_ptr<T> make_shared(ARGS&&... args);
protected:
Allocator() = default;
};
Allocator::Deleter::Deleter() : allocator(nullptr) {}
Allocator::Deleter::Deleter(Allocator* allocator) : allocator(allocator) {}
template <typename T>
void Allocator::Deleter::operator()(T* object) {
object->~T();
Allocation allocation;
allocation.ptr = object;
allocation.request.size = sizeof(T);
allocation.request.alignment = alignof(T);
allocation.request.usage = Allocation::Usage::Create;
allocator->free(allocation);
} }
// aligned_new() allocates and constructs an object of type T, respecting the
// alignment of the type.
// The pointer returned by aligned_new() must be deleted with aligned_delete().
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
T* aligned_new(ARGS&&... args) { T* Allocator::create(ARGS&&... args) {
auto ptr = aligned_malloc(alignof(T), sizeof(T)); Allocation::Request request;
new (ptr) T(std::forward<ARGS>(args)...); request.size = sizeof(T);
return reinterpret_cast<T*>(ptr); request.alignment = alignof(T);
request.usage = Allocation::Usage::Create;
auto alloc = allocate(request);
new (alloc.ptr) T(std::forward<ARGS>(args)...);
return reinterpret_cast<T*>(alloc.ptr);
} }
// aligned_delete() destructs and frees the object allocated with aligned_new().
template <typename T> template <typename T>
void aligned_delete(T* object) { void Allocator::destroy(T* object) {
object->~T(); object->~T();
aligned_free(object);
Allocation alloc;
alloc.ptr = object;
alloc.request.size = sizeof(T);
alloc.request.alignment = alignof(T);
alloc.request.usage = Allocation::Usage::Create;
free(alloc);
} }
// make_aligned_shared() returns a new object wrapped in a std::shared_ptr that
// respects the alignemnt of the type.
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
inline std::shared_ptr<T> make_aligned_shared(ARGS&&... args) { Allocator::unique_ptr<T> Allocator::make_unique(ARGS&&... args) {
auto ptr = aligned_new<T>(std::forward<ARGS>(args)...); Allocation::Request request;
return std::shared_ptr<T>(ptr, aligned_delete<T>); request.size = sizeof(T);
request.alignment = alignof(T);
request.usage = Allocation::Usage::Create;
auto alloc = allocate(request);
new (alloc.ptr) T(std::forward<ARGS>(args)...);
return unique_ptr<T>(reinterpret_cast<T*>(alloc.ptr), Deleter{this});
}
template <typename T, typename... ARGS>
std::shared_ptr<T> Allocator::make_shared(ARGS&&... args) {
Allocation::Request request;
request.size = sizeof(T);
request.alignment = alignof(T);
request.usage = Allocation::Usage::Create;
auto alloc = allocate(request);
new (alloc.ptr) T(std::forward<ARGS>(args)...);
return std::shared_ptr<T>(reinterpret_cast<T*>(alloc.ptr), Deleter{this});
} }
// aligned_storage() is a replacement for std::aligned_storage that isn't busted // aligned_storage() is a replacement for std::aligned_storage that isn't busted
...@@ -86,6 +186,94 @@ struct aligned_storage { ...@@ -86,6 +186,94 @@ struct aligned_storage {
}; };
}; };
// TrackedAllocator wraps an Allocator to track the allocations made.
class TrackedAllocator : public Allocator {
public:
struct UsageStats {
// Total number of allocations.
size_t count = 0;
// total allocation size in bytes (as requested, may be higher due to
// alignment or guards).
size_t bytes = 0;
};
struct Stats {
// numAllocations() returns the total number of allocations across all
// usages for the allocator.
inline size_t numAllocations() const;
// bytesAllocated() returns the total number of bytes allocated across all
// usages for the allocator.
inline size_t bytesAllocated() const;
// Statistics per usage.
std::array<UsageStats, size_t(Allocation::Usage::Count)> byUsage;
};
// Constructor that wraps an existing allocator.
inline TrackedAllocator(Allocator* allocator);
// stats() returns the current allocator statistics.
inline Stats stats();
// Allocator compliance
inline Allocation allocate(const Allocation::Request&) override;
inline void free(const Allocation&) override;
private:
Allocator* const allocator;
std::mutex mutex;
Stats stats_;
};
size_t TrackedAllocator::Stats::numAllocations() const {
size_t out = 0;
for (auto& stats : byUsage) {
out += stats.count;
}
return out;
}
size_t TrackedAllocator::Stats::bytesAllocated() const {
size_t out = 0;
for (auto& stats : byUsage) {
out += stats.bytes;
}
return out;
}
TrackedAllocator::TrackedAllocator(Allocator* allocator)
: allocator(allocator) {}
TrackedAllocator::Stats TrackedAllocator::stats() {
std::unique_lock<std::mutex> lock(mutex);
return stats_;
}
Allocation TrackedAllocator::allocate(const Allocation::Request& request) {
{
std::unique_lock<std::mutex> lock(mutex);
auto& usageStats = stats_.byUsage[int(request.usage)];
++usageStats.count;
usageStats.bytes += request.size;
}
return allocator->allocate(request);
}
void TrackedAllocator::free(const Allocation& allocation) {
{
std::unique_lock<std::mutex> lock(mutex);
auto& usageStats = stats_.byUsage[int(allocation.request.usage)];
MARL_ASSERT(usageStats.count > 0,
"TrackedAllocator detected abnormal free()");
MARL_ASSERT(usageStats.bytes >= allocation.request.size,
"TrackedAllocator detected abnormal free()");
--usageStats.count;
usageStats.bytes -= allocation.request.size;
}
return allocator->free(allocation);
}
} // namespace marl } // namespace marl
#endif // marl_memory_h #endif // marl_memory_h
...@@ -210,6 +210,8 @@ class BoundedPool : public Pool<T> { ...@@ -210,6 +210,8 @@ class BoundedPool : public Pool<T> {
using Item = typename Pool<T>::Item; using Item = typename Pool<T>::Item;
using Loan = typename Pool<T>::Loan; using Loan = typename Pool<T>::Loan;
inline BoundedPool(Allocator* allocator = Allocator::Default);
// borrow() borrows a single item from the pool, blocking until an item is // borrow() borrows a single item from the pool, blocking until an item is
// returned if the pool is empty. // returned if the pool is empty.
inline Loan borrow() const; inline Loan borrow() const;
...@@ -239,7 +241,7 @@ class BoundedPool : public Pool<T> { ...@@ -239,7 +241,7 @@ class BoundedPool : public Pool<T> {
ConditionVariable returned; ConditionVariable returned;
Item* free = nullptr; Item* free = nullptr;
}; };
std::shared_ptr<Storage> storage = make_aligned_shared<Storage>(); std::shared_ptr<Storage> storage;
}; };
template <typename T, int N, PoolPolicy POLICY> template <typename T, int N, PoolPolicy POLICY>
...@@ -263,6 +265,11 @@ BoundedPool<T, N, POLICY>::Storage::~Storage() { ...@@ -263,6 +265,11 @@ BoundedPool<T, N, POLICY>::Storage::~Storage() {
} }
template <typename T, int N, PoolPolicy POLICY> template <typename T, int N, PoolPolicy POLICY>
BoundedPool<T, N, POLICY>::BoundedPool(
Allocator* allocator /* = Allocator::Default */)
: storage(allocator->make_shared<Storage>()) {}
template <typename T, int N, PoolPolicy POLICY>
typename BoundedPool<T, N, POLICY>::Loan BoundedPool<T, N, POLICY>::borrow() typename BoundedPool<T, N, POLICY>::Loan BoundedPool<T, N, POLICY>::borrow()
const { const {
Loan out; Loan out;
...@@ -329,6 +336,8 @@ class UnboundedPool : public Pool<T> { ...@@ -329,6 +336,8 @@ class UnboundedPool : public Pool<T> {
using Item = typename Pool<T>::Item; using Item = typename Pool<T>::Item;
using Loan = typename Pool<T>::Loan; using Loan = typename Pool<T>::Loan;
inline UnboundedPool(Allocator* allocator = Allocator::Default);
// borrow() borrows a single item from the pool, automatically allocating // borrow() borrows a single item from the pool, automatically allocating
// more items if the pool is empty. // more items if the pool is empty.
// This function does not block. // This function does not block.
...@@ -344,27 +353,41 @@ class UnboundedPool : public Pool<T> { ...@@ -344,27 +353,41 @@ class UnboundedPool : public Pool<T> {
private: private:
class Storage : public Pool<T>::Storage { class Storage : public Pool<T>::Storage {
public: public:
inline Storage(Allocator* allocator);
inline ~Storage(); inline ~Storage();
inline void return_(Item*) override; inline void return_(Item*) override;
Allocator* allocator;
std::mutex mutex; std::mutex mutex;
std::vector<Item*> items; std::vector<Item*> items;
Item* free = nullptr; Item* free = nullptr;
}; };
std::shared_ptr<Storage> storage = std::make_shared<Storage>();
Allocator* allocator;
std::shared_ptr<Storage> storage;
}; };
template <typename T, PoolPolicy POLICY> template <typename T, PoolPolicy POLICY>
UnboundedPool<T, POLICY>::Storage::Storage(Allocator* allocator)
: allocator(allocator) {}
template <typename T, PoolPolicy POLICY>
UnboundedPool<T, POLICY>::Storage::~Storage() { UnboundedPool<T, POLICY>::Storage::~Storage() {
for (auto item : items) { for (auto item : items) {
if (POLICY == PoolPolicy::Preserve) { if (POLICY == PoolPolicy::Preserve) {
item->destruct(); item->destruct();
} }
aligned_delete(item); allocator->destroy(item);
} }
} }
template <typename T, PoolPolicy POLICY> template <typename T, PoolPolicy POLICY>
UnboundedPool<T, POLICY>::UnboundedPool(
Allocator* allocator /* = Allocator::Default */)
: allocator(allocator),
storage(allocator->make_shared<Storage>(allocator)) {}
template <typename T, PoolPolicy POLICY>
Loan<T> UnboundedPool<T, POLICY>::borrow() const { Loan<T> UnboundedPool<T, POLICY>::borrow() const {
Loan out; Loan out;
borrow(1, [&](Loan&& loan) { out = std::move(loan); }); borrow(1, [&](Loan&& loan) { out = std::move(loan); });
...@@ -379,7 +402,7 @@ inline void UnboundedPool<T, POLICY>::borrow(size_t n, const F& f) const { ...@@ -379,7 +402,7 @@ inline void UnboundedPool<T, POLICY>::borrow(size_t n, const F& f) const {
if (storage->free == nullptr) { if (storage->free == nullptr) {
auto count = std::max<size_t>(storage->items.size(), 32); auto count = std::max<size_t>(storage->items.size(), 32);
for (size_t j = 0; j < count; j++) { for (size_t j = 0; j < count; j++) {
auto item = aligned_new<Item>(); auto item = allocator->create<Item>();
if (POLICY == PoolPolicy::Preserve) { if (POLICY == PoolPolicy::Preserve) {
item->construct(); item->construct();
} }
......
// Copyright 2019 The Marl Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef marl_sanitizers_h
#define marl_sanitizers_h
// Define ADDRESS_SANITIZER_ENABLED to 1 if the project was built with the
// address sanitizer enabled (-fsanitize=address).
#if defined(__SANITIZE_ADDRESS__)
#define ADDRESS_SANITIZER_ENABLED 1
#else // defined(__SANITIZE_ADDRESS__)
#if defined(__clang__)
#if __has_feature(address_sanitizer)
#define ADDRESS_SANITIZER_ENABLED 1
#endif // __has_feature(address_sanitizer)
#endif // defined(__clang__)
#endif // defined(__SANITIZE_ADDRESS__)
// ADDRESS_SANITIZER_ONLY(X) resolves to X if ADDRESS_SANITIZER_ENABLED is
// defined to a non-zero value, otherwise ADDRESS_SANITIZER_ONLY() is stripped
// by the preprocessor.
#if ADDRESS_SANITIZER_ENABLED
#define ADDRESS_SANITIZER_ONLY(x) x
#else
#define ADDRESS_SANITIZER_ONLY(x)
#endif // ADDRESS_SANITIZER_ENABLED
// Define MEMORY_SANITIZER_ENABLED to 1 if the project was built with the memory
// sanitizer enabled (-fsanitize=memory).
#if defined(__SANITIZE_MEMORY__)
#define MEMORY_SANITIZER_ENABLED 1
#else // defined(__SANITIZE_MEMORY__)
#if defined(__clang__)
#if __has_feature(memory_sanitizer)
#define MEMORY_SANITIZER_ENABLED 1
#endif // __has_feature(memory_sanitizer)
#endif // defined(__clang__)
#endif // defined(__SANITIZE_MEMORY__)
// MEMORY_SANITIZER_ONLY(X) resolves to X if MEMORY_SANITIZER_ENABLED is defined
// to a non-zero value, otherwise MEMORY_SANITIZER_ONLY() is stripped by the
// preprocessor.
#if MEMORY_SANITIZER_ENABLED
#define MEMORY_SANITIZER_ONLY(x) x
#else
#define MEMORY_SANITIZER_ONLY(x)
#endif // MEMORY_SANITIZER_ENABLED
// Define THREAD_SANITIZER_ENABLED to 1 if the project was built with the thread
// sanitizer enabled (-fsanitize=thread).
#if defined(__SANITIZE_THREAD__)
#define THREAD_SANITIZER_ENABLED 1
#else // defined(__SANITIZE_THREAD__)
#if defined(__clang__)
#if __has_feature(thread_sanitizer)
#define THREAD_SANITIZER_ENABLED 1
#endif // __has_feature(thread_sanitizer)
#endif // defined(__clang__)
#endif // defined(__SANITIZE_THREAD__)
// THREAD_SANITIZER_ONLY(X) resolves to X if THREAD_SANITIZER_ENABLED is defined
// to a non-zero value, otherwise THREAD_SANITIZER_ONLY() is stripped by the
// preprocessor.
#if THREAD_SANITIZER_ENABLED
#define THREAD_SANITIZER_ONLY(x) x
#else
#define THREAD_SANITIZER_ONLY(x)
#endif // THREAD_SANITIZER_ENABLED
#endif // marl_sanitizers_h
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#define marl_scheduler_h #define marl_scheduler_h
#include "debug.h" #include "debug.h"
#include "memory.h"
#include "sal.h" #include "sal.h"
#include <array> #include <array>
...@@ -44,7 +45,7 @@ class Scheduler { ...@@ -44,7 +45,7 @@ class Scheduler {
class Worker; class Worker;
public: public:
Scheduler(); Scheduler(Allocator* allocator = Allocator::Default);
~Scheduler(); ~Scheduler();
// get() returns the scheduler bound to the current thread. // get() returns the scheduler bound to the current thread.
...@@ -91,8 +92,6 @@ class Scheduler { ...@@ -91,8 +92,6 @@ class Scheduler {
// thread that previously executed it. // thread that previously executed it.
class Fiber { class Fiber {
public: public:
~Fiber();
// current() returns the currently executing fiber, or nullptr if called // current() returns the currently executing fiber, or nullptr if called
// without a bound scheduler. // without a bound scheduler.
static Fiber* current(); static Fiber* current();
...@@ -109,9 +108,10 @@ class Scheduler { ...@@ -109,9 +108,10 @@ class Scheduler {
uint32_t const id; uint32_t const id;
private: private:
friend class Allocator;
friend class Scheduler; friend class Scheduler;
Fiber(OSFiber*, uint32_t id); Fiber(Allocator::unique_ptr<OSFiber>&&, uint32_t id);
// switchTo() switches execution to the given fiber. // switchTo() switches execution to the given fiber.
// switchTo() must only be called on the currently executing fiber. // switchTo() must only be called on the currently executing fiber.
...@@ -119,15 +119,19 @@ class Scheduler { ...@@ -119,15 +119,19 @@ class Scheduler {
// create() constructs and returns a new fiber with the given identifier, // create() constructs and returns a new fiber with the given identifier,
// stack size that will executed func when switched to. // stack size that will executed func when switched to.
static Fiber* create(uint32_t id, static Allocator::unique_ptr<Fiber> create(
size_t stackSize, Allocator* allocator,
const std::function<void()>& func); uint32_t id,
size_t stackSize,
const std::function<void()>& func);
// createFromCurrentThread() constructs and returns a new fiber with the // createFromCurrentThread() constructs and returns a new fiber with the
// given identifier for the current thread. // given identifier for the current thread.
static Fiber* createFromCurrentThread(uint32_t id); static Allocator::unique_ptr<Fiber> createFromCurrentThread(
Allocator* allocator,
uint32_t id);
OSFiber* const impl; Allocator::unique_ptr<OSFiber> const impl;
Worker* const worker; Worker* const worker;
}; };
...@@ -266,12 +270,12 @@ class Scheduler { ...@@ -266,12 +270,12 @@ class Scheduler {
Mode const mode; Mode const mode;
Scheduler* const scheduler; Scheduler* const scheduler;
std::unique_ptr<Fiber> mainFiber; Allocator::unique_ptr<Fiber> mainFiber;
Fiber* currentFiber = nullptr; Fiber* currentFiber = nullptr;
std::thread thread; std::thread thread;
Work work; Work work;
FiberQueue idleFibers; // Fibers that have completed which can be reused. FiberQueue idleFibers; // Fibers that have completed which can be reused.
std::vector<std::unique_ptr<Fiber>> std::vector<Allocator::unique_ptr<Fiber>>
workerFibers; // All fibers created by this worker. workerFibers; // All fibers created by this worker.
FastRnd rng; FastRnd rng;
std::atomic<bool> shutdown = {false}; std::atomic<bool> shutdown = {false};
...@@ -289,6 +293,8 @@ class Scheduler { ...@@ -289,6 +293,8 @@ class Scheduler {
// The scheduler currently bound to the current thread. // The scheduler currently bound to the current thread.
static thread_local Scheduler* bound; static thread_local Scheduler* bound;
Allocator* const allocator;
std::function<void()> threadInitFunc; std::function<void()> threadInitFunc;
std::mutex threadInitFuncMutex; std::mutex threadInitFuncMutex;
...@@ -302,7 +308,7 @@ class Scheduler { ...@@ -302,7 +308,7 @@ class Scheduler {
std::array<Worker*, MaxWorkerThreads> workerThreads; std::array<Worker*, MaxWorkerThreads> workerThreads;
std::mutex singleThreadedWorkerMutex; std::mutex singleThreadedWorkerMutex;
std::unordered_map<std::thread::id, std::unique_ptr<Worker>> std::unordered_map<std::thread::id, Allocator::unique_ptr<Worker>>
singleThreadedWorkers; singleThreadedWorkers;
}; };
......
...@@ -18,6 +18,9 @@ ...@@ -18,6 +18,9 @@
// https://www.chromium.org/developers/how-tos/trace-event-profiling-tool // https://www.chromium.org/developers/how-tos/trace-event-profiling-tool
// https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit // https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/edit
#ifndef marl_trace_h
#define marl_trace_h
#define MARL_TRACE_ENABLED 0 #define MARL_TRACE_ENABLED 0
#if MARL_TRACE_ENABLED #if MARL_TRACE_ENABLED
...@@ -242,3 +245,5 @@ Trace::ScopedAsyncEvent::~ScopedAsyncEvent() { ...@@ -242,3 +245,5 @@ Trace::ScopedAsyncEvent::~ScopedAsyncEvent() {
#define MARL_NAME_THREAD(...) #define MARL_NAME_THREAD(...)
#endif // MARL_TRACE_ENABLED #endif // MARL_TRACE_ENABLED
#endif // marl_trace_h
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include <mutex> #include <mutex>
TEST_P(WithBoundScheduler, BlockingCall) { TEST_P(WithBoundScheduler, BlockingCallVoidReturn) {
auto mutex = std::make_shared<std::mutex>(); auto mutex = std::make_shared<std::mutex>();
mutex->lock(); mutex->lock();
...@@ -38,3 +38,26 @@ TEST_P(WithBoundScheduler, BlockingCall) { ...@@ -38,3 +38,26 @@ TEST_P(WithBoundScheduler, BlockingCall) {
mutex->unlock(); mutex->unlock();
wg.wait(); wg.wait();
} }
TEST_P(WithBoundScheduler, BlockingCallIntReturn) {
auto mutex = std::make_shared<std::mutex>();
mutex->lock();
marl::WaitGroup wg(100);
std::atomic<int> n = {0};
for (int i = 0; i < 100; i++) {
marl::schedule([=, &n] {
defer(wg.done());
n += marl::blocking_call([=] {
mutex->lock();
defer(mutex->unlock());
return i;
});
});
}
mutex->unlock();
wg.wait();
ASSERT_EQ(n.load(), 4950);
}
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "marl_test.h" #include "marl_test.h"
TEST(WithoutBoundScheduler, ConditionVariable) { TEST_F(WithoutBoundScheduler, ConditionVariable) {
bool trigger[3] = {false, false, false}; bool trigger[3] = {false, false, false};
bool signal[3] = {false, false, false}; bool signal[3] = {false, false, false};
std::mutex mutex; std::mutex mutex;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "marl/containers.h" #include "marl/containers.h"
#include "marl_test.h"
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -20,15 +21,15 @@ ...@@ -20,15 +21,15 @@
#include <cstddef> #include <cstddef>
#include <string> #include <string>
class ContainersVectorTest : public testing::Test {}; class ContainersVectorTest : public WithoutBoundScheduler {};
TEST(ContainersVectorTest, Empty) { TEST_F(ContainersVectorTest, Empty) {
marl::containers::vector<std::string, 4> vector; marl::containers::vector<std::string, 4> vector(allocator);
ASSERT_EQ(vector.size(), size_t(0)); ASSERT_EQ(vector.size(), size_t(0));
} }
TEST(ContainersVectorTest, WithinFixedCapIndex) { TEST_F(ContainersVectorTest, WithinFixedCapIndex) {
marl::containers::vector<std::string, 4> vector; marl::containers::vector<std::string, 4> vector(allocator);
vector.resize(4); vector.resize(4);
vector[0] = "A"; vector[0] = "A";
vector[1] = "B"; vector[1] = "B";
...@@ -41,8 +42,8 @@ TEST(ContainersVectorTest, WithinFixedCapIndex) { ...@@ -41,8 +42,8 @@ TEST(ContainersVectorTest, WithinFixedCapIndex) {
ASSERT_EQ(vector[3], "D"); ASSERT_EQ(vector[3], "D");
} }
TEST(ContainersVectorTest, BeyondFixedCapIndex) { TEST_F(ContainersVectorTest, BeyondFixedCapIndex) {
marl::containers::vector<std::string, 1> vector; marl::containers::vector<std::string, 1> vector(allocator);
vector.resize(4); vector.resize(4);
vector[0] = "A"; vector[0] = "A";
vector[1] = "B"; vector[1] = "B";
...@@ -55,8 +56,8 @@ TEST(ContainersVectorTest, BeyondFixedCapIndex) { ...@@ -55,8 +56,8 @@ TEST(ContainersVectorTest, BeyondFixedCapIndex) {
ASSERT_EQ(vector[3], "D"); ASSERT_EQ(vector[3], "D");
} }
TEST(ContainersVectorTest, WithinFixedCapPushPop) { TEST_F(ContainersVectorTest, WithinFixedCapPushPop) {
marl::containers::vector<std::string, 4> vector; marl::containers::vector<std::string, 4> vector(allocator);
vector.push_back("A"); vector.push_back("A");
vector.push_back("B"); vector.push_back("B");
vector.push_back("C"); vector.push_back("C");
...@@ -89,8 +90,8 @@ TEST(ContainersVectorTest, WithinFixedCapPushPop) { ...@@ -89,8 +90,8 @@ TEST(ContainersVectorTest, WithinFixedCapPushPop) {
ASSERT_EQ(vector.size(), size_t(0)); ASSERT_EQ(vector.size(), size_t(0));
} }
TEST(ContainersVectorTest, BeyondFixedCapPushPop) { TEST_F(ContainersVectorTest, BeyondFixedCapPushPop) {
marl::containers::vector<std::string, 2> vector; marl::containers::vector<std::string, 2> vector(allocator);
vector.push_back("A"); vector.push_back("A");
vector.push_back("B"); vector.push_back("B");
vector.push_back("C"); vector.push_back("C");
...@@ -123,39 +124,40 @@ TEST(ContainersVectorTest, BeyondFixedCapPushPop) { ...@@ -123,39 +124,40 @@ TEST(ContainersVectorTest, BeyondFixedCapPushPop) {
ASSERT_EQ(vector.size(), size_t(0)); ASSERT_EQ(vector.size(), size_t(0));
} }
TEST(ContainersVectorTest, CopyConstruct) { TEST_F(ContainersVectorTest, CopyConstruct) {
marl::containers::vector<std::string, 4> vectorA; marl::containers::vector<std::string, 4> vectorA(allocator);
vectorA.resize(3); vectorA.resize(3);
vectorA[0] = "A"; vectorA[0] = "A";
vectorA[1] = "B"; vectorA[1] = "B";
vectorA[2] = "C"; vectorA[2] = "C";
marl::containers::vector<std::string, 2> vectorB(vectorA); marl::containers::vector<std::string, 2> vectorB(vectorA, allocator);
ASSERT_EQ(vectorB.size(), size_t(3)); ASSERT_EQ(vectorB.size(), size_t(3));
ASSERT_EQ(vectorB[0], "A"); ASSERT_EQ(vectorB[0], "A");
ASSERT_EQ(vectorB[1], "B"); ASSERT_EQ(vectorB[1], "B");
ASSERT_EQ(vectorB[2], "C"); ASSERT_EQ(vectorB[2], "C");
} }
TEST(ContainersVectorTest, MoveConstruct) { TEST_F(ContainersVectorTest, MoveConstruct) {
marl::containers::vector<std::string, 4> vectorA; marl::containers::vector<std::string, 4> vectorA(allocator);
vectorA.resize(3); vectorA.resize(3);
vectorA[0] = "A"; vectorA[0] = "A";
vectorA[1] = "B"; vectorA[1] = "B";
vectorA[2] = "C"; vectorA[2] = "C";
marl::containers::vector<std::string, 2> vectorB(std::move(vectorA)); marl::containers::vector<std::string, 2> vectorB(std::move(vectorA),
allocator);
ASSERT_EQ(vectorB.size(), size_t(3)); ASSERT_EQ(vectorB.size(), size_t(3));
ASSERT_EQ(vectorB[0], "A"); ASSERT_EQ(vectorB[0], "A");
ASSERT_EQ(vectorB[1], "B"); ASSERT_EQ(vectorB[1], "B");
ASSERT_EQ(vectorB[2], "C"); ASSERT_EQ(vectorB[2], "C");
} }
TEST(ContainersVectorTest, Copy) { TEST_F(ContainersVectorTest, Copy) {
marl::containers::vector<std::string, 4> vectorA; marl::containers::vector<std::string, 4> vectorA(allocator);
marl::containers::vector<std::string, 2> vectorB; marl::containers::vector<std::string, 2> vectorB(allocator);
vectorA.resize(3); vectorA.resize(3);
vectorA[0] = "A"; vectorA[0] = "A";
...@@ -172,9 +174,9 @@ TEST(ContainersVectorTest, Copy) { ...@@ -172,9 +174,9 @@ TEST(ContainersVectorTest, Copy) {
ASSERT_EQ(vectorB[2], "C"); ASSERT_EQ(vectorB[2], "C");
} }
TEST(ContainersVectorTest, Move) { TEST_F(ContainersVectorTest, Move) {
marl::containers::vector<std::string, 4> vectorA; marl::containers::vector<std::string, 4> vectorA(allocator);
marl::containers::vector<std::string, 2> vectorB; marl::containers::vector<std::string, 2> vectorB(allocator);
vectorA.resize(3); vectorA.resize(3);
vectorA[0] = "A"; vectorA[0] = "A";
......
...@@ -16,13 +16,13 @@ ...@@ -16,13 +16,13 @@
#include "marl_test.h" #include "marl_test.h"
TEST(WithoutBoundScheduler, Defer) { TEST_F(WithoutBoundScheduler, Defer) {
bool deferCalled = false; bool deferCalled = false;
{ defer(deferCalled = true); } { defer(deferCalled = true); }
ASSERT_TRUE(deferCalled); ASSERT_TRUE(deferCalled);
} }
TEST(WithoutBoundScheduler, DeferOrder) { TEST_F(WithoutBoundScheduler, DeferOrder) {
int counter = 0; int counter = 0;
int a = 0, b = 0, c = 0; int a = 0, b = 0, c = 0;
{ {
......
...@@ -29,16 +29,32 @@ struct SchedulerParams { ...@@ -29,16 +29,32 @@ struct SchedulerParams {
}; };
// WithoutBoundScheduler is a test fixture that does not bind a scheduler. // WithoutBoundScheduler is a test fixture that does not bind a scheduler.
class WithoutBoundScheduler : public testing::Test {}; class WithoutBoundScheduler : public testing::Test {
public:
void SetUp() override {
allocator = new marl::TrackedAllocator(marl::Allocator::Default);
}
void TearDown() override {
auto stats = allocator->stats();
ASSERT_EQ(stats.numAllocations(), 0U);
ASSERT_EQ(stats.bytesAllocated(), 0U);
delete allocator;
}
marl::TrackedAllocator* allocator = nullptr;
};
// WithBoundScheduler is a parameterized test fixture that performs tests with // WithBoundScheduler is a parameterized test fixture that performs tests with
// a bound scheduler using a number of different configurations. // a bound scheduler using a number of different configurations.
class WithBoundScheduler : public testing::TestWithParam<SchedulerParams> { class WithBoundScheduler : public testing::TestWithParam<SchedulerParams> {
public: public:
void SetUp() override { void SetUp() override {
allocator = new marl::TrackedAllocator(marl::Allocator::Default);
auto& params = GetParam(); auto& params = GetParam();
auto scheduler = new marl::Scheduler(); auto scheduler = new marl::Scheduler(allocator);
scheduler->bind(); scheduler->bind();
scheduler->setWorkerThreadCount(params.numWorkerThreads); scheduler->setWorkerThreadCount(params.numWorkerThreads);
} }
...@@ -47,5 +63,12 @@ class WithBoundScheduler : public testing::TestWithParam<SchedulerParams> { ...@@ -47,5 +63,12 @@ class WithBoundScheduler : public testing::TestWithParam<SchedulerParams> {
auto scheduler = marl::Scheduler::get(); auto scheduler = marl::Scheduler::get();
scheduler->unbind(); scheduler->unbind();
delete scheduler; delete scheduler;
auto stats = allocator->stats();
ASSERT_EQ(stats.numAllocations(), 0U);
ASSERT_EQ(stats.bytesAllocated(), 0U);
delete allocator;
} }
marl::TrackedAllocator* allocator = nullptr;
}; };
// Copyright 2019 The Marl Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "marl/memory.h"
#include "marl/debug.h"
#include "marl/sanitizers.h"
#include <cstring>
#if defined(__linux__) || defined(__APPLE__)
#include <sys/mman.h>
#include <unistd.h>
namespace {
// This was a static in pageSize(), but due to the following TSAN false-positive
// bug, this has been moved out to a global.
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68338
const size_t kPageSize = sysconf(_SC_PAGESIZE);
inline size_t pageSize() {
return kPageSize;
}
inline void* allocatePages(size_t count) {
auto mapping = mmap(nullptr, count * pageSize(), PROT_READ | PROT_WRITE,
MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
MARL_ASSERT(mapping != MAP_FAILED, "Failed to allocate %d pages", int(count));
if (mapping == MAP_FAILED) {
mapping = nullptr;
}
return mapping;
}
inline void freePages(void* ptr, size_t count) {
auto res = munmap(ptr, count * pageSize());
(void)res;
MARL_ASSERT(res == 0, "Failed to free %d pages at %p", int(count), ptr);
}
inline void protectPage(void* addr) {
auto res = mprotect(addr, pageSize(), PROT_NONE);
(void)res;
MARL_ASSERT(res == 0, "Failed to protect page at %p", addr);
}
} // anonymous namespace
#elif defined(_WIN32)
#define WIN32_LEAN_AND_MEAN 1
#include <Windows.h>
namespace {
inline size_t pageSize() {
static auto size = [] {
SYSTEM_INFO systemInfo = {};
GetSystemInfo(&systemInfo);
return systemInfo.dwPageSize;
}();
return size;
}
inline void* allocatePages(size_t count) {
auto mapping = VirtualAlloc(nullptr, count * pageSize(),
MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE);
MARL_ASSERT(mapping != nullptr, "Failed to allocate %d pages", int(count));
return mapping;
}
inline void freePages(void* ptr, size_t count) {
(void)count;
auto res = VirtualFree(ptr, 0, MEM_RELEASE);
(void)res;
MARL_ASSERT(res != 0, "Failed to free %d pages at %p", int(count), ptr);
}
inline void protectPage(void* addr) {
DWORD oldVal = 0;
auto res = VirtualProtect(addr, pageSize(), PAGE_NOACCESS, &oldVal);
(void)res;
MARL_ASSERT(res != 0, "Failed to protect page at %p", addr);
}
} // anonymous namespace
#else
// TODO: Fuchsia support
#error "Page based allocation not implemented for this platform"
#endif
namespace {
template <typename T>
inline T alignUp(T val, T alignment) {
return alignment * ((val + alignment - 1) / alignment);
}
// pagedMalloc() allocates size bytes of uninitialized storage with the
// specified minimum byte alignment using OS specific page mapping calls.
// If guardLow is true then reads or writes to the page below the returned
// address will cause a page fault.
// If guardHigh is true then reads or writes to the page above the allocated
// block will cause a page fault.
// The pointer returned must be freed with pagedFree().
void* pagedMalloc(size_t alignment,
size_t size,
bool guardLow,
bool guardHigh) {
(void)alignment;
MARL_ASSERT(alignment < pageSize(),
"alignment (0x%x) must be less than the page size (0x%x)",
int(alignment), int(pageSize()));
auto numRequestedPages = (size + pageSize() - 1) / pageSize();
auto numTotalPages =
numRequestedPages + (guardLow ? 1 : 0) + (guardHigh ? 1 : 0);
auto mem = reinterpret_cast<uint8_t*>(allocatePages(numTotalPages));
if (guardLow) {
protectPage(mem);
mem += pageSize();
}
if (guardHigh) {
protectPage(mem + numRequestedPages * pageSize());
}
return mem;
}
// pagedFree() frees the memory allocated with pagedMalloc().
void pagedFree(void* ptr,
size_t alignment,
size_t size,
bool guardLow,
bool guardHigh) {
(void)alignment;
MARL_ASSERT(alignment < pageSize(),
"alignment (0x%x) must be less than the page size (0x%x)",
int(alignment), int(pageSize()));
auto numRequestedPages = (size + pageSize() - 1) / pageSize();
auto numTotalPages =
numRequestedPages + (guardLow ? 1 : 0) + (guardHigh ? 1 : 0);
if (guardLow) {
ptr = reinterpret_cast<uint8_t*>(ptr) - pageSize();
}
freePages(ptr, numTotalPages);
}
// alignedMalloc() allocates size bytes of uninitialized storage with the
// specified minimum byte alignment. The pointer returned must be freed with
// alignedFree().
inline void* alignedMalloc(size_t alignment, size_t size) {
size_t allocSize = size + alignment + sizeof(void*);
auto allocation = malloc(allocSize);
auto aligned = reinterpret_cast<uint8_t*>(
alignUp(reinterpret_cast<uintptr_t>(allocation), alignment)); // align
memcpy(aligned + size, &allocation, sizeof(void*)); // pointer-to-allocation
return aligned;
}
// alignedFree() frees memory allocated by alignedMalloc.
inline void alignedFree(void* ptr, size_t size) {
void* base;
memcpy(&base, reinterpret_cast<uint8_t*>(ptr) + size, sizeof(size_t));
free(base);
}
class DefaultAllocator : public marl::Allocator {
public:
static DefaultAllocator instance;
virtual marl::Allocation allocate(
const marl::Allocation::Request& request) override {
void* ptr = nullptr;
if (request.useGuards) {
ptr = ::pagedMalloc(request.alignment, request.size, true, true);
} else if (request.alignment > 1U) {
ptr = ::alignedMalloc(request.alignment, request.size);
} else {
ptr = ::malloc(request.size);
}
MARL_ASSERT(ptr != nullptr, "Allocation failed");
MARL_ASSERT(reinterpret_cast<uintptr_t>(ptr) % request.alignment == 0,
"Allocation gave incorrect alignment");
marl::Allocation allocation;
allocation.ptr = ptr;
allocation.request = request;
return allocation;
}
virtual void free(const marl::Allocation& allocation) override {
if (allocation.request.useGuards) {
::pagedFree(allocation.ptr, allocation.request.alignment,
allocation.request.size, true, true);
} else if (allocation.request.alignment > 1U) {
::alignedFree(allocation.ptr, allocation.request.size);
} else {
::free(allocation.ptr);
}
}
};
DefaultAllocator DefaultAllocator::instance;
} // anonymous namespace
namespace marl {
Allocator* Allocator::Default = &DefaultAllocator::instance;
size_t pageSize() {
return ::pageSize();
}
} // namespace marl
...@@ -16,19 +16,36 @@ ...@@ -16,19 +16,36 @@
#include "marl_test.h" #include "marl_test.h"
class MemoryTest : public testing::Test {}; class AllocatorTest : public testing::Test {
public:
marl::Allocator* allocator = marl::Allocator::Default;
};
TEST(MemoryTest, AlignedMalloc) { TEST_F(AllocatorTest, AlignedAllocate) {
std::vector<bool> guards = {false, true};
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 8, 14, 16, 17, std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 8, 14, 16, 17,
31, 34, 50, 63, 64, 65, 100, 127, 128, 129, 31, 34, 50, 63, 64, 65, 100, 127, 128, 129,
200, 255, 256, 257, 500, 511, 512, 513}; 200, 255, 256, 257, 500, 511, 512, 513};
std::vector<size_t> alignments = {1, 2, 4, 8, 16, 32, 64, 128}; std::vector<size_t> alignments = {1, 2, 4, 8, 16, 32, 64, 128};
for (auto alignment : alignments) { for (auto useGuards : guards) {
for (auto size : sizes) { for (auto alignment : alignments) {
auto ptr = marl::aligned_malloc(alignment, size); for (auto size : sizes) {
ASSERT_EQ(reinterpret_cast<uintptr_t>(ptr) & (alignment - 1), 0U); marl::Allocation::Request request;
memset(ptr, 0, size); // Check the memory was actually allocated. request.alignment = alignment;
marl::aligned_free(ptr); request.size = size;
request.useGuards = useGuards;
auto allocation = allocator->allocate(request);
auto ptr = allocation.ptr;
ASSERT_EQ(allocation.request.size, request.size);
ASSERT_EQ(allocation.request.alignment, request.alignment);
ASSERT_EQ(allocation.request.useGuards, request.useGuards);
ASSERT_EQ(allocation.request.usage, request.usage);
ASSERT_EQ(reinterpret_cast<uintptr_t>(ptr) & (alignment - 1), 0U);
memset(ptr, 0,
size); // Check the memory was actually allocated.
allocator->free(allocation);
}
} }
} }
} }
...@@ -46,17 +63,28 @@ struct alignas(64) StructWith64ByteAlignment { ...@@ -46,17 +63,28 @@ struct alignas(64) StructWith64ByteAlignment {
uint8_t padding[63]; uint8_t padding[63];
}; };
TEST(MemoryTest, AlignedNew) { TEST_F(AllocatorTest, Create) {
auto s16 = marl::aligned_new<StructWith16ByteAlignment>(); auto s16 = allocator->create<StructWith16ByteAlignment>();
auto s32 = marl::aligned_new<StructWith32ByteAlignment>(); auto s32 = allocator->create<StructWith32ByteAlignment>();
auto s64 = marl::aligned_new<StructWith64ByteAlignment>(); auto s64 = allocator->create<StructWith64ByteAlignment>();
ASSERT_EQ(alignof(StructWith16ByteAlignment), 16U); ASSERT_EQ(alignof(StructWith16ByteAlignment), 16U);
ASSERT_EQ(alignof(StructWith32ByteAlignment), 32U); ASSERT_EQ(alignof(StructWith32ByteAlignment), 32U);
ASSERT_EQ(alignof(StructWith64ByteAlignment), 64U); ASSERT_EQ(alignof(StructWith64ByteAlignment), 64U);
ASSERT_EQ(reinterpret_cast<uintptr_t>(s16) & 15U, 0U); ASSERT_EQ(reinterpret_cast<uintptr_t>(s16) & 15U, 0U);
ASSERT_EQ(reinterpret_cast<uintptr_t>(s32) & 31U, 0U); ASSERT_EQ(reinterpret_cast<uintptr_t>(s32) & 31U, 0U);
ASSERT_EQ(reinterpret_cast<uintptr_t>(s64) & 63U, 0U); ASSERT_EQ(reinterpret_cast<uintptr_t>(s64) & 63U, 0U);
marl::aligned_delete(s64); allocator->destroy(s64);
marl::aligned_delete(s32); allocator->destroy(s32);
marl::aligned_delete(s16); allocator->destroy(s16);
} }
\ No newline at end of file
TEST_F(AllocatorTest, Guards) {
marl::Allocation::Request request;
request.alignment = 16;
request.size = 16;
request.useGuards = true;
auto alloc = allocator->allocate(request);
auto ptr = reinterpret_cast<uint8_t*>(alloc.ptr);
EXPECT_DEATH(ptr[-1] = 1, "");
EXPECT_DEATH(ptr[marl::pageSize()] = 1, "");
}
...@@ -12,6 +12,20 @@ ...@@ -12,6 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "marl/sanitizers.h"
#ifndef MARL_USE_FIBER_STACK_GUARDS
#if !defined(NDEBUG) && !ADDRESS_SANITIZER_ENABLED
#define MARL_USE_FIBER_STACK_GUARDS 1
#else
#define MARL_USE_FIBER_STACK_GUARDS 0
#endif
#endif // MARL_USE_FIBER_STACK_GUARDS
#if MARL_USE_FIBER_STACK_GUARDS && ADDRESS_SANITIZER_ENABLED
#warning "ASAN can raise spurious failures when using mmap() allocated stacks"
#endif
#if defined(_WIN32) #if defined(_WIN32)
#include "osfiber_windows.h" #include "osfiber_windows.h"
#elif defined(MARL_FIBERS_USE_UCONTEXT) #elif defined(MARL_FIBERS_USE_UCONTEXT)
......
...@@ -36,6 +36,8 @@ ...@@ -36,6 +36,8 @@
#error "Unsupported target" #error "Unsupported target"
#endif #endif
#include "marl/memory.h"
#include <functional> #include <functional>
#include <memory> #include <memory>
...@@ -55,15 +57,21 @@ namespace marl { ...@@ -55,15 +57,21 @@ namespace marl {
class OSFiber { class OSFiber {
public: public:
inline OSFiber(Allocator*);
inline ~OSFiber();
// createFiberFromCurrentThread() returns a fiber created from the current // createFiberFromCurrentThread() returns a fiber created from the current
// thread. // thread.
static inline OSFiber* createFiberFromCurrentThread(); static inline Allocator::unique_ptr<OSFiber> createFiberFromCurrentThread(
Allocator* allocator);
// createFiber() returns a new fiber with the given stack size that will // createFiber() returns a new fiber with the given stack size that will
// call func when switched to. func() must end by switching back to another // call func when switched to. func() must end by switching back to another
// fiber, and must not return. // fiber, and must not return.
static inline OSFiber* createFiber(size_t stackSize, static inline Allocator::unique_ptr<OSFiber> createFiber(
const std::function<void()>& func); Allocator* allocator,
size_t stackSize,
const std::function<void()>& func);
// switchTo() immediately switches execution to the given fiber. // switchTo() immediately switches execution to the given fiber.
// switchTo() must be called on the currently executing fiber. // switchTo() must be called on the currently executing fiber.
...@@ -72,25 +80,46 @@ class OSFiber { ...@@ -72,25 +80,46 @@ class OSFiber {
private: private:
static inline void run(OSFiber* self); static inline void run(OSFiber* self);
Allocator* allocator;
marl_fiber_context context; marl_fiber_context context;
std::function<void()> target; std::function<void()> target;
std::unique_ptr<uint8_t[]> stack; Allocation stack;
}; };
OSFiber* OSFiber::createFiberFromCurrentThread() { OSFiber::OSFiber(Allocator* allocator) : allocator(allocator) {}
auto out = new OSFiber();
OSFiber::~OSFiber() {
if (stack.ptr != nullptr) {
allocator->free(stack);
}
}
Allocator::unique_ptr<OSFiber> OSFiber::createFiberFromCurrentThread(
Allocator* allocator) {
auto out = allocator->make_unique<OSFiber>(allocator);
out->context = {}; out->context = {};
return out; return out;
} }
OSFiber* OSFiber::createFiber(size_t stackSize, Allocator::unique_ptr<OSFiber> OSFiber::createFiber(
const std::function<void()>& func) { Allocator* allocator,
auto out = new OSFiber(); size_t stackSize,
const std::function<void()>& func) {
Allocation::Request request;
request.size = stackSize;
request.alignment = 16;
request.usage = Allocation::Usage::Stack;
#if MARL_USE_FIBER_STACK_GUARDS
request.useGuards = true;
#endif
auto out = allocator->make_unique<OSFiber>(allocator);
out->context = {}; out->context = {};
out->target = func; out->target = func;
out->stack = std::unique_ptr<uint8_t[]>(new uint8_t[stackSize]); out->stack = allocator->allocate(request);
marl_fiber_set_target(&out->context, out->stack.get(), stackSize, marl_fiber_set_target(&out->context, out->stack.ptr, stackSize,
reinterpret_cast<void (*)(void*)>(&OSFiber::run), out); reinterpret_cast<void (*)(void*)>(&OSFiber::run),
out.get());
return out; return out;
} }
......
...@@ -26,8 +26,9 @@ void marl_fiber_set_target(struct marl_fiber_context* ctx, ...@@ -26,8 +26,9 @@ void marl_fiber_set_target(struct marl_fiber_context* ctx,
void (*target)(void*), void (*target)(void*),
void* arg) { void* arg) {
uintptr_t stack_top = (uintptr_t)((uint8_t*)(stack) + stack_size); uintptr_t stack_top = (uintptr_t)((uint8_t*)(stack) + stack_size);
if ((stack_top % 16) != 0) if ((stack_top % 16) != 0) {
stack_top -= (stack_top % 16); stack_top -= (stack_top % 16);
}
// Write a backchain and subtract a minimum stack frame size (32) // Write a backchain and subtract a minimum stack frame size (32)
*(uintptr_t*)stack_top = 0; *(uintptr_t*)stack_top = 0;
......
...@@ -16,27 +16,23 @@ ...@@ -16,27 +16,23 @@
#include "marl_test.h" #include "marl_test.h"
TEST(WithoutBoundScheduler, OSFiber) { TEST_F(WithoutBoundScheduler, OSFiber) {
std::string str; std::string str;
auto constexpr fiberStackSize = 8 * 1024; auto constexpr fiberStackSize = 8 * 1024;
auto main = std::unique_ptr<marl::OSFiber>( auto main = marl::OSFiber::createFiberFromCurrentThread(allocator);
marl::OSFiber::createFiberFromCurrentThread()); marl::Allocator::unique_ptr<marl::OSFiber> fiberA, fiberB, fiberC;
std::unique_ptr<marl::OSFiber> fiberA, fiberB, fiberC; fiberC = marl::OSFiber::createFiber(allocator, fiberStackSize, [&] {
fiberC = std::unique_ptr<marl::OSFiber>( str += "C";
marl::OSFiber::createFiber(fiberStackSize, [&] { fiberC->switchTo(fiberB.get());
str += "C"; });
fiberC->switchTo(fiberB.get()); fiberB = marl::OSFiber::createFiber(allocator, fiberStackSize, [&] {
})); str += "B";
fiberB = std::unique_ptr<marl::OSFiber>( fiberB->switchTo(fiberA.get());
marl::OSFiber::createFiber(fiberStackSize, [&] { });
str += "B"; fiberA = marl::OSFiber::createFiber(allocator, fiberStackSize, [&] {
fiberB->switchTo(fiberA.get()); str += "A";
})); fiberA->switchTo(main.get());
fiberA = std::unique_ptr<marl::OSFiber>( });
marl::OSFiber::createFiber(fiberStackSize, [&] {
str += "A";
fiberA->switchTo(main.get());
}));
main->switchTo(fiberC.get()); main->switchTo(fiberC.get());
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#endif // !defined(_XOPEN_SOURCE) #endif // !defined(_XOPEN_SOURCE)
#include "marl/debug.h" #include "marl/debug.h"
#include "marl/memory.h"
#include <functional> #include <functional>
#include <memory> #include <memory>
...@@ -34,35 +35,53 @@ namespace marl { ...@@ -34,35 +35,53 @@ namespace marl {
class OSFiber { class OSFiber {
public: public:
inline OSFiber(Allocator*);
inline ~OSFiber();
// createFiberFromCurrentThread() returns a fiber created from the current // createFiberFromCurrentThread() returns a fiber created from the current
// thread. // thread.
static inline OSFiber* createFiberFromCurrentThread(); static inline Allocator::unique_ptr<OSFiber> createFiberFromCurrentThread(
Allocator* allocator);
// createFiber() returns a new fiber with the given stack size that will // createFiber() returns a new fiber with the given stack size that will
// call func when switched to. func() must end by switching back to another // call func when switched to. func() must end by switching back to another
// fiber, and must not return. // fiber, and must not return.
static inline OSFiber* createFiber(size_t stackSize, static inline Allocator::unique_ptr<OSFiber> createFiber(
const std::function<void()>& func); Allocator* allocator,
size_t stackSize,
const std::function<void()>& func);
// switchTo() immediately switches execution to the given fiber. // switchTo() immediately switches execution to the given fiber.
// switchTo() must be called on the currently executing fiber. // switchTo() must be called on the currently executing fiber.
inline void switchTo(OSFiber*); inline void switchTo(OSFiber*);
private: private:
std::unique_ptr<uint8_t[]> stack; Allocator* allocator;
ucontext_t context; ucontext_t context;
std::function<void()> target; std::function<void()> target;
Allocation stack;
}; };
OSFiber* OSFiber::createFiberFromCurrentThread() { OSFiber::OSFiber(Allocator* allocator) : allocator(allocator) {}
auto out = new OSFiber();
OSFiber::~OSFiber() {
if (stack.ptr != nullptr) {
allocator->free(stack);
}
}
Allocator::unique_ptr<OSFiber> OSFiber::createFiberFromCurrentThread(
Allocator* allocator) {
auto out = allocator->make_unique<OSFiber>(allocator);
out->context = {}; out->context = {};
getcontext(&out->context); getcontext(&out->context);
return out; return out;
} }
OSFiber* OSFiber::createFiber(size_t stackSize, Allocator::unique_ptr<OSFiber> OSFiber::createFiber(
const std::function<void()>& func) { Allocator* allocator,
size_t stackSize,
const std::function<void()>& func) {
union Args { union Args {
OSFiber* self; OSFiber* self;
struct { struct {
...@@ -82,21 +101,27 @@ OSFiber* OSFiber::createFiber(size_t stackSize, ...@@ -82,21 +101,27 @@ OSFiber* OSFiber::createFiber(size_t stackSize,
} }
}; };
auto out = new OSFiber(); Allocation::Request request;
request.size = stackSize;
request.alignment = 16;
request.usage = Allocation::Usage::Stack;
#if MARL_USE_FIBER_STACK_GUARDS
request.useGuards = true;
#endif
auto out = allocator->make_unique<OSFiber>(allocator);
out->context = {}; out->context = {};
out->stack = std::unique_ptr<uint8_t[]>(new uint8_t[stackSize]); out->stack = allocator->allocate(request);
out->target = func; out->target = func;
auto alignmentOffset =
15 - (reinterpret_cast<uintptr_t>(out->stack.get() + 15) & 15);
auto res = getcontext(&out->context); auto res = getcontext(&out->context);
MARL_ASSERT(res == 0, "getcontext() returned %d", int(res)); MARL_ASSERT(res == 0, "getcontext() returned %d", int(res));
out->context.uc_stack.ss_sp = out->stack.get() + alignmentOffset; out->context.uc_stack.ss_sp = out->stack.ptr;
out->context.uc_stack.ss_size = stackSize - alignmentOffset; out->context.uc_stack.ss_size = stackSize;
out->context.uc_link = nullptr; out->context.uc_link = nullptr;
Args args; Args args;
args.self = out; args.self = out.get();
makecontext(&out->context, reinterpret_cast<void (*)()>(&Target::Main), 2, makecontext(&out->context, reinterpret_cast<void (*)()>(&Target::Main), 2,
args.a, args.b); args.a, args.b);
......
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
// limitations under the License. // limitations under the License.
#include "marl/debug.h" #include "marl/debug.h"
#include "marl/memory.h"
#include <functional> #include <functional>
#include <memory> #include <memory>
#define WIN32_LEAN_AND_MEAN 1
#include <Windows.h> #include <Windows.h>
namespace marl { namespace marl {
...@@ -27,13 +29,16 @@ class OSFiber { ...@@ -27,13 +29,16 @@ class OSFiber {
// createFiberFromCurrentThread() returns a fiber created from the current // createFiberFromCurrentThread() returns a fiber created from the current
// thread. // thread.
static inline OSFiber* createFiberFromCurrentThread(); static inline Allocator::unique_ptr<OSFiber> createFiberFromCurrentThread(
Allocator* allocator);
// createFiber() returns a new fiber with the given stack size that will // createFiber() returns a new fiber with the given stack size that will
// call func when switched to. func() must end by switching back to another // call func when switched to. func() must end by switching back to another
// fiber, and must not return. // fiber, and must not return.
static inline OSFiber* createFiber(size_t stackSize, static inline Allocator::unique_ptr<OSFiber> createFiber(
const std::function<void()>& func); Allocator* allocator,
size_t stackSize,
const std::function<void()>& func);
// switchTo() immediately switches execution to the given fiber. // switchTo() immediately switches execution to the given fiber.
// switchTo() must be called on the currently executing fiber. // switchTo() must be called on the currently executing fiber.
...@@ -56,8 +61,9 @@ OSFiber::~OSFiber() { ...@@ -56,8 +61,9 @@ OSFiber::~OSFiber() {
} }
} }
OSFiber* OSFiber::createFiberFromCurrentThread() { Allocator::unique_ptr<OSFiber> OSFiber::createFiberFromCurrentThread(
auto out = new OSFiber(); Allocator* allocator) {
auto out = allocator->make_unique<OSFiber>();
out->fiber = ConvertThreadToFiber(nullptr); out->fiber = ConvertThreadToFiber(nullptr);
out->isFiberFromThread = true; out->isFiberFromThread = true;
MARL_ASSERT(out->fiber != nullptr, MARL_ASSERT(out->fiber != nullptr,
...@@ -66,10 +72,12 @@ OSFiber* OSFiber::createFiberFromCurrentThread() { ...@@ -66,10 +72,12 @@ OSFiber* OSFiber::createFiberFromCurrentThread() {
return out; return out;
} }
OSFiber* OSFiber::createFiber(size_t stackSize, Allocator::unique_ptr<OSFiber> OSFiber::createFiber(
const std::function<void()>& func) { Allocator* allocator,
auto out = new OSFiber(); size_t stackSize,
out->fiber = CreateFiber(stackSize, &OSFiber::run, out); const std::function<void()>& func) {
auto out = allocator->make_unique<OSFiber>();
out->fiber = CreateFiber(stackSize, &OSFiber::run, out.get());
out->target = func; out->target = func;
MARL_ASSERT(out->fiber != nullptr, "CreateFiber() failed with error 0x%x", MARL_ASSERT(out->fiber != nullptr, "CreateFiber() failed with error 0x%x",
int(GetLastError())); int(GetLastError()));
......
...@@ -69,8 +69,8 @@ void Scheduler::bind() { ...@@ -69,8 +69,8 @@ void Scheduler::bind() {
bound = this; bound = this;
{ {
std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex); std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
auto worker = std::unique_ptr<Worker>( auto worker =
new Worker(this, Worker::Mode::SingleThreaded, 0)); allocator->make_unique<Worker>(this, Worker::Mode::SingleThreaded, 0);
worker->start(); worker->start();
auto tid = std::this_thread::get_id(); auto tid = std::this_thread::get_id();
singleThreadedWorkers.emplace(tid, std::move(worker)); singleThreadedWorkers.emplace(tid, std::move(worker));
...@@ -79,7 +79,7 @@ void Scheduler::bind() { ...@@ -79,7 +79,7 @@ void Scheduler::bind() {
void Scheduler::unbind() { void Scheduler::unbind() {
MARL_ASSERT(bound != nullptr, "No scheduler bound"); MARL_ASSERT(bound != nullptr, "No scheduler bound");
std::unique_ptr<Worker> worker; Allocator::unique_ptr<Worker> worker;
{ {
std::unique_lock<std::mutex> lock(bound->singleThreadedWorkerMutex); std::unique_lock<std::mutex> lock(bound->singleThreadedWorkerMutex);
auto tid = std::this_thread::get_id(); auto tid = std::this_thread::get_id();
...@@ -94,7 +94,8 @@ void Scheduler::unbind() { ...@@ -94,7 +94,8 @@ void Scheduler::unbind() {
bound = nullptr; bound = nullptr;
} }
Scheduler::Scheduler() { Scheduler::Scheduler(Allocator* allocator /* = Allocator::Default */)
: allocator(allocator) {
for (size_t i = 0; i < spinningWorkers.size(); i++) { for (size_t i = 0; i < spinningWorkers.size(); i++) {
spinningWorkers[i] = -1; spinningWorkers[i] = -1;
} }
...@@ -135,10 +136,11 @@ void Scheduler::setWorkerThreadCount(int newCount) { ...@@ -135,10 +136,11 @@ void Scheduler::setWorkerThreadCount(int newCount) {
workerThreads[idx]->stop(); workerThreads[idx]->stop();
} }
for (int idx = oldCount - 1; idx >= newCount; idx--) { for (int idx = oldCount - 1; idx >= newCount; idx--) {
delete workerThreads[idx]; allocator->destroy(workerThreads[idx]);
} }
for (int idx = oldCount; idx < newCount; idx++) { for (int idx = oldCount; idx < newCount; idx++) {
workerThreads[idx] = new Worker(this, Worker::Mode::MultiThreaded, idx); workerThreads[idx] =
allocator->create<Worker>(this, Worker::Mode::MultiThreaded, idx);
} }
numWorkerThreads = newCount; numWorkerThreads = newCount;
for (int idx = oldCount; idx < newCount; idx++) { for (int idx = oldCount; idx < newCount; idx++) {
...@@ -198,15 +200,11 @@ void Scheduler::onBeginSpinning(int workerId) { ...@@ -198,15 +200,11 @@ void Scheduler::onBeginSpinning(int workerId) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Fiber // Fiber
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
Scheduler::Fiber::Fiber(OSFiber* impl, uint32_t id) Scheduler::Fiber::Fiber(Allocator::unique_ptr<OSFiber>&& impl, uint32_t id)
: id(id), impl(impl), worker(Scheduler::Worker::getCurrent()) { : id(id), impl(std::move(impl)), worker(Scheduler::Worker::getCurrent()) {
MARL_ASSERT(worker != nullptr, "No Scheduler::Worker bound"); MARL_ASSERT(worker != nullptr, "No Scheduler::Worker bound");
} }
Scheduler::Fiber::~Fiber() {
delete impl;
}
Scheduler::Fiber* Scheduler::Fiber::current() { Scheduler::Fiber* Scheduler::Fiber::current() {
auto worker = Scheduler::Worker::getCurrent(); auto worker = Scheduler::Worker::getCurrent();
return worker != nullptr ? worker->getCurrentFiber() : nullptr; return worker != nullptr ? worker->getCurrentFiber() : nullptr;
...@@ -223,18 +221,23 @@ void Scheduler::Fiber::yield() { ...@@ -223,18 +221,23 @@ void Scheduler::Fiber::yield() {
void Scheduler::Fiber::switchTo(Fiber* to) { void Scheduler::Fiber::switchTo(Fiber* to) {
if (to != this) { if (to != this) {
impl->switchTo(to->impl); impl->switchTo(to->impl.get());
} }
} }
Scheduler::Fiber* Scheduler::Fiber::create(uint32_t id, Allocator::unique_ptr<Scheduler::Fiber> Scheduler::Fiber::create(
size_t stackSize, Allocator* allocator,
const std::function<void()>& func) { uint32_t id,
return new Fiber(OSFiber::createFiber(stackSize, func), id); size_t stackSize,
const std::function<void()>& func) {
return allocator->make_unique<Fiber>(
OSFiber::createFiber(allocator, stackSize, func), id);
} }
Scheduler::Fiber* Scheduler::Fiber::createFromCurrentThread(uint32_t id) { Allocator::unique_ptr<Scheduler::Fiber>
return new Fiber(OSFiber::createFiberFromCurrentThread(), id); Scheduler::Fiber::createFromCurrentThread(Allocator* allocator, uint32_t id) {
return allocator->make_unique<Fiber>(
OSFiber::createFiberFromCurrentThread(allocator), id);
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
...@@ -257,7 +260,7 @@ void Scheduler::Worker::start() { ...@@ -257,7 +260,7 @@ void Scheduler::Worker::start() {
Scheduler::bound = scheduler; Scheduler::bound = scheduler;
Worker::current = this; Worker::current = this;
mainFiber.reset(Fiber::createFromCurrentThread(0)); mainFiber = Fiber::createFromCurrentThread(scheduler->allocator, 0);
currentFiber = mainFiber.get(); currentFiber = mainFiber.get();
run(); run();
mainFiber.reset(); mainFiber.reset();
...@@ -267,7 +270,7 @@ void Scheduler::Worker::start() { ...@@ -267,7 +270,7 @@ void Scheduler::Worker::start() {
case Mode::SingleThreaded: case Mode::SingleThreaded:
Worker::current = this; Worker::current = this;
mainFiber.reset(Fiber::createFromCurrentThread(0)); mainFiber = Fiber::createFromCurrentThread(scheduler->allocator, 0);
currentFiber = mainFiber.get(); currentFiber = mainFiber.get();
break; break;
...@@ -488,9 +491,11 @@ _Requires_lock_held_(lock) void Scheduler::Worker::runUntilIdle( ...@@ -488,9 +491,11 @@ _Requires_lock_held_(lock) void Scheduler::Worker::runUntilIdle(
Scheduler::Fiber* Scheduler::Worker::createWorkerFiber() { Scheduler::Fiber* Scheduler::Worker::createWorkerFiber() {
auto fiberId = static_cast<uint32_t>(workerFibers.size() + 1); auto fiberId = static_cast<uint32_t>(workerFibers.size() + 1);
auto fiber = Fiber::create(fiberId, FiberStackSize, [&] { run(); }); auto fiber = Fiber::create(scheduler->allocator, fiberId, FiberStackSize,
workerFibers.push_back(std::unique_ptr<Fiber>(fiber)); [&] { run(); });
return fiber; auto ptr = fiber.get();
workerFibers.push_back(std::move(fiber));
return ptr;
} }
void Scheduler::Worker::switchToFiber(Fiber* to) { void Scheduler::Worker::switchToFiber(Fiber* to) {
......
...@@ -20,12 +20,12 @@ ...@@ -20,12 +20,12 @@
#include <atomic> #include <atomic>
#include <unordered_set> #include <unordered_set>
TEST(WithoutBoundScheduler, SchedulerConstructAndDestruct) { TEST_F(WithoutBoundScheduler, SchedulerConstructAndDestruct) {
auto scheduler = new marl::Scheduler(); auto scheduler = new marl::Scheduler();
delete scheduler; delete scheduler;
} }
TEST(WithoutBoundScheduler, SchedulerBindGetUnbind) { TEST_F(WithoutBoundScheduler, SchedulerBindGetUnbind) {
auto scheduler = new marl::Scheduler(); auto scheduler = new marl::Scheduler();
scheduler->bind(); scheduler->bind();
auto got = marl::Scheduler::get(); auto got = marl::Scheduler::get();
...@@ -133,7 +133,7 @@ TEST_P(WithBoundScheduler, FibersResumeOnSameStdThread) { ...@@ -133,7 +133,7 @@ TEST_P(WithBoundScheduler, FibersResumeOnSameStdThread) {
} }
} }
TEST(WithoutBoundScheduler, TasksOnlyScheduledOnWorkerThreads) { TEST_F(WithoutBoundScheduler, TasksOnlyScheduledOnWorkerThreads) {
auto scheduler = std::unique_ptr<marl::Scheduler>(new marl::Scheduler()); auto scheduler = std::unique_ptr<marl::Scheduler>(new marl::Scheduler());
scheduler->bind(); scheduler->bind();
scheduler->setWorkerThreadCount(8); scheduler->setWorkerThreadCount(8);
......
...@@ -20,9 +20,7 @@ ...@@ -20,9 +20,7 @@
#include <cstdio> #include <cstdio>
#if defined(_WIN32) #if defined(_WIN32)
#ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN 1
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h> #include <windows.h>
#include <cstdlib> // mbstowcs #include <cstdlib> // mbstowcs
#elif defined(__APPLE__) #elif defined(__APPLE__)
......
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
#include "marl/waitgroup.h" #include "marl/waitgroup.h"
TEST(WithoutBoundScheduler, WaitGroupDone) { TEST_F(WithoutBoundScheduler, WaitGroupDone) {
marl::WaitGroup wg(2); // Should not require a scheduler. marl::WaitGroup wg(2); // Should not require a scheduler.
wg.done(); wg.done();
wg.done(); wg.done();
} }
#if MARL_DEBUG_ENABLED #if MARL_DEBUG_ENABLED
TEST(WithoutBoundScheduler, WaitGroupDoneTooMany) { TEST_F(WithoutBoundScheduler, WaitGroupDoneTooMany) {
marl::WaitGroup wg(2); // Should not require a scheduler. marl::WaitGroup wg(2); // Should not require a scheduler.
wg.done(); wg.done();
wg.done(); wg.done();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment