Commit 16da2813 by Ben Clayton

LLVMReactor: Lazily promote functions to coroutines.

Coroutines have a larger performance overhead to regular functions. If a rr::Coroutine does not call Yield(), then keep it as a regular function, and stub `coroutine_await` and `coroutine_destroy`. Bug: b/135609394 Bug: b/137167505 Change-Id: Id890e86eb3602f9cd9ed4e68e50c68a98705c753 Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/33950Reviewed-by: 's avatarNicolas Capens <nicolascapens@google.com> Reviewed-by: 's avatarChris Forbes <chrisforbes@google.com> Tested-by: 's avatarBen Clayton <bclayton@google.com> Kokoro-Presubmit: kokoro <noreply+kokoro@google.com>
parent fc7bb8f0
......@@ -420,6 +420,8 @@ namespace
llvm::Value *handle = nullptr;
llvm::Value *id = nullptr;
llvm::Value *promise = nullptr;
llvm::Type *yieldType = nullptr;
llvm::BasicBlock *entryBlock = nullptr;
llvm::BasicBlock *suspendBlock = nullptr;
llvm::BasicBlock *endBlock = nullptr;
llvm::BasicBlock *destroyBlock = nullptr;
......@@ -4492,22 +4494,19 @@ namespace {
SuspendActionDestroy = 1
};
} // anonymous namespace
namespace rr {
void Nucleus::createCoroutine(Type *YieldType, std::vector<Type*> &Params)
void promoteFunctionToCoroutine()
{
ASSERT(jit->coroutine.id == nullptr);
// Types
auto voidTy = ::llvm::Type::getVoidTy(jit->context);
auto i1Ty = ::llvm::Type::getInt1Ty(jit->context);
auto i8Ty = ::llvm::Type::getInt8Ty(jit->context);
auto i32Ty = ::llvm::Type::getInt32Ty(jit->context);
auto i8PtrTy = ::llvm::Type::getInt8PtrTy(jit->context);
auto promiseTy = T(YieldType);
auto promiseTy = jit->coroutine.yieldType;
auto promisePtrTy = promiseTy->getPointerTo();
auto handleTy = i8PtrTy;
auto boolTy = i1Ty;
// LLVM intrinsics
auto coro_id = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::coro_id);
......@@ -4526,6 +4525,8 @@ void Nucleus::createCoroutine(Type *YieldType, std::vector<Type*> &Params)
auto freeFrameTy = ::llvm::FunctionType::get(voidTy, {i8PtrTy}, false);
auto freeFrame = jit->module->getOrInsertFunction("coroutine_free_frame", freeFrameTy);
auto oldInsertionPoint = jit->builder->saveIP();
// Build the coroutine_await() function:
//
// bool coroutine_await(CoroutineHandle* handle, YieldType* out)
......@@ -4542,7 +4543,6 @@ void Nucleus::createCoroutine(Type *YieldType, std::vector<Type*> &Params)
// }
// }
//
jit->coroutine.await = rr::createFunction("coroutine_await", boolTy, {handleTy, promisePtrTy});
{
auto args = jit->coroutine.await->arg_begin();
auto handle = args++;
......@@ -4573,7 +4573,6 @@ void Nucleus::createCoroutine(Type *YieldType, std::vector<Type*> &Params)
// llvm.coro.destroy(handle);
// }
//
jit->coroutine.destroy = rr::createFunction("coroutine_destroy", voidTy, {handleTy});
{
auto handle = jit->coroutine.destroy->arg_begin();
jit->builder->SetInsertPoint(llvm::BasicBlock::Create(jit->context, "", jit->coroutine.destroy));
......@@ -4613,20 +4612,17 @@ void Nucleus::createCoroutine(Type *YieldType, std::vector<Type*> &Params)
// return handle;
// }
//
jit->function = rr::createFunction("coroutine_begin", handleTy, T(Params));
#ifdef ENABLE_RR_DEBUG_INFO
jit->debugInfo = std::unique_ptr<DebugInfo>(new DebugInfo(jit->builder, jit->context, jit->module, jit->function));
jit->debugInfo = std::unique_ptr<rr::DebugInfo>(new rr::DebugInfo(jit->builder.get(), &jit->context, jit->module.get(), jit->function));
#endif // ENABLE_RR_DEBUG_INFO
auto entryBlock = llvm::BasicBlock::Create(jit->context, "coroutine", jit->function);
jit->coroutine.suspendBlock = llvm::BasicBlock::Create(jit->context, "suspend", jit->function);
jit->coroutine.endBlock = llvm::BasicBlock::Create(jit->context, "end", jit->function);
jit->coroutine.destroyBlock = llvm::BasicBlock::Create(jit->context, "destroy", jit->function);
jit->builder->SetInsertPoint(entryBlock);
Variable::materializeAll();
jit->coroutine.promise = jit->builder->CreateAlloca(T(YieldType), nullptr, "promise");
jit->builder->SetInsertPoint(jit->coroutine.entryBlock, jit->coroutine.entryBlock->begin());
jit->coroutine.promise = jit->builder->CreateAlloca(promiseTy, nullptr, "promise");
jit->coroutine.id = jit->builder->CreateCall(coro_id, {
::llvm::ConstantInt::get(i32Ty, 0),
jit->builder->CreatePointerCast(jit->coroutine.promise, i8PtrTy),
......@@ -4658,13 +4654,45 @@ void Nucleus::createCoroutine(Type *YieldType, std::vector<Type*> &Params)
jit->builder->CreateCall(freeFrame, {memory});
jit->builder->CreateBr(jit->coroutine.suspendBlock);
// Switch back to the entry block for reactor codegen.
jit->builder->SetInsertPoint(entryBlock);
// Switch back to original insert point to continue building the coroutine.
jit->builder->restoreIP(oldInsertionPoint);
}
} // anonymous namespace
namespace rr {
void Nucleus::createCoroutine(Type *YieldType, std::vector<Type*> &Params)
{
// Coroutines are initially created as a regular function.
// Upon the first call to Yield(), the function is promoted to a true
// coroutine.
auto voidTy = ::llvm::Type::getVoidTy(jit->context);
auto i1Ty = ::llvm::Type::getInt1Ty(jit->context);
auto i8PtrTy = ::llvm::Type::getInt8PtrTy(jit->context);
auto handleTy = i8PtrTy;
auto boolTy = i1Ty;
auto promiseTy = T(YieldType);
auto promisePtrTy = promiseTy->getPointerTo();
jit->function = rr::createFunction("coroutine_begin", handleTy, T(Params));
jit->coroutine.await = rr::createFunction("coroutine_await", boolTy, {handleTy, promisePtrTy});
jit->coroutine.destroy = rr::createFunction("coroutine_destroy", voidTy, {handleTy});
jit->coroutine.yieldType = promiseTy;
jit->coroutine.entryBlock = llvm::BasicBlock::Create(jit->context, "function", jit->function);
jit->builder->SetInsertPoint(jit->coroutine.entryBlock);
}
void Nucleus::yield(Value* val)
{
ASSERT_MSG(jit->coroutine.id != nullptr, "yield() can only be called when building a Coroutine");
if (jit->coroutine.id == nullptr)
{
// First call to yield().
// Promote the function to a full coroutine.
promoteFunctionToCoroutine();
ASSERT(jit->coroutine.id != nullptr);
}
// promise = val;
//
......@@ -4710,9 +4738,24 @@ void Nucleus::yield(Value* val)
Routine* Nucleus::acquireCoroutine(const char *name, const Config::Edit &cfgEdit /* = Config::Edit::None */)
{
ASSERT_MSG(jit->coroutine.id != nullptr, "acquireCoroutine() called without a call to createCoroutine()");
jit->builder->CreateBr(jit->coroutine.endBlock);
bool isCoroutine = jit->coroutine.id != nullptr;
if (isCoroutine)
{
jit->builder->CreateBr(jit->coroutine.endBlock);
}
else
{
// Coroutine without a Yield acts as a regular function.
// The 'coroutine_begin' function returns a nullptr for the coroutine
// handle.
jit->builder->CreateRet(llvm::Constant::getNullValue(jit->function->getReturnType()));
// The 'coroutine_await' function always returns false (coroutine done).
jit->builder->SetInsertPoint(llvm::BasicBlock::Create(jit->context, "", jit->coroutine.await));
jit->builder->CreateRet(llvm::Constant::getNullValue(jit->coroutine.await->getReturnType()));
// The 'coroutine_destroy' does nothing, returns void.
jit->builder->SetInsertPoint(llvm::BasicBlock::Create(jit->context, "", jit->coroutine.destroy));
jit->builder->CreateRetVoid();
}
#ifdef ENABLE_RR_DEBUG_INFO
if (jit->debugInfo != nullptr)
......@@ -4728,14 +4771,25 @@ Routine* Nucleus::acquireCoroutine(const char *name, const Config::Edit &cfgEdit
jit->module->print(file, 0);
}
// Run manadory coroutine transforms.
llvm::legacy::PassManager pm;
pm.add(llvm::createCoroEarlyPass());
pm.add(llvm::createCoroSplitPass());
pm.add(llvm::createCoroElidePass());
pm.add(llvm::createBarrierNoopPass());
pm.add(llvm::createCoroCleanupPass());
pm.run(*jit->module);
if (isCoroutine)
{
// Run manadory coroutine transforms.
llvm::legacy::PassManager pm;
pm.add(llvm::createCoroEarlyPass());
pm.add(llvm::createCoroSplitPass());
pm.add(llvm::createCoroElidePass());
pm.add(llvm::createBarrierNoopPass());
pm.add(llvm::createCoroCleanupPass());
pm.run(*jit->module);
}
#if defined(ENABLE_RR_LLVM_IR_VERIFICATION) || !defined(NDEBUG)
{
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(*jit->module);
}
#endif // defined(ENABLE_RR_LLVM_IR_VERIFICATION) || !defined(NDEBUG)
auto cfg = cfgEdit.apply(jit->config);
jit->optimize(cfg);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment