Commit b34275e4 by Nicolas Capens Committed by Nicolas Capens

Emulate gather/scatter for MSan builds

MemorySanitizer doesn't support instrumenting masked_gather and masked_scatter LLVM intrinsics. Its visitIntrinsicInst() method ends up calling handleUnknownIntrinsic(), which silently doesn't handle it and subsequently visitInstruction() checks all operands for poisoned bits. In the case of a scatter, a 0 bit in the mask means the corresponding element doesn't get written, so it doesn't matter if it's uninitialized data. The current implementation leads to false positives. Work around it by emulating gather and scatter as element-wise loads and stores. This can be correctly instrumented by MemorySanitizer. Note this change has no effect currently since we don't support MSan instrumentation for Reactor yet. We just unpoison all stores. Previously we did that in element-wise manner after the intrinsic executes. Now it's done as part of the element stores. Bug: b/155148722 Change-Id: I9058cd926667fb6df5d9626bc87fb2d0a596771b Reviewed-on: https://swiftshader-review.googlesource.com/c/SwiftShader/+/49809Tested-by: 's avatarNicolas Capens <nicolascapens@google.com> Reviewed-by: 's avatarAntonio Maiorano <amaiorano@google.com>
parent 782dbebf
...@@ -355,84 +355,6 @@ llvm::Value *lowerMulHigh(llvm::Value *x, llvm::Value *y, bool sext) ...@@ -355,84 +355,6 @@ llvm::Value *lowerMulHigh(llvm::Value *x, llvm::Value *y, bool sext)
return jit->builder->CreateTrunc(mulh, ty); return jit->builder->CreateTrunc(mulh, ty);
} }
llvm::Value *createGather(llvm::Value *base, llvm::Type *elTy, llvm::Value *offsets, llvm::Value *mask, unsigned int alignment, bool zeroMaskedLanes)
{
ASSERT(base->getType()->isPointerTy());
ASSERT(offsets->getType()->isVectorTy());
ASSERT(mask->getType()->isVectorTy());
auto numEls = llvm::cast<llvm::VectorType>(mask->getType())->getNumElements();
auto i1Ty = ::llvm::Type::getInt1Ty(jit->context);
auto i32Ty = ::llvm::Type::getInt32Ty(jit->context);
auto i8Ty = ::llvm::Type::getInt8Ty(jit->context);
auto i8PtrTy = i8Ty->getPointerTo();
auto elPtrTy = elTy->getPointerTo();
auto elVecTy = ::llvm::VectorType::get(elTy, numEls, false);
auto elPtrVecTy = ::llvm::VectorType::get(elPtrTy, numEls, false);
auto i8Base = jit->builder->CreatePointerCast(base, i8PtrTy);
auto i8Ptrs = jit->builder->CreateGEP(i8Base, offsets);
auto elPtrs = jit->builder->CreatePointerCast(i8Ptrs, elPtrVecTy);
auto i8Mask = jit->builder->CreateIntCast(mask, ::llvm::VectorType::get(i1Ty, numEls, false), false); // vec<int, int, ...> -> vec<bool, bool, ...>
auto passthrough = zeroMaskedLanes ? ::llvm::Constant::getNullValue(elVecTy) : llvm::UndefValue::get(elVecTy);
auto align = ::llvm::ConstantInt::get(i32Ty, alignment);
auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_gather, { elVecTy, elPtrVecTy });
return jit->builder->CreateCall(func, { elPtrs, align, i8Mask, passthrough });
}
void createScatter(llvm::Value *base, llvm::Value *val, llvm::Value *offsets, llvm::Value *mask, unsigned int alignment)
{
ASSERT(base->getType()->isPointerTy());
ASSERT(val->getType()->isVectorTy());
ASSERT(offsets->getType()->isVectorTy());
ASSERT(mask->getType()->isVectorTy());
auto numEls = llvm::cast<llvm::VectorType>(mask->getType())->getNumElements();
auto i1Ty = ::llvm::Type::getInt1Ty(jit->context);
auto i32Ty = ::llvm::Type::getInt32Ty(jit->context);
auto i8Ty = ::llvm::Type::getInt8Ty(jit->context);
auto i8PtrTy = i8Ty->getPointerTo();
auto elVecTy = val->getType();
auto elTy = llvm::cast<llvm::VectorType>(elVecTy)->getElementType();
auto elPtrTy = elTy->getPointerTo();
auto elPtrVecTy = ::llvm::VectorType::get(elPtrTy, numEls, false);
auto i8Base = jit->builder->CreatePointerCast(base, i8PtrTy);
auto i8Ptrs = jit->builder->CreateGEP(i8Base, offsets);
auto elPtrs = jit->builder->CreatePointerCast(i8Ptrs, elPtrVecTy);
auto i1Mask = jit->builder->CreateIntCast(mask, ::llvm::VectorType::get(i1Ty, numEls, false), false); // vec<int, int, ...> -> vec<bool, bool, ...>
auto align = ::llvm::ConstantInt::get(i32Ty, alignment);
auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_scatter, { elVecTy, elPtrVecTy });
jit->builder->CreateCall(func, { val, elPtrs, align, i1Mask });
#if __has_feature(memory_sanitizer)
// Mark memory writes as initialized by calling __msan_unpoison
{
// void __msan_unpoison(const volatile void *a, size_t size)
auto voidTy = ::llvm::Type::getVoidTy(jit->context);
auto int8Ty = ::llvm::Type::getInt8Ty(jit->context);
auto int8PtrTy = int8Ty->getPointerTo();
auto sizetTy = ::llvm::IntegerType::get(jit->context, sizeof(size_t) * 8);
auto funcTy = ::llvm::FunctionType::get(voidTy, { int8PtrTy, sizetTy }, false);
auto func = jit->module->getOrInsertFunction("__msan_unpoison", funcTy);
auto size = jit->module->getDataLayout().getTypeStoreSize(elTy);
for(unsigned i = 0; i < numEls; i++)
{
// Check mask for this element
auto idx = ::llvm::ConstantInt::get(i32Ty, i);
auto thenBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
auto mergeBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
jit->builder->CreateCondBr(jit->builder->CreateExtractElement(i1Mask, idx), thenBlock, mergeBlock);
jit->builder->SetInsertPoint(thenBlock);
// Insert __msan_unpoison call in conditional block
auto elPtr = jit->builder->CreateExtractElement(elPtrs, idx);
jit->builder->CreateCall(func, { jit->builder->CreatePointerCast(elPtr, int8PtrTy),
::llvm::ConstantInt::get(sizetTy, size) });
jit->builder->CreateBr(mergeBlock);
jit->builder->SetInsertPoint(mergeBlock);
}
}
#endif
}
} // namespace } // namespace
namespace rr { namespace rr {
...@@ -1044,9 +966,9 @@ Value *Nucleus::createStore(Value *value, Value *ptr, Type *type, bool isVolatil ...@@ -1044,9 +966,9 @@ Value *Nucleus::createStore(Value *value, Value *ptr, Type *type, bool isVolatil
auto elTy = T(type); auto elTy = T(type);
ASSERT(V(ptr)->getType()->getContainedType(0) == elTy); ASSERT(V(ptr)->getType()->getContainedType(0) == elTy);
#if __has_feature(memory_sanitizer) if(__has_feature(memory_sanitizer))
// Mark all memory writes as initialized by calling __msan_unpoison
{ {
// Mark all memory writes as initialized by calling __msan_unpoison
// void __msan_unpoison(const volatile void *a, size_t size) // void __msan_unpoison(const volatile void *a, size_t size)
auto voidTy = ::llvm::Type::getVoidTy(jit->context); auto voidTy = ::llvm::Type::getVoidTy(jit->context);
auto i8Ty = ::llvm::Type::getInt8Ty(jit->context); auto i8Ty = ::llvm::Type::getInt8Ty(jit->context);
...@@ -1055,10 +977,10 @@ Value *Nucleus::createStore(Value *value, Value *ptr, Type *type, bool isVolatil ...@@ -1055,10 +977,10 @@ Value *Nucleus::createStore(Value *value, Value *ptr, Type *type, bool isVolatil
auto funcTy = ::llvm::FunctionType::get(voidTy, { voidPtrTy, sizetTy }, false); auto funcTy = ::llvm::FunctionType::get(voidTy, { voidPtrTy, sizetTy }, false);
auto func = jit->module->getOrInsertFunction("__msan_unpoison", funcTy); auto func = jit->module->getOrInsertFunction("__msan_unpoison", funcTy);
auto size = jit->module->getDataLayout().getTypeStoreSize(elTy); auto size = jit->module->getDataLayout().getTypeStoreSize(elTy);
jit->builder->CreateCall(func, { jit->builder->CreatePointerCast(V(ptr), voidPtrTy), jit->builder->CreateCall(func, { jit->builder->CreatePointerCast(V(ptr), voidPtrTy),
::llvm::ConstantInt::get(sizetTy, size) }); ::llvm::ConstantInt::get(sizetTy, size) });
} }
#endif
if(!atomic) if(!atomic)
{ {
...@@ -1150,9 +1072,9 @@ void Nucleus::createMaskedStore(Value *ptr, Value *val, Value *mask, unsigned in ...@@ -1150,9 +1072,9 @@ void Nucleus::createMaskedStore(Value *ptr, Value *val, Value *mask, unsigned in
auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_store, { elVecTy, elVecPtrTy }); auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_store, { elVecTy, elVecPtrTy });
jit->builder->CreateCall(func, { V(val), V(ptr), align, i1Mask }); jit->builder->CreateCall(func, { V(val), V(ptr), align, i1Mask });
#if __has_feature(memory_sanitizer) if(__has_feature(memory_sanitizer))
// Mark memory writes as initialized by calling __msan_unpoison
{ {
// Mark memory writes as initialized by calling __msan_unpoison
// void __msan_unpoison(const volatile void *a, size_t size) // void __msan_unpoison(const volatile void *a, size_t size)
auto voidTy = ::llvm::Type::getVoidTy(jit->context); auto voidTy = ::llvm::Type::getVoidTy(jit->context);
auto voidPtrTy = voidTy->getPointerTo(); auto voidPtrTy = voidTy->getPointerTo();
...@@ -1160,6 +1082,7 @@ void Nucleus::createMaskedStore(Value *ptr, Value *val, Value *mask, unsigned in ...@@ -1160,6 +1082,7 @@ void Nucleus::createMaskedStore(Value *ptr, Value *val, Value *mask, unsigned in
auto funcTy = ::llvm::FunctionType::get(voidTy, { voidPtrTy, sizetTy }, false); auto funcTy = ::llvm::FunctionType::get(voidTy, { voidPtrTy, sizetTy }, false);
auto func = jit->module->getOrInsertFunction("__msan_unpoison", funcTy); auto func = jit->module->getOrInsertFunction("__msan_unpoison", funcTy);
auto size = jit->module->getDataLayout().getTypeStoreSize(llvm::cast<llvm::VectorType>(elVecTy)->getElementType()); auto size = jit->module->getDataLayout().getTypeStoreSize(llvm::cast<llvm::VectorType>(elVecTy)->getElementType());
for(unsigned i = 0; i < numEls; i++) for(unsigned i = 0; i < numEls; i++)
{ {
// Check mask for this element // Check mask for this element
...@@ -1173,11 +1096,66 @@ void Nucleus::createMaskedStore(Value *ptr, Value *val, Value *mask, unsigned in ...@@ -1173,11 +1096,66 @@ void Nucleus::createMaskedStore(Value *ptr, Value *val, Value *mask, unsigned in
auto elPtr = jit->builder->CreateGEP(V(ptr), idx); auto elPtr = jit->builder->CreateGEP(V(ptr), idx);
jit->builder->CreateCall(func, { jit->builder->CreatePointerCast(elPtr, voidPtrTy), jit->builder->CreateCall(func, { jit->builder->CreatePointerCast(elPtr, voidPtrTy),
::llvm::ConstantInt::get(sizetTy, size) }); ::llvm::ConstantInt::get(sizetTy, size) });
jit->builder->CreateBr(mergeBlock); jit->builder->CreateBr(mergeBlock);
jit->builder->SetInsertPoint(mergeBlock); jit->builder->SetInsertPoint(mergeBlock);
} }
} }
#endif } // namespace rr
static llvm::Value *createGather(llvm::Value *base, llvm::Type *elTy, llvm::Value *offsets, llvm::Value *mask, unsigned int alignment, bool zeroMaskedLanes)
{
ASSERT(base->getType()->isPointerTy());
ASSERT(offsets->getType()->isVectorTy());
ASSERT(mask->getType()->isVectorTy());
auto numEls = llvm::cast<llvm::VectorType>(mask->getType())->getNumElements();
auto i1Ty = ::llvm::Type::getInt1Ty(jit->context);
auto i32Ty = ::llvm::Type::getInt32Ty(jit->context);
auto i8Ty = ::llvm::Type::getInt8Ty(jit->context);
auto i8PtrTy = i8Ty->getPointerTo();
auto elPtrTy = elTy->getPointerTo();
auto elVecTy = ::llvm::VectorType::get(elTy, numEls, false);
auto elPtrVecTy = ::llvm::VectorType::get(elPtrTy, numEls, false);
auto i8Base = jit->builder->CreatePointerCast(base, i8PtrTy);
auto i8Ptrs = jit->builder->CreateGEP(i8Base, offsets);
auto elPtrs = jit->builder->CreatePointerCast(i8Ptrs, elPtrVecTy);
auto i1Mask = jit->builder->CreateIntCast(mask, ::llvm::VectorType::get(i1Ty, numEls, false), false); // vec<int, int, ...> -> vec<bool, bool, ...>
auto passthrough = zeroMaskedLanes ? ::llvm::Constant::getNullValue(elVecTy) : llvm::UndefValue::get(elVecTy);
if(!__has_feature(memory_sanitizer))
{
auto align = ::llvm::ConstantInt::get(i32Ty, alignment);
auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_gather, { elVecTy, elPtrVecTy });
return jit->builder->CreateCall(func, { elPtrs, align, i1Mask, passthrough });
}
else // __has_feature(memory_sanitizer)
{
// MemorySanitizer currently does not support instrumenting llvm::Intrinsic::masked_gather
// Work around it by emulating gather with element-wise loads.
// TODO(b/172238865): Remove when supported by MemorySanitizer.
Value *result = Nucleus::allocateStackVariable(T(elVecTy));
Nucleus::createStore(V(passthrough), result, T(elVecTy));
for(unsigned i = 0; i < numEls; i++)
{
// Check mask for this element
Value *elementMask = Nucleus::createExtractElement(V(i1Mask), T(i1Ty), i);
If(RValue<Bool>(elementMask))
{
Value *elPtr = Nucleus::createExtractElement(V(elPtrs), T(elPtrTy), i);
Value *el = Nucleus::createLoad(elPtr, T(elTy), /*isVolatile */ false, alignment, /* atomic */ false, std::memory_order_relaxed);
Value *v = Nucleus::createLoad(result, T(elVecTy));
v = Nucleus::createInsertElement(v, el, i);
Nucleus::createStore(v, result, T(elVecTy));
}
}
return V(Nucleus::createLoad(result, T(elVecTy)));
}
} }
RValue<Float4> Gather(RValue<Pointer<Float>> base, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment, bool zeroMaskedLanes /* = false */) RValue<Float4> Gather(RValue<Pointer<Float>> base, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment, bool zeroMaskedLanes /* = false */)
...@@ -1187,7 +1165,60 @@ RValue<Float4> Gather(RValue<Pointer<Float>> base, RValue<Int4> offsets, RValue< ...@@ -1187,7 +1165,60 @@ RValue<Float4> Gather(RValue<Pointer<Float>> base, RValue<Int4> offsets, RValue<
RValue<Int4> Gather(RValue<Pointer<Int>> base, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment, bool zeroMaskedLanes /* = false */) RValue<Int4> Gather(RValue<Pointer<Int>> base, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment, bool zeroMaskedLanes /* = false */)
{ {
return As<Int4>(V(createGather(V(base.value()), T(Float::type()), V(offsets.value()), V(mask.value()), alignment, zeroMaskedLanes))); return As<Int4>(V(createGather(V(base.value()), T(Int::type()), V(offsets.value()), V(mask.value()), alignment, zeroMaskedLanes)));
}
static void createScatter(llvm::Value *base, llvm::Value *val, llvm::Value *offsets, llvm::Value *mask, unsigned int alignment)
{
ASSERT(base->getType()->isPointerTy());
ASSERT(val->getType()->isVectorTy());
ASSERT(offsets->getType()->isVectorTy());
ASSERT(mask->getType()->isVectorTy());
auto numEls = llvm::cast<llvm::VectorType>(mask->getType())->getNumElements();
auto i1Ty = ::llvm::Type::getInt1Ty(jit->context);
auto i32Ty = ::llvm::Type::getInt32Ty(jit->context);
auto i8Ty = ::llvm::Type::getInt8Ty(jit->context);
auto i8PtrTy = i8Ty->getPointerTo();
auto elVecTy = val->getType();
auto elTy = llvm::cast<llvm::VectorType>(elVecTy)->getElementType();
auto elPtrTy = elTy->getPointerTo();
auto elPtrVecTy = ::llvm::VectorType::get(elPtrTy, numEls, false);
auto i8Base = jit->builder->CreatePointerCast(base, i8PtrTy);
auto i8Ptrs = jit->builder->CreateGEP(i8Base, offsets);
auto elPtrs = jit->builder->CreatePointerCast(i8Ptrs, elPtrVecTy);
auto i1Mask = jit->builder->CreateIntCast(mask, ::llvm::VectorType::get(i1Ty, numEls, false), false); // vec<int, int, ...> -> vec<bool, bool, ...>
if(!__has_feature(memory_sanitizer))
{
auto align = ::llvm::ConstantInt::get(i32Ty, alignment);
auto func = ::llvm::Intrinsic::getDeclaration(jit->module.get(), llvm::Intrinsic::masked_scatter, { elVecTy, elPtrVecTy });
jit->builder->CreateCall(func, { val, elPtrs, align, i1Mask });
}
else // __has_feature(memory_sanitizer)
{
// MemorySanitizer currently does not support instrumenting llvm::Intrinsic::masked_scatter
// Work around it by emulating scatter with element-wise stores.
// TODO(b/172238865): Remove when supported by MemorySanitizer.
for(unsigned i = 0; i < numEls; i++)
{
// Check mask for this element
auto idx = ::llvm::ConstantInt::get(i32Ty, i);
auto thenBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
auto mergeBlock = ::llvm::BasicBlock::Create(jit->context, "", jit->function);
jit->builder->CreateCondBr(jit->builder->CreateExtractElement(i1Mask, idx), thenBlock, mergeBlock);
jit->builder->SetInsertPoint(thenBlock);
auto el = jit->builder->CreateExtractElement(val, idx);
auto elPtr = jit->builder->CreateExtractElement(elPtrs, idx);
Nucleus::createStore(V(el), V(elPtr), T(elTy), /*isVolatile */ false, alignment, /* atomic */ false, std::memory_order_relaxed);
jit->builder->CreateBr(mergeBlock);
jit->builder->SetInsertPoint(mergeBlock);
}
}
} }
void Scatter(RValue<Pointer<Float>> base, RValue<Float4> val, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment) void Scatter(RValue<Pointer<Float>> base, RValue<Float4> val, RValue<Int4> offsets, RValue<Int4> mask, unsigned int alignment)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment