Commit 4892f939 by Ben Clayton

Yarn: Add Scheduler class

The scheduler is the core of Yarn, and has tight coupling with the OS-abstracted Scheduler::Fiber class. Added basic tests - this will be expanded in later changes. Bug: b/139010488 Change-Id: I562c61d3c4551c4347d9306a3dd87efed06e45a5 Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/34812Tested-by: 's avatarBen Clayton <bclayton@google.com> Kokoro-Presubmit: kokoro <noreply+kokoro@google.com> Reviewed-by: 's avatarNicolas Capens <nicolascapens@google.com>
parent 65a26b64
......@@ -14,6 +14,8 @@
#include "Debug.hpp"
#include "Scheduler.hpp"
#include <cstdlib>
#include <stdarg.h>
......@@ -33,7 +35,7 @@ void fatal(const char* msg, ...)
void assert_has_bound_scheduler(const char* feature)
{
// TODO
YARN_ASSERT(Scheduler::get() != nullptr, "%s requires a yarn::Scheduler to be bound", feature);
}
} // namespace yarn
// Copyright 2019 The SwiftShader Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "OSFiber.hpp" // Must come first. See OSFiber_ucontext.hpp.
#include "Scheduler.hpp"
#include "Debug.hpp"
#include "Defer.hpp"
#include "Thread.hpp"
#if defined(_WIN32)
#include <intrin.h> // __nop()
#endif
namespace
{
template <typename T>
inline T take(std::queue<T>& queue)
{
auto out = std::move(queue.front());
queue.pop();
return out;
}
inline void nop()
{
#if defined(_WIN32)
__nop();
#else
__asm__ __volatile__ ("nop");
#endif
}
} // anonymous namespace
namespace yarn {
////////////////////////////////////////////////////////////////////////////////
// Scheduler
////////////////////////////////////////////////////////////////////////////////
thread_local Scheduler* Scheduler::bound = nullptr;
Scheduler* Scheduler::get()
{
return bound;
}
void Scheduler::bind()
{
YARN_ASSERT(bound == nullptr, "Scheduler already bound");
bound = this;
{
std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
auto worker = std::unique_ptr<Worker>(new Worker(this, Worker::Mode::SingleThreaded, 0));
worker->start();
auto tid = std::this_thread::get_id();
singleThreadedWorkers.emplace(tid, std::move(worker));
}
}
void Scheduler::unbind()
{
YARN_ASSERT(bound != nullptr, "No scheduler bound");
std::unique_ptr<Worker> worker;
{
std::unique_lock<std::mutex> lock(bound->singleThreadedWorkerMutex);
auto tid = std::this_thread::get_id();
auto it = bound->singleThreadedWorkers.find(tid);
YARN_ASSERT(it != bound->singleThreadedWorkers.end(), "singleThreadedWorker not found");
worker = std::move(it->second);
bound->singleThreadedWorkers.erase(tid);
}
worker->flush();
worker->stop();
bound = nullptr;
}
Scheduler::Scheduler()
{
for (size_t i = 0; i < spinningWorkers.size(); i++)
{
spinningWorkers[i] = -1;
}
}
Scheduler::~Scheduler()
{
{
std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
YARN_ASSERT(singleThreadedWorkers.size() == 0, "Scheduler still bound on %d threads", int(singleThreadedWorkers.size()));
}
setWorkerThreadCount(0);
}
void Scheduler::setThreadInitializer(const std::function<void()>& func)
{
std::unique_lock<std::mutex> lock(threadInitFuncMutex);
threadInitFunc = func;
}
const std::function<void()>& Scheduler::getThreadInitializer()
{
std::unique_lock<std::mutex> lock(threadInitFuncMutex);
return threadInitFunc;
}
void Scheduler::setWorkerThreadCount(int newCount)
{
YARN_ASSERT(newCount >= 0, "count must be positive");
auto oldCount = numWorkerThreads;
for (int idx = oldCount - 1; idx >= newCount; idx--)
{
workerThreads[idx]->stop();
}
for (int idx = oldCount - 1; idx >= newCount; idx--)
{
delete workerThreads[idx];
}
for (int idx = oldCount; idx < newCount; idx++)
{
workerThreads[idx] = new Worker(this, Worker::Mode::MultiThreaded, idx);
}
numWorkerThreads = newCount;
for (int idx = oldCount; idx < newCount; idx++)
{
workerThreads[idx]->start();
}
}
int Scheduler::getWorkerThreadCount()
{
return numWorkerThreads;
}
void Scheduler::enqueue(Task&& task)
{
if (numWorkerThreads > 0)
{
while (true)
{
// Prioritize workers that have recently started spinning.
auto i = --nextSpinningWorkerIdx % spinningWorkers.size();
auto idx = spinningWorkers[i].exchange(-1);
if (idx < 0)
{
// If a spinning worker couldn't be found, round-robin the
// workers.
idx = nextEnqueueIndex++ % numWorkerThreads;
}
auto worker = workerThreads[idx];
if (worker->tryLock())
{
worker->enqueueAndUnlock(std::move(task));
return;
}
}
}
else
{
auto tid = std::this_thread::get_id();
std::unique_lock<std::mutex> lock(singleThreadedWorkerMutex);
auto it = singleThreadedWorkers.find(tid);
YARN_ASSERT(it != singleThreadedWorkers.end(), "singleThreadedWorker not found");
it->second->enqueue(std::move(task));
}
}
bool Scheduler::stealWork(Worker* thief, uint64_t from, Task& out)
{
if (numWorkerThreads > 0)
{
auto thread = workerThreads[from % numWorkerThreads];
if (thread != thief)
{
if (thread->dequeue(out))
{
return true;
}
}
}
return false;
}
void Scheduler::onBeginSpinning(int workerId)
{
auto idx = nextSpinningWorkerIdx++ % spinningWorkers.size();
spinningWorkers[idx] = workerId;
}
////////////////////////////////////////////////////////////////////////////////
// Fiber
////////////////////////////////////////////////////////////////////////////////
Scheduler::Fiber::Fiber(OSFiber* impl, uint32_t id) :
id(id), impl(impl), worker(Scheduler::Worker::getCurrent())
{
YARN_ASSERT(worker != nullptr, "No Scheduler::Worker bound");
}
Scheduler::Fiber::~Fiber()
{
delete impl;
}
Scheduler::Fiber* Scheduler::Fiber::current()
{
auto worker = Scheduler::Worker::getCurrent();
return worker != nullptr ? worker->getCurrentFiber() : nullptr;
}
void Scheduler::Fiber::schedule()
{
worker->enqueue(this);
}
void Scheduler::Fiber::yield()
{
worker->yield(this);
}
void Scheduler::Fiber::switchTo(Fiber* to)
{
if (to != this)
{
impl->switchTo(to->impl);
}
}
Scheduler::Fiber* Scheduler::Fiber::create(uint32_t id, size_t stackSize, const std::function<void()>& func)
{
return new Fiber(OSFiber::createFiber(stackSize, func), id);
}
Scheduler::Fiber* Scheduler::Fiber::createFromCurrentThread(uint32_t id)
{
return new Fiber(OSFiber::createFiberFromCurrentThread(), id);
}
////////////////////////////////////////////////////////////////////////////////
// Scheduler::Worker
////////////////////////////////////////////////////////////////////////////////
thread_local Scheduler::Worker* Scheduler::Worker::current = nullptr;
Scheduler::Worker::Worker(Scheduler *scheduler, Mode mode, uint32_t id) : id(id), mode(mode), scheduler(scheduler) {}
void Scheduler::Worker::start()
{
switch (mode)
{
case Mode::MultiThreaded:
thread = std::thread([=]
{
Thread::setName("Thread<%.2d>", int(id));
if (auto const &initFunc = scheduler->getThreadInitializer())
{
initFunc();
}
Scheduler::bound = scheduler;
Worker::current = this;
mainFiber.reset(Fiber::createFromCurrentThread(0));
currentFiber = mainFiber.get();
run();
mainFiber.reset();
Worker::current = nullptr;
});
break;
case Mode::SingleThreaded:
Worker::current = this;
mainFiber.reset(Fiber::createFromCurrentThread(0));
currentFiber = mainFiber.get();
break;
default:
YARN_ASSERT(false, "Unknown mode: %d", int(mode));
}
}
void Scheduler::Worker::stop()
{
switch (mode)
{
case Mode::MultiThreaded:
shutdown = true;
enqueue([]{}); // Ensure the worker is woken up to notice the shutdown.
thread.join();
break;
case Mode::SingleThreaded:
Worker::current = nullptr;
break;
default:
YARN_ASSERT(false, "Unknown mode: %d", int(mode));
}
}
void Scheduler::Worker::yield(Fiber *from)
{
YARN_ASSERT(currentFiber == from, "Attempting to call yield from a non-current fiber");
// Current fiber is yielding as it is blocked.
// First wait until there's something else this worker can do.
std::unique_lock<std::mutex> lock(work.mutex);
waitForWork(lock);
if (work.fibers.size() > 0)
{
// There's another fiber that has become unblocked, resume that.
work.num--;
auto to = take(work.fibers);
lock.unlock();
switchToFiber(to);
}
else if (idleFibers.size() > 0)
{
// There's an old fiber we can reuse, resume that.
auto to = take(idleFibers);
lock.unlock();
switchToFiber(to);
}
else
{
// Tasks to process and no existing fibers to resume. Spawn a new fiber.
lock.unlock();
switchToFiber(createWorkerFiber());
}
}
bool Scheduler::Worker::tryLock()
{
return work.mutex.try_lock();
}
void Scheduler::Worker::enqueue(Fiber* fiber)
{
std::unique_lock<std::mutex> lock(work.mutex);
auto wasIdle = work.num == 0;
work.fibers.push(std::move(fiber));
work.num++;
lock.unlock();
if (wasIdle) { work.added.notify_one(); }
}
void Scheduler::Worker::enqueue(Task&& task)
{
work.mutex.lock();
enqueueAndUnlock(std::move(task));
}
void Scheduler::Worker::enqueueAndUnlock(Task&& task)
{
auto wasIdle = work.num == 0;
work.tasks.push(std::move(task));
work.num++;
work.mutex.unlock();
if (wasIdle) { work.added.notify_one(); }
}
bool Scheduler::Worker::dequeue(Task& out)
{
if (work.num.load() == 0) { return false; }
if (!work.mutex.try_lock()) { return false; }
defer(work.mutex.unlock());
if (work.tasks.size() == 0) { return false; }
work.num--;
out = take(work.tasks);
return true;
}
void Scheduler::Worker::flush()
{
YARN_ASSERT(mode == Mode::SingleThreaded, "flush() can only be used on a single-threaded worker");
std::unique_lock<std::mutex> lock(work.mutex);
runUntilIdle(lock);
}
void Scheduler::Worker::run()
{
switch (mode)
{
case Mode::MultiThreaded:
{
{
std::unique_lock<std::mutex> lock(work.mutex);
work.added.wait(lock, [this] { return work.num > 0 || shutdown; });
while (!shutdown)
{
waitForWork(lock);
runUntilIdle(lock);
}
Worker::current = nullptr;
}
switchToFiber(mainFiber.get());
break;
}
case Mode::SingleThreaded:
while (!shutdown)
{
flush();
idleFibers.emplace(currentFiber);
switchToFiber(mainFiber.get());
}
break;
default:
YARN_ASSERT(false, "Unknown mode: %d", int(mode));
}
}
_Requires_lock_held_(lock)
void Scheduler::Worker::waitForWork(std::unique_lock<std::mutex> &lock)
{
YARN_ASSERT(work.num == work.fibers.size() + work.tasks.size(), "work.num out of sync");
if (work.num == 0)
{
scheduler->onBeginSpinning(id);
lock.unlock();
spinForWork();
lock.lock();
}
work.added.wait(lock, [this] { return work.num > 0 || shutdown; });
}
void Scheduler::Worker::spinForWork()
{
Task stolen;
constexpr auto duration = std::chrono::milliseconds(1);
auto start = std::chrono::high_resolution_clock::now();
while (std::chrono::high_resolution_clock::now() - start < duration)
{
for (int i = 0; i < 256; i++) // Empirically picked magic number!
{
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
if (work.num > 0)
{
return;
}
}
if (scheduler->stealWork(this, rng(), stolen))
{
std::unique_lock<std::mutex> lock(work.mutex);
work.tasks.emplace(std::move(stolen));
work.num++;
return;
}
std::this_thread::yield();
}
}
_Requires_lock_held_(lock)
void Scheduler::Worker::runUntilIdle(std::unique_lock<std::mutex> &lock)
{
YARN_ASSERT(work.num == work.fibers.size() + work.tasks.size(), "work.num out of sync");
while (work.fibers.size() > 0 || work.tasks.size() > 0)
{
// Note: we cannot take and store on the stack more than a single fiber
// or task at a time, as the Fiber may yield and these items may get
// held on suspended fiber stack.
while (work.fibers.size() > 0)
{
work.num--;
auto fiber = take(work.fibers);
lock.unlock();
idleFibers.push(currentFiber);
switchToFiber(fiber);
lock.lock();
}
if (work.tasks.size() > 0)
{
work.num--;
auto task = take(work.tasks);
lock.unlock();
// Run the task.
task();
// std::function<> can carry arguments with complex destructors.
// Ensure these are destructed outside of the lock.
task = Task();
lock.lock();
}
}
}
Scheduler::Fiber* Scheduler::Worker::createWorkerFiber()
{
auto id = workerFibers.size() + 1;
auto fiber = Fiber::create(id, FiberStackSize, [&] { run(); });
workerFibers.push_back(std::unique_ptr<Fiber>(fiber));
return fiber;
}
void Scheduler::Worker::switchToFiber(Fiber* to)
{
auto from = currentFiber;
currentFiber = to;
from->switchTo(to);
}
} // namespace yarn
// Copyright 2019 The SwiftShader Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef yarn_scheduler_hpp
#define yarn_scheduler_hpp
#include "Debug.hpp"
#include "SAL.hpp"
#include <array>
#include <atomic>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <unordered_map>
namespace yarn {
class OSFiber;
// Task is a unit of work for the scheduler.
using Task = std::function<void()>;
// Scheduler asynchronously processes Tasks.
// A scheduler can be bound to one or more threads using the bind() method.
// Once bound to a thread, that thread can call yarn::schedule() to enqueue
// work tasks to be executed asynchronously.
// Scheduler are initially constructed in single-threaded mode.
// Call setWorkerThreadCount() to spawn dedicated worker threads.
class Scheduler
{
class Worker;
public:
Scheduler();
~Scheduler();
// get() returns the scheduler bound to the current thread.
static Scheduler *get();
// bind() binds this scheduler to the current thread.
// There must be no existing scheduler bound to the thread prior to calling.
void bind();
// unbind() unbinds the scheduler currently bound to the current thread.
// There must be a existing scheduler bound to the thread prior to calling.
static void unbind();
// enqueue() queues the task for asynchronous execution.
void enqueue(Task&& task);
// setThreadInitializer() sets the worker thread initializer function which
// will be called for each new worker thread spawned.
// The initializer will only be called on newly created threads (call
// setThreadInitializer() before setWorkerThreadCount()).
void setThreadInitializer(const std::function<void()>& init);
// getThreadInitializer() returns the thread initializer function set by
// setThreadInitializer().
const std::function<void()>& getThreadInitializer();
// setWorkerThreadCount() adjusts the number of dedicated worker threads.
// A count of 0 puts the scheduler into single-threaded mode.
// Note: Currently the number of threads cannot be adjusted once tasks
// have been enqueued. This restriction may be lifted at a later time.
void setWorkerThreadCount(int count);
// getWorkerThreadCount() returns the number of worker threads.
int getWorkerThreadCount();
// Fibers expose methods to perform cooperative multitasking and are
// automatically created by the Scheduler.
//
// The currently executing Fiber can be obtained by calling Fiber::current().
//
// When execution becomes blocked, yield() can be called to suspend execution of
// the fiber and start executing other pending work.
// Once the block has been lifted, schedule() can be called to reschedule the
// Fiber on the same thread that previously executed it.
class Fiber
{
public:
~Fiber();
// current() returns the currently executing fiber, or nullptr if called
// without a bound scheduler.
static Fiber* current();
// yield() suspends execution of this Fiber, allowing the thread to work
// on other tasks.
// yield() must only be called on the currently executing fiber.
void yield();
// schedule() reschedules the suspended Fiber for execution.
void schedule();
// id is the thread-unique identifier of the Fiber.
uint32_t const id;
private:
friend class Scheduler;
Fiber(OSFiber*, uint32_t id);
// switchTo() switches execution to the given fiber.
// switchTo() must only be called on the currently executing fiber.
void switchTo(Fiber*);
// create() constructs and returns a new fiber with the given identifier,
// stack size that will executed func when switched to.
static Fiber* create(uint32_t id, size_t stackSize, const std::function<void()>& func);
// createFromCurrentThread() constructs and returns a new fiber with the
// given identifier for the current thread.
static Fiber* createFromCurrentThread(uint32_t id);
OSFiber* const impl;
Worker* const worker;
};
private:
// Stack size in bytes of a new fiber.
// TODO: Make configurable so the default size can be reduced.
static constexpr size_t FiberStackSize = 1024 * 1024;
// Maximum number of worker threads.
static constexpr size_t MaxWorkerThreads = 64;
// TODO: Implement a queue that recycles elements to reduce number of
// heap allocations.
using TaskQueue = std::queue<Task>;
using FiberQueue = std::queue<Fiber*>;
// Workers executes Tasks on a single thread.
// Once a task is started, it may yield to other tasks on the same Worker.
// Tasks are always resumed by the same Worker.
class Worker
{
public:
enum class Mode
{
// Worker will spawn a background thread to process tasks.
MultiThreaded,
// Worker will execute tasks whenever it yields.
SingleThreaded,
};
Worker(Scheduler *scheduler, Mode mode, uint32_t id);
// start() begins execution of the worker.
void start();
// stop() ceases execution of the worker, blocking until all pending
// tasks have fully finished.
void stop();
// yield() suspends execution of the current task, and looks for other
// tasks to start or continue execution.
void yield(Fiber* fiber);
// enqueue(Fiber*) enqueues resuming of a suspended fiber.
void enqueue(Fiber* fiber);
// enqueue(Task&&) enqueues a new, unstarted task.
void enqueue(Task&& task);
// tryLock() attempts to lock the worker for task enqueing.
// If the lock was successful then true is returned, and the caller must
// call enqueueAndUnlock().
bool tryLock();
// enqueueAndUnlock() enqueues the task and unlocks the worker.
// Must only be called after a call to tryLock() which returned true.
void enqueueAndUnlock(Task&& task);
// flush() processes all pending tasks before returning.
void flush();
// dequeue() attempts to take a Task from the worker. Returns true if
// a task was taken and assigned to out, otherwise false.
bool dequeue(Task& out);
// getCurrent() returns the Worker currently bound to the current
// thread.
static inline Worker* getCurrent();
// getCurrentFiber() returns the Fiber currently being executed.
inline Fiber* getCurrentFiber() const;
// Unique identifier of the Worker.
const uint32_t id;
private:
// run() is the task processing function for the worker.
// If the worker was constructed in Mode::MultiThreaded, run() will
// continue to process tasks until stop() is called.
// If the worker was constructed in Mode::SingleThreaded, run() call
// flush() and return.
void run();
// createWorkerFiber() creates a new fiber that when executed calls
// run().
Fiber* createWorkerFiber();
// switchToFiber() switches execution to the given fiber. The fiber
// must belong to this worker.
void switchToFiber(Fiber*);
// runUntilIdle() executes all pending tasks and then returns.
_Requires_lock_held_(lock)
void runUntilIdle(std::unique_lock<std::mutex> &lock);
// waitForWork() blocks until new work is available, potentially calling
// spinForWork().
_Requires_lock_held_(lock)
void waitForWork(std::unique_lock<std::mutex> &lock);
// spinForWork() attempts to steal work from another Worker, and keeps
// the thread awake for a short duration. This reduces overheads of
// frequently putting the thread to sleep and re-waking.
void spinForWork();
// Work holds tasks and fibers that are enqueued on the Worker.
struct Work
{
std::atomic<uint64_t> num = { 0 }; // tasks.size() + fibers.size()
TaskQueue tasks; // guarded by mutex
FiberQueue fibers; // guarded by mutex
std::condition_variable added;
std::mutex mutex;
};
// https://en.wikipedia.org/wiki/Xorshift
class FastRnd
{
public:
inline uint64_t operator ()()
{
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
return x;
}
private:
uint64_t x = std::chrono::system_clock::now().time_since_epoch().count();
};
// The current worker bound to the current thread.
static thread_local Worker* current;
Mode const mode;
Scheduler* const scheduler;
std::unique_ptr<Fiber> mainFiber;
Fiber* currentFiber = nullptr;
std::thread thread;
Work work;
FiberQueue idleFibers; // Fibers that have completed which can be reused.
std::vector<std::unique_ptr<Fiber>> workerFibers; // All fibers created by this worker.
FastRnd rng;
std::atomic<bool> shutdown = { false };
};
// stealWork() attempts to steal a task from the worker with the given id.
// Returns true if a task was stolen and assigned to out, otherwise false.
bool stealWork(Worker* thief, uint64_t from, Task& out);
// onBeginSpinning() is called when a Worker calls spinForWork().
// The scheduler will prioritize this worker for new tasks to try to prevent
// it going to sleep.
void onBeginSpinning(int workerId);
// The scheduler currently bound to the current thread.
static thread_local Scheduler* bound;
std::function<void()> threadInitFunc;
std::mutex threadInitFuncMutex;
std::array<std::atomic<int>, 8> spinningWorkers;
std::atomic<unsigned int> nextSpinningWorkerIdx = { 0x8000000 };
// TODO: Make this lot thread-safe so setWorkerThreadCount() can be called
// during execution of tasks.
unsigned int nextEnqueueIndex = 0;
unsigned int numWorkerThreads = 0;
std::array<Worker*, MaxWorkerThreads> workerThreads;
std::mutex singleThreadedWorkerMutex;
std::unordered_map<std::thread::id, std::unique_ptr<Worker>> singleThreadedWorkers;
};
Scheduler::Worker* Scheduler::Worker::getCurrent()
{
return Worker::current;
}
Scheduler::Fiber* Scheduler::Worker::getCurrentFiber() const
{
return currentFiber;
}
// schedule() schedules the function f to be asynchronously called with the
// given arguments using the currently bound scheduler.
template<typename Function, typename ... Args>
inline void schedule(Function&& f, Args&& ... args)
{
YARN_ASSERT_HAS_BOUND_SCHEDULER("yarn::schedule");
auto scheduler = Scheduler::get();
scheduler->enqueue(std::bind(std::forward<Function>(f), std::forward<Args>(args)...));
}
// schedule() schedules the function f to be asynchronously called using the
// currently bound scheduler.
template<typename Function>
inline void schedule(Function&& f)
{
YARN_ASSERT_HAS_BOUND_SCHEDULER("yarn::schedule");
auto scheduler = Scheduler::get();
scheduler->enqueue(std::forward<Function>(f));
}
} // namespace yarn
#endif // yarn_scheduler_hpp
// Copyright 2019 The SwiftShader Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "Yarn_test.hpp"
TEST(WithoutBoundScheduler, SchedulerConstructAndDestruct)
{
auto scheduler = new yarn::Scheduler();
delete scheduler;
}
TEST(WithoutBoundScheduler, SchedulerBindGetUnbind)
{
auto scheduler = new yarn::Scheduler();
scheduler->bind();
auto got = yarn::Scheduler::get();
ASSERT_EQ(scheduler, got);
scheduler->unbind();
got = yarn::Scheduler::get();
ASSERT_EQ(got, nullptr);
delete scheduler;
}
TEST_P(WithBoundScheduler, SetAndGetWorkerThreadCount)
{
ASSERT_EQ(yarn::Scheduler::get()->getWorkerThreadCount(), GetParam().numWorkerThreads);
}
TEST_P(WithBoundScheduler, DestructWithPendingTasks)
{
for (int i = 0; i < 10000; i++)
{
yarn::schedule([] {});
}
}
......@@ -14,8 +14,17 @@
#include "Yarn_test.hpp"
INSTANTIATE_TEST_SUITE_P(SchedulerParams, WithBoundScheduler, testing::Values(
SchedulerParams{0}, // Single-threaded mode test
SchedulerParams{1}, // Single worker thread
SchedulerParams{2}, // 2 worker threads...
SchedulerParams{4},
SchedulerParams{8},
SchedulerParams{64}
));
int main(int argc, char **argv)
{
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
......@@ -15,4 +15,41 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "Yarn/Scheduler.hpp"
// SchedulerParams holds Scheduler construction parameters for testing.
struct SchedulerParams
{
int numWorkerThreads;
friend std::ostream& operator<<(std::ostream& os, const SchedulerParams& params) {
return os << "SchedulerParams{" <<
"numWorkerThreads: " << params.numWorkerThreads <<
"}";
}
};
// WithoutBoundScheduler is a test fixture that does not bind a scheduler.
class WithoutBoundScheduler : public testing::Test {};
// WithBoundScheduler is a parameterized test fixture that performs tests with
// a bound scheduler using a number of different configurations.
class WithBoundScheduler : public testing::TestWithParam<SchedulerParams>
{
public:
void SetUp() override
{
auto &params = GetParam();
auto scheduler = new yarn::Scheduler();
scheduler->bind();
scheduler->setWorkerThreadCount(params.numWorkerThreads);
}
void TearDown() override
{
auto scheduler = yarn::Scheduler::get();
scheduler->unbind();
delete scheduler;
}
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment