Commit 1d5aaa6c by James Dong Committed by Commit Bot

Vulkan: support dynamic indices in array of arrays

Expands existing struct-sampler rewrite to flatten arrays of arrays. This allows us to support dynamically-uniform array indexing, which is core in ES 3.2. Samplers inside (possibly nested) structs are broken apart as before, and then if the type resulting from merging the array sizes of the field and its containing structs is an array of array, the array is flattened. Also adds an offset parameter to functions taking in arrays to account for this translation. As a result of outer array sizes leaking into function signatures, functions taking arrays of different sizes are duplicated according to how the function is invoked. Bug: angleproject:3604 Change-Id: Ic9373fd12a38f19bd811eac92e281055a63c1901 Reviewed-on: https://chromium-review.googlesource.com/c/angle/angle/+/1744177 Commit-Queue: James Dong <dongja@google.com> Reviewed-by: 's avatarShahbaz Youssefi <syoussefi@chromium.org>
parent 5f45432f
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
// Version number for shader translation API. // Version number for shader translation API.
// It is incremented every time the API changes. // It is incremented every time the API changes.
#define ANGLE_SH_VERSION 213 #define ANGLE_SH_VERSION 214
enum ShShaderSpec enum ShShaderSpec
{ {
...@@ -300,6 +300,10 @@ const ShCompileOptions SH_EMULATE_SEAMFUL_CUBE_MAP_SAMPLING_WITH_SUBGROUP_OP = U ...@@ -300,6 +300,10 @@ const ShCompileOptions SH_EMULATE_SEAMFUL_CUBE_MAP_SAMPLING_WITH_SUBGROUP_OP = U
// If requested, validates the AST after every transformation. Useful for debugging. // If requested, validates the AST after every transformation. Useful for debugging.
const ShCompileOptions SH_VALIDATE_AST = UINT64_C(1) << 46; const ShCompileOptions SH_VALIDATE_AST = UINT64_C(1) << 46;
// Use old version of RewriteStructSamplers, which doesn't produce as many
// sampler arrays in parameters. This causes a few tests to pass on Android.
const ShCompileOptions SH_USE_OLD_REWRITE_STRUCT_SAMPLERS = UINT64_C(1) << 47;
// Defines alternate strategies for implementing array index clamping. // Defines alternate strategies for implementing array index clamping.
enum ShArrayIndexClampingStrategy enum ShArrayIndexClampingStrategy
{ {
......
...@@ -200,6 +200,14 @@ struct FeaturesVk : FeatureSetBase ...@@ -200,6 +200,14 @@ struct FeaturesVk : FeatureSetBase
"disallow_seamful_cube_map_emulation", FeatureCategory::VulkanWorkarounds, "disallow_seamful_cube_map_emulation", FeatureCategory::VulkanWorkarounds,
"Seamful cube map emulation misbehaves on the AMD windows driver, so it's disallowed", "Seamful cube map emulation misbehaves on the AMD windows driver, so it's disallowed",
&members, "http://anglebug.com/3243"}; &members, "http://anglebug.com/3243"};
// Qualcomm shader compiler doesn't support sampler arrays as parameters, so
// revert to old RewriteStructSamplers behavior, which produces fewer.
Feature forceOldRewriteStructSamplers = {
"force_old_rewrite_struct_samplers", FeatureCategory::VulkanWorkarounds,
"Qualcomm shader compiler doesn't support sampler arrays as parameters, so "
"revert to old RewriteStructSamplers behavior, which produces fewer.",
&members, "http://anglebug.com/2703"};
}; };
inline FeaturesVk::FeaturesVk() = default; inline FeaturesVk::FeaturesVk() = default;
......
...@@ -168,6 +168,7 @@ angle_translator_sources = [ ...@@ -168,6 +168,7 @@ angle_translator_sources = [
"src/compiler/translator/tree_ops/RewriteExpressionsWithShaderStorageBlock.h", "src/compiler/translator/tree_ops/RewriteExpressionsWithShaderStorageBlock.h",
"src/compiler/translator/tree_ops/RewriteStructSamplers.cpp", "src/compiler/translator/tree_ops/RewriteStructSamplers.cpp",
"src/compiler/translator/tree_ops/RewriteStructSamplers.h", "src/compiler/translator/tree_ops/RewriteStructSamplers.h",
"src/compiler/translator/tree_ops/RewriteStructSamplersOld.cpp",
"src/compiler/translator/tree_ops/RewriteRepeatedAssignToSwizzled.cpp", "src/compiler/translator/tree_ops/RewriteRepeatedAssignToSwizzled.cpp",
"src/compiler/translator/tree_ops/RewriteRepeatedAssignToSwizzled.h", "src/compiler/translator/tree_ops/RewriteRepeatedAssignToSwizzled.h",
"src/compiler/translator/tree_ops/RewriteRowMajorMatrices.cpp", "src/compiler/translator/tree_ops/RewriteRowMajorMatrices.cpp",
......
...@@ -268,6 +268,10 @@ bool TIntermLoop::replaceChildNode(TIntermNode *original, TIntermNode *replaceme ...@@ -268,6 +268,10 @@ bool TIntermLoop::replaceChildNode(TIntermNode *original, TIntermNode *replaceme
return false; return false;
} }
TIntermBranch::TIntermBranch(const TIntermBranch &node)
: TIntermBranch(node.mFlowOp, node.mExpression->deepCopy())
{}
size_t TIntermBranch::getChildCount() const size_t TIntermBranch::getChildCount() const
{ {
return (mExpression ? 1 : 0); return (mExpression ? 1 : 0);
...@@ -401,6 +405,14 @@ bool TIntermAggregate::replaceChildNode(TIntermNode *original, TIntermNode *repl ...@@ -401,6 +405,14 @@ bool TIntermAggregate::replaceChildNode(TIntermNode *original, TIntermNode *repl
return replaceChildNodeInternal(original, replacement); return replaceChildNodeInternal(original, replacement);
} }
TIntermBlock::TIntermBlock(const TIntermBlock &node)
{
for (TIntermNode *node : node.mStatements)
{
mStatements.push_back(node->deepCopy());
}
}
size_t TIntermBlock::getChildCount() const size_t TIntermBlock::getChildCount() const
{ {
return mStatements.size(); return mStatements.size();
...@@ -954,6 +966,8 @@ bool TIntermSwitch::replaceChildNode(TIntermNode *original, TIntermNode *replace ...@@ -954,6 +966,8 @@ bool TIntermSwitch::replaceChildNode(TIntermNode *original, TIntermNode *replace
return false; return false;
} }
TIntermCase::TIntermCase(const TIntermCase &node) : TIntermCase(node.mCondition->deepCopy()) {}
size_t TIntermCase::getChildCount() const size_t TIntermCase::getChildCount() const
{ {
return (mCondition ? 1 : 0); return (mCondition ? 1 : 0);
...@@ -1326,6 +1340,11 @@ TIntermInvariantDeclaration::TIntermInvariantDeclaration(TIntermSymbol *symbol, ...@@ -1326,6 +1340,11 @@ TIntermInvariantDeclaration::TIntermInvariantDeclaration(TIntermSymbol *symbol,
setLine(line); setLine(line);
} }
TIntermInvariantDeclaration::TIntermInvariantDeclaration(const TIntermInvariantDeclaration &node)
: TIntermInvariantDeclaration(static_cast<TIntermSymbol *>(node.mSymbol->deepCopy()),
node.mLine)
{}
TIntermTernary::TIntermTernary(TIntermTyped *cond, TIntermTernary::TIntermTernary(TIntermTyped *cond,
TIntermTyped *trueExpression, TIntermTyped *trueExpression,
TIntermTyped *falseExpression) TIntermTyped *falseExpression)
...@@ -1357,6 +1376,14 @@ TIntermLoop::TIntermLoop(TLoopType type, ...@@ -1357,6 +1376,14 @@ TIntermLoop::TIntermLoop(TLoopType type,
} }
} }
TIntermLoop::TIntermLoop(const TIntermLoop &node)
: TIntermLoop(node.mType,
node.mInit->deepCopy(),
node.mCond->deepCopy(),
node.mExpr->deepCopy(),
node.mBody->deepCopy())
{}
TIntermIfElse::TIntermIfElse(TIntermTyped *cond, TIntermBlock *trueB, TIntermBlock *falseB) TIntermIfElse::TIntermIfElse(TIntermTyped *cond, TIntermBlock *trueB, TIntermBlock *falseB)
: TIntermNode(), mCondition(cond), mTrueBlock(trueB), mFalseBlock(falseB) : TIntermNode(), mCondition(cond), mTrueBlock(trueB), mFalseBlock(falseB)
{ {
...@@ -1368,6 +1395,12 @@ TIntermIfElse::TIntermIfElse(TIntermTyped *cond, TIntermBlock *trueB, TIntermBlo ...@@ -1368,6 +1395,12 @@ TIntermIfElse::TIntermIfElse(TIntermTyped *cond, TIntermBlock *trueB, TIntermBlo
} }
} }
TIntermIfElse::TIntermIfElse(const TIntermIfElse &node)
: TIntermIfElse(node.mCondition->deepCopy(),
node.mTrueBlock->deepCopy(),
node.mFalseBlock ? node.mFalseBlock->deepCopy() : nullptr)
{}
TIntermSwitch::TIntermSwitch(TIntermTyped *init, TIntermBlock *statementList) TIntermSwitch::TIntermSwitch(TIntermTyped *init, TIntermBlock *statementList)
: TIntermNode(), mInit(init), mStatementList(statementList) : TIntermNode(), mInit(init), mStatementList(statementList)
{ {
...@@ -1375,6 +1408,10 @@ TIntermSwitch::TIntermSwitch(TIntermTyped *init, TIntermBlock *statementList) ...@@ -1375,6 +1408,10 @@ TIntermSwitch::TIntermSwitch(TIntermTyped *init, TIntermBlock *statementList)
ASSERT(mStatementList); ASSERT(mStatementList);
} }
TIntermSwitch::TIntermSwitch(const TIntermSwitch &node)
: TIntermSwitch(node.mInit->deepCopy(), node.mStatementList->deepCopy())
{}
void TIntermSwitch::setStatementList(TIntermBlock *statementList) void TIntermSwitch::setStatementList(TIntermBlock *statementList)
{ {
ASSERT(statementList); ASSERT(statementList);
...@@ -3772,6 +3809,10 @@ TIntermPreprocessorDirective::TIntermPreprocessorDirective(PreprocessorDirective ...@@ -3772,6 +3809,10 @@ TIntermPreprocessorDirective::TIntermPreprocessorDirective(PreprocessorDirective
: mDirective(directive), mCommand(std::move(command)) : mDirective(directive), mCommand(std::move(command))
{} {}
TIntermPreprocessorDirective::TIntermPreprocessorDirective(const TIntermPreprocessorDirective &node)
: TIntermPreprocessorDirective(node.mDirective, node.mCommand)
{}
TIntermPreprocessorDirective::~TIntermPreprocessorDirective() = default; TIntermPreprocessorDirective::~TIntermPreprocessorDirective() = default;
size_t TIntermPreprocessorDirective::getChildCount() const size_t TIntermPreprocessorDirective::getChildCount() const
......
...@@ -104,6 +104,8 @@ class TIntermNode : angle::NonCopyable ...@@ -104,6 +104,8 @@ class TIntermNode : angle::NonCopyable
virtual TIntermBranch *getAsBranchNode() { return nullptr; } virtual TIntermBranch *getAsBranchNode() { return nullptr; }
virtual TIntermPreprocessorDirective *getAsPreprocessorDirective() { return nullptr; } virtual TIntermPreprocessorDirective *getAsPreprocessorDirective() { return nullptr; }
virtual TIntermNode *deepCopy() const = 0;
virtual size_t getChildCount() const = 0; virtual size_t getChildCount() const = 0;
virtual TIntermNode *getChildNode(size_t index) const = 0; virtual TIntermNode *getChildNode(size_t index) const = 0;
// Replace a child node. Return true if |original| is a child // Replace a child node. Return true if |original| is a child
...@@ -131,7 +133,7 @@ class TIntermTyped : public TIntermNode ...@@ -131,7 +133,7 @@ class TIntermTyped : public TIntermNode
public: public:
TIntermTyped() {} TIntermTyped() {}
virtual TIntermTyped *deepCopy() const = 0; virtual TIntermTyped *deepCopy() const override = 0;
TIntermTyped *getAsTyped() override { return this; } TIntermTyped *getAsTyped() override { return this; }
...@@ -211,12 +213,17 @@ class TIntermLoop : public TIntermNode ...@@ -211,12 +213,17 @@ class TIntermLoop : public TIntermNode
void setExpression(TIntermTyped *expression) { mExpr = expression; } void setExpression(TIntermTyped *expression) { mExpr = expression; }
void setBody(TIntermBlock *body) { mBody = body; } void setBody(TIntermBlock *body) { mBody = body; }
virtual TIntermLoop *deepCopy() const override { return new TIntermLoop(*this); }
protected: protected:
TLoopType mType; TLoopType mType;
TIntermNode *mInit; // for-loop initialization TIntermNode *mInit; // for-loop initialization
TIntermTyped *mCond; // loop exit condition TIntermTyped *mCond; // loop exit condition
TIntermTyped *mExpr; // for-loop expression TIntermTyped *mExpr; // for-loop expression
TIntermBlock *mBody; // loop body TIntermBlock *mBody; // loop body
private:
TIntermLoop(const TIntermLoop &);
}; };
// //
...@@ -237,9 +244,14 @@ class TIntermBranch : public TIntermNode ...@@ -237,9 +244,14 @@ class TIntermBranch : public TIntermNode
TOperator getFlowOp() { return mFlowOp; } TOperator getFlowOp() { return mFlowOp; }
TIntermTyped *getExpression() { return mExpression; } TIntermTyped *getExpression() { return mExpression; }
virtual TIntermBranch *deepCopy() const override { return new TIntermBranch(*this); }
protected: protected:
TOperator mFlowOp; TOperator mFlowOp;
TIntermTyped *mExpression; // zero except for "return exp;" statements TIntermTyped *mExpression; // zero except for "return exp;" statements
private:
TIntermBranch(const TIntermBranch &);
}; };
// Nodes that correspond to variable symbols in the source code. These may be regular variables or // Nodes that correspond to variable symbols in the source code. These may be regular variables or
...@@ -676,8 +688,13 @@ class TIntermBlock : public TIntermNode, public TIntermAggregateBase ...@@ -676,8 +688,13 @@ class TIntermBlock : public TIntermNode, public TIntermAggregateBase
TIntermSequence *getSequence() override { return &mStatements; } TIntermSequence *getSequence() override { return &mStatements; }
const TIntermSequence *getSequence() const override { return &mStatements; } const TIntermSequence *getSequence() const override { return &mStatements; }
TIntermBlock *deepCopy() const override { return new TIntermBlock(*this); }
protected: protected:
TIntermSequence mStatements; TIntermSequence mStatements;
private:
TIntermBlock(const TIntermBlock &);
}; };
// Function prototype. May be in the AST either as a function prototype declaration or as a part of // Function prototype. May be in the AST either as a function prototype declaration or as a part of
...@@ -740,6 +757,12 @@ class TIntermFunctionDefinition : public TIntermNode ...@@ -740,6 +757,12 @@ class TIntermFunctionDefinition : public TIntermNode
const TFunction *getFunction() const { return mPrototype->getFunction(); } const TFunction *getFunction() const { return mPrototype->getFunction(); }
TIntermNode *deepCopy() const override
{
UNREACHABLE();
return nullptr;
}
private: private:
TIntermFunctionPrototype *mPrototype; TIntermFunctionPrototype *mPrototype;
TIntermBlock *mBody; TIntermBlock *mBody;
...@@ -767,6 +790,12 @@ class TIntermDeclaration : public TIntermNode, public TIntermAggregateBase ...@@ -767,6 +790,12 @@ class TIntermDeclaration : public TIntermNode, public TIntermAggregateBase
TIntermSequence *getSequence() override { return &mDeclarators; } TIntermSequence *getSequence() override { return &mDeclarators; }
const TIntermSequence *getSequence() const override { return &mDeclarators; } const TIntermSequence *getSequence() const override { return &mDeclarators; }
TIntermNode *deepCopy() const override
{
UNREACHABLE();
return nullptr;
}
protected: protected:
TIntermSequence mDeclarators; TIntermSequence mDeclarators;
}; };
...@@ -786,8 +815,15 @@ class TIntermInvariantDeclaration : public TIntermNode ...@@ -786,8 +815,15 @@ class TIntermInvariantDeclaration : public TIntermNode
TIntermNode *getChildNode(size_t index) const final; TIntermNode *getChildNode(size_t index) const final;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override; bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
TIntermInvariantDeclaration *deepCopy() const override
{
return new TIntermInvariantDeclaration(*this);
}
private: private:
TIntermSymbol *mSymbol; TIntermSymbol *mSymbol;
TIntermInvariantDeclaration(const TIntermInvariantDeclaration &);
}; };
// For ternary operators like a ? b : c. // For ternary operators like a ? b : c.
...@@ -845,10 +881,15 @@ class TIntermIfElse : public TIntermNode ...@@ -845,10 +881,15 @@ class TIntermIfElse : public TIntermNode
TIntermBlock *getTrueBlock() const { return mTrueBlock; } TIntermBlock *getTrueBlock() const { return mTrueBlock; }
TIntermBlock *getFalseBlock() const { return mFalseBlock; } TIntermBlock *getFalseBlock() const { return mFalseBlock; }
TIntermIfElse *deepCopy() const override { return new TIntermIfElse(*this); }
protected: protected:
TIntermTyped *mCondition; TIntermTyped *mCondition;
TIntermBlock *mTrueBlock; TIntermBlock *mTrueBlock;
TIntermBlock *mFalseBlock; TIntermBlock *mFalseBlock;
private:
TIntermIfElse(const TIntermIfElse &);
}; };
// //
...@@ -872,9 +913,14 @@ class TIntermSwitch : public TIntermNode ...@@ -872,9 +913,14 @@ class TIntermSwitch : public TIntermNode
// Must be called with a non-null statementList. // Must be called with a non-null statementList.
void setStatementList(TIntermBlock *statementList); void setStatementList(TIntermBlock *statementList);
TIntermSwitch *deepCopy() const override { return new TIntermSwitch(*this); }
protected: protected:
TIntermTyped *mInit; TIntermTyped *mInit;
TIntermBlock *mStatementList; TIntermBlock *mStatementList;
private:
TIntermSwitch(const TIntermSwitch &);
}; };
// //
...@@ -895,8 +941,13 @@ class TIntermCase : public TIntermNode ...@@ -895,8 +941,13 @@ class TIntermCase : public TIntermNode
bool hasCondition() const { return mCondition != nullptr; } bool hasCondition() const { return mCondition != nullptr; }
TIntermTyped *getCondition() const { return mCondition; } TIntermTyped *getCondition() const { return mCondition; }
TIntermCase *deepCopy() const override { return new TIntermCase(*this); }
protected: protected:
TIntermTyped *mCondition; TIntermTyped *mCondition;
private:
TIntermCase(const TIntermCase &);
}; };
// //
...@@ -930,9 +981,16 @@ class TIntermPreprocessorDirective : public TIntermNode ...@@ -930,9 +981,16 @@ class TIntermPreprocessorDirective : public TIntermNode
PreprocessorDirective getDirective() const { return mDirective; } PreprocessorDirective getDirective() const { return mDirective; }
const ImmutableString &getCommand() const { return mCommand; } const ImmutableString &getCommand() const { return mCommand; }
TIntermPreprocessorDirective *deepCopy() const override
{
return new TIntermPreprocessorDirective(*this);
}
private: private:
PreprocessorDirective mDirective; PreprocessorDirective mDirective;
ImmutableString mCommand; ImmutableString mCommand;
TIntermPreprocessorDirective(const TIntermPreprocessorDirective &);
}; };
} // namespace sh } // namespace sh
......
...@@ -222,12 +222,28 @@ bool TFunction::isAtomicCounterFunction() const ...@@ -222,12 +222,28 @@ bool TFunction::isAtomicCounterFunction() const
return SymbolType() == SymbolType::BuiltIn && name().beginsWith(kAtomicCounterName); return SymbolType() == SymbolType::BuiltIn && name().beginsWith(kAtomicCounterName);
} }
bool TFunction::hasSamplerInStructParams() const bool TFunction::hasSamplerInStructOrArrayParams() const
{ {
for (size_t paramIndex = 0; paramIndex < mParamCount; ++paramIndex) for (size_t paramIndex = 0; paramIndex < mParamCount; ++paramIndex)
{ {
const TVariable *param = getParam(paramIndex); const TVariable *param = getParam(paramIndex);
if (param->getType().isStructureContainingSamplers()) if (param->getType().isStructureContainingSamplers() ||
(param->getType().isArray() && param->getType().isSampler()))
{
return true;
}
}
return false;
}
bool TFunction::hasSamplerInStructOrArrayOfArrayParams() const
{
for (size_t paramIndex = 0; paramIndex < mParamCount; ++paramIndex)
{
const TVariable *param = getParam(paramIndex);
if (param->getType().isStructureContainingSamplers() ||
(param->getType().isArrayOfArrays() && param->getType().isSampler()))
{ {
return true; return true;
} }
......
...@@ -234,7 +234,8 @@ class TFunction : public TSymbol ...@@ -234,7 +234,8 @@ class TFunction : public TSymbol
bool isMain() const; bool isMain() const;
bool isImageFunction() const; bool isImageFunction() const;
bool isAtomicCounterFunction() const; bool isAtomicCounterFunction() const;
bool hasSamplerInStructParams() const; bool hasSamplerInStructOrArrayParams() const;
bool hasSamplerInStructOrArrayOfArrayParams() const;
// Note: Only to be used for static built-in functions! // Note: Only to be used for static built-in functions!
constexpr TFunction(const TSymbolUniqueId &id, constexpr TFunction(const TSymbolUniqueId &id,
......
...@@ -671,9 +671,9 @@ bool TranslatorVulkan::translate(TIntermBlock *root, ...@@ -671,9 +671,9 @@ bool TranslatorVulkan::translate(TIntermBlock *root,
} }
// Write out default uniforms into a uniform block assigned to a specific set/binding. // Write out default uniforms into a uniform block assigned to a specific set/binding.
int defaultUniformCount = 0; int defaultUniformCount = 0;
int structTypesUsedForUniforms = 0; int aggregateTypesUsedForUniforms = 0;
int atomicCounterCount = 0; int atomicCounterCount = 0;
for (const auto &uniform : getUniforms()) for (const auto &uniform : getUniforms())
{ {
if (!uniform.isBuiltIn() && uniform.staticUse && !gl::IsOpaqueType(uniform.type)) if (!uniform.isBuiltIn() && uniform.staticUse && !gl::IsOpaqueType(uniform.type))
...@@ -681,9 +681,9 @@ bool TranslatorVulkan::translate(TIntermBlock *root, ...@@ -681,9 +681,9 @@ bool TranslatorVulkan::translate(TIntermBlock *root,
++defaultUniformCount; ++defaultUniformCount;
} }
if (uniform.isStruct()) if (uniform.isStruct() || uniform.isArrayOfArrays())
{ {
++structTypesUsedForUniforms; ++aggregateTypesUsedForUniforms;
} }
if (gl::IsAtomicCounterType(uniform.type)) if (gl::IsAtomicCounterType(uniform.type))
...@@ -694,15 +694,28 @@ bool TranslatorVulkan::translate(TIntermBlock *root, ...@@ -694,15 +694,28 @@ bool TranslatorVulkan::translate(TIntermBlock *root,
// TODO(lucferron): Refactor this function to do fewer tree traversals. // TODO(lucferron): Refactor this function to do fewer tree traversals.
// http://anglebug.com/2461 // http://anglebug.com/2461
if (structTypesUsedForUniforms > 0) if (aggregateTypesUsedForUniforms > 0)
{ {
if (!NameEmbeddedStructUniforms(this, root, &getSymbolTable())) if (!NameEmbeddedStructUniforms(this, root, &getSymbolTable()))
{ {
return false; return false;
} }
int removedUniformsCount = 0; bool rewriteStructSamplersResult;
if (!RewriteStructSamplers(this, root, &getSymbolTable(), &removedUniformsCount)) int removedUniformsCount;
if (compileOptions & SH_USE_OLD_REWRITE_STRUCT_SAMPLERS)
{
rewriteStructSamplersResult =
RewriteStructSamplersOld(this, root, &getSymbolTable(), &removedUniformsCount);
}
else
{
rewriteStructSamplersResult =
RewriteStructSamplers(this, root, &getSymbolTable(), &removedUniformsCount);
}
if (!rewriteStructSamplersResult)
{ {
return false; return false;
} }
......
...@@ -716,6 +716,19 @@ void TType::toArrayElementType() ...@@ -716,6 +716,19 @@ void TType::toArrayElementType()
} }
} }
void TType::toArrayBaseType()
{
if (mArraySizes == nullptr)
{
return;
}
if (mArraySizes->size() > 0)
{
mArraySizes->clear();
}
invalidateMangledName();
}
void TType::setInterfaceBlock(const TInterfaceBlock *interfaceBlockIn) void TType::setInterfaceBlock(const TInterfaceBlock *interfaceBlockIn)
{ {
if (mInterfaceBlock != interfaceBlockIn) if (mInterfaceBlock != interfaceBlockIn)
......
...@@ -214,6 +214,8 @@ class TType ...@@ -214,6 +214,8 @@ class TType
// Note that the array element type might still be an array type in GLSL ES version >= 3.10. // Note that the array element type might still be an array type in GLSL ES version >= 3.10.
void toArrayElementType(); void toArrayElementType();
// Removes all array sizes.
void toArrayBaseType();
const TInterfaceBlock *getInterfaceBlock() const { return mInterfaceBlock; } const TInterfaceBlock *getInterfaceBlock() const { return mInterfaceBlock; }
void setInterfaceBlock(const TInterfaceBlock *interfaceBlockIn); void setInterfaceBlock(const TInterfaceBlock *interfaceBlockIn);
......
...@@ -406,6 +406,7 @@ void VariableNameVisitor::enterArray(const ShaderVariable &arrayVar) ...@@ -406,6 +406,7 @@ void VariableNameVisitor::enterArray(const ShaderVariable &arrayVar)
mNameStack.push_back(arrayVar.name); mNameStack.push_back(arrayVar.name);
mMappedNameStack.push_back(arrayVar.mappedName); mMappedNameStack.push_back(arrayVar.mappedName);
} }
mArraySizeStack.push_back(arrayVar.getOutermostArraySize());
} }
void VariableNameVisitor::exitArray(const ShaderVariable &arrayVar) void VariableNameVisitor::exitArray(const ShaderVariable &arrayVar)
...@@ -415,6 +416,7 @@ void VariableNameVisitor::exitArray(const ShaderVariable &arrayVar) ...@@ -415,6 +416,7 @@ void VariableNameVisitor::exitArray(const ShaderVariable &arrayVar)
mNameStack.pop_back(); mNameStack.pop_back();
mMappedNameStack.pop_back(); mMappedNameStack.pop_back();
} }
mArraySizeStack.pop_back();
} }
void VariableNameVisitor::enterArrayElement(const ShaderVariable &arrayVar, void VariableNameVisitor::enterArrayElement(const ShaderVariable &arrayVar,
...@@ -461,7 +463,7 @@ void VariableNameVisitor::visitSampler(const sh::ShaderVariable &sampler) ...@@ -461,7 +463,7 @@ void VariableNameVisitor::visitSampler(const sh::ShaderVariable &sampler)
mMappedNameStack.pop_back(); mMappedNameStack.pop_back();
} }
visitNamedSampler(sampler, name, mappedName); visitNamedSampler(sampler, name, mappedName, mArraySizeStack);
} }
void VariableNameVisitor::visitVariable(const ShaderVariable &variable, bool isRowMajor) void VariableNameVisitor::visitVariable(const ShaderVariable &variable, bool isRowMajor)
...@@ -481,7 +483,7 @@ void VariableNameVisitor::visitVariable(const ShaderVariable &variable, bool isR ...@@ -481,7 +483,7 @@ void VariableNameVisitor::visitVariable(const ShaderVariable &variable, bool isR
mMappedNameStack.pop_back(); mMappedNameStack.pop_back();
} }
visitNamedVariable(variable, isRowMajor, name, mappedName); visitNamedVariable(variable, isRowMajor, name, mappedName, mArraySizeStack);
} }
// BlockEncoderVisitor implementation. // BlockEncoderVisitor implementation.
...@@ -554,7 +556,8 @@ void BlockEncoderVisitor::exitArrayElement(const sh::ShaderVariable &arrayVar, ...@@ -554,7 +556,8 @@ void BlockEncoderVisitor::exitArrayElement(const sh::ShaderVariable &arrayVar,
void BlockEncoderVisitor::visitNamedVariable(const ShaderVariable &variable, void BlockEncoderVisitor::visitNamedVariable(const ShaderVariable &variable,
bool isRowMajor, bool isRowMajor,
const std::string &name, const std::string &name,
const std::string &mappedName) const std::string &mappedName,
const std::vector<unsigned int> &arraySizes)
{ {
std::vector<unsigned int> innermostArraySize; std::vector<unsigned int> innermostArraySize;
......
...@@ -230,12 +230,14 @@ class VariableNameVisitor : public ShaderVariableVisitor ...@@ -230,12 +230,14 @@ class VariableNameVisitor : public ShaderVariableVisitor
protected: protected:
virtual void visitNamedSampler(const sh::ShaderVariable &sampler, virtual void visitNamedSampler(const sh::ShaderVariable &sampler,
const std::string &name, const std::string &name,
const std::string &mappedName) const std::string &mappedName,
const std::vector<unsigned int> &arraySizes)
{} {}
virtual void visitNamedVariable(const ShaderVariable &variable, virtual void visitNamedVariable(const ShaderVariable &variable,
bool isRowMajor, bool isRowMajor,
const std::string &name, const std::string &name,
const std::string &mappedName) = 0; const std::string &mappedName,
const std::vector<unsigned int> &arraySizes) = 0;
std::string collapseNameStack() const; std::string collapseNameStack() const;
std::string collapseMappedNameStack() const; std::string collapseMappedNameStack() const;
...@@ -246,6 +248,7 @@ class VariableNameVisitor : public ShaderVariableVisitor ...@@ -246,6 +248,7 @@ class VariableNameVisitor : public ShaderVariableVisitor
std::vector<std::string> mNameStack; std::vector<std::string> mNameStack;
std::vector<std::string> mMappedNameStack; std::vector<std::string> mMappedNameStack;
std::vector<unsigned int> mArraySizeStack;
}; };
class BlockEncoderVisitor : public VariableNameVisitor class BlockEncoderVisitor : public VariableNameVisitor
...@@ -264,7 +267,8 @@ class BlockEncoderVisitor : public VariableNameVisitor ...@@ -264,7 +267,8 @@ class BlockEncoderVisitor : public VariableNameVisitor
void visitNamedVariable(const ShaderVariable &variable, void visitNamedVariable(const ShaderVariable &variable,
bool isRowMajor, bool isRowMajor,
const std::string &name, const std::string &name,
const std::string &mappedName) override; const std::string &mappedName,
const std::vector<unsigned int> &arraySizes) override;
virtual void encodeVariable(const ShaderVariable &variable, virtual void encodeVariable(const ShaderVariable &variable,
const BlockMemberInfo &variableInfo, const BlockMemberInfo &variableInfo,
......
...@@ -33,6 +33,10 @@ ANGLE_NO_DISCARD bool RewriteStructSamplers(TCompiler *compiler, ...@@ -33,6 +33,10 @@ ANGLE_NO_DISCARD bool RewriteStructSamplers(TCompiler *compiler,
TIntermBlock *root, TIntermBlock *root,
TSymbolTable *symbolTable, TSymbolTable *symbolTable,
int *removedUniformsCountOut); int *removedUniformsCountOut);
ANGLE_NO_DISCARD bool RewriteStructSamplersOld(TCompiler *compier,
TIntermBlock *root,
TSymbolTable *symbolTable,
int *removedUniformsCountOut);
} // namespace sh } // namespace sh
#endif // COMPILER_TRANSLATOR_TREEOPS_REWRITESTRUCTSAMPLERS_H_ #endif // COMPILER_TRANSLATOR_TREEOPS_REWRITESTRUCTSAMPLERS_H_
...@@ -248,7 +248,8 @@ class UniformBlockEncodingVisitor : public sh::VariableNameVisitor ...@@ -248,7 +248,8 @@ class UniformBlockEncodingVisitor : public sh::VariableNameVisitor
void visitNamedVariable(const sh::ShaderVariable &variable, void visitNamedVariable(const sh::ShaderVariable &variable,
bool isRowMajor, bool isRowMajor,
const std::string &name, const std::string &name,
const std::string &mappedName) override const std::string &mappedName,
const std::vector<unsigned int> &arraySizes) override
{ {
// If getBlockMemberInfo returns false, the variable is optimized out. // If getBlockMemberInfo returns false, the variable is optimized out.
sh::BlockMemberInfo variableInfo; sh::BlockMemberInfo variableInfo;
...@@ -308,7 +309,8 @@ class ShaderStorageBlockVisitor : public sh::BlockEncoderVisitor ...@@ -308,7 +309,8 @@ class ShaderStorageBlockVisitor : public sh::BlockEncoderVisitor
void visitNamedVariable(const sh::ShaderVariable &variable, void visitNamedVariable(const sh::ShaderVariable &variable,
bool isRowMajor, bool isRowMajor,
const std::string &name, const std::string &name,
const std::string &mappedName) override const std::string &mappedName,
const std::vector<unsigned int> &arraySizes) override
{ {
if (mSkipEnabled) if (mSkipEnabled)
return; return;
...@@ -397,15 +399,17 @@ class FlattenUniformVisitor : public sh::VariableNameVisitor ...@@ -397,15 +399,17 @@ class FlattenUniformVisitor : public sh::VariableNameVisitor
void visitNamedSampler(const sh::ShaderVariable &sampler, void visitNamedSampler(const sh::ShaderVariable &sampler,
const std::string &name, const std::string &name,
const std::string &mappedName) override const std::string &mappedName,
const std::vector<unsigned int> &arraySizes) override
{ {
visitNamedVariable(sampler, false, name, mappedName); visitNamedVariable(sampler, false, name, mappedName, arraySizes);
} }
void visitNamedVariable(const sh::ShaderVariable &variable, void visitNamedVariable(const sh::ShaderVariable &variable,
bool isRowMajor, bool isRowMajor,
const std::string &name, const std::string &name,
const std::string &mappedName) override const std::string &mappedName,
const std::vector<unsigned int> &arraySizes) override
{ {
bool isSampler = IsSamplerType(variable.type); bool isSampler = IsSamplerType(variable.type);
bool isImage = IsImageType(variable.type); bool isImage = IsImageType(variable.type);
...@@ -468,6 +472,7 @@ class FlattenUniformVisitor : public sh::VariableNameVisitor ...@@ -468,6 +472,7 @@ class FlattenUniformVisitor : public sh::VariableNameVisitor
linkedUniform.mappedName = fullMappedNameWithArrayIndex; linkedUniform.mappedName = fullMappedNameWithArrayIndex;
linkedUniform.active = mMarkActive; linkedUniform.active = mMarkActive;
linkedUniform.staticUse = mMarkStaticUse; linkedUniform.staticUse = mMarkStaticUse;
linkedUniform.outerArraySizes = arraySizes;
if (variable.hasParentArrayIndex()) if (variable.hasParentArrayIndex())
{ {
linkedUniform.setParentArrayIndex(variable.parentArrayIndex()); linkedUniform.setParentArrayIndex(variable.parentArrayIndex());
......
...@@ -406,7 +406,15 @@ void Shader::resolveCompile() ...@@ -406,7 +406,15 @@ void Shader::resolveCompile()
// Remove null characters from the source line // Remove null characters from the source line
line.erase(std::remove(line.begin(), line.end(), '\0'), line.end()); line.erase(std::remove(line.begin(), line.end(), '\0'), line.end());
shaderStream << "// " << line << std::endl; shaderStream << "// " << line;
// glslang complains if a comment ends with backslash
if (!line.empty() && line.back() == '\\')
{
shaderStream << "\\";
}
shaderStream << std::endl;
} }
shaderStream << "\n\n"; shaderStream << "\n\n";
shaderStream << mState.mTranslatedSource; shaderStream << mState.mTranslatedSource;
......
...@@ -80,7 +80,8 @@ LinkedUniform::LinkedUniform(const LinkedUniform &uniform) ...@@ -80,7 +80,8 @@ LinkedUniform::LinkedUniform(const LinkedUniform &uniform)
ActiveVariable(uniform), ActiveVariable(uniform),
typeInfo(uniform.typeInfo), typeInfo(uniform.typeInfo),
bufferIndex(uniform.bufferIndex), bufferIndex(uniform.bufferIndex),
blockInfo(uniform.blockInfo) blockInfo(uniform.blockInfo),
outerArraySizes(uniform.outerArraySizes)
{} {}
LinkedUniform &LinkedUniform::operator=(const LinkedUniform &uniform) LinkedUniform &LinkedUniform::operator=(const LinkedUniform &uniform)
...@@ -90,6 +91,7 @@ LinkedUniform &LinkedUniform::operator=(const LinkedUniform &uniform) ...@@ -90,6 +91,7 @@ LinkedUniform &LinkedUniform::operator=(const LinkedUniform &uniform)
typeInfo = uniform.typeInfo; typeInfo = uniform.typeInfo;
bufferIndex = uniform.bufferIndex; bufferIndex = uniform.bufferIndex;
blockInfo = uniform.blockInfo; blockInfo = uniform.blockInfo;
outerArraySizes = uniform.outerArraySizes;
return *this; return *this;
} }
......
...@@ -75,6 +75,7 @@ struct LinkedUniform : public sh::Uniform, public ActiveVariable ...@@ -75,6 +75,7 @@ struct LinkedUniform : public sh::Uniform, public ActiveVariable
// Identifies the containing buffer backed resource -- interface block or atomic counter buffer. // Identifies the containing buffer backed resource -- interface block or atomic counter buffer.
int bufferIndex; int bufferIndex;
sh::BlockMemberInfo blockInfo; sh::BlockMemberInfo blockInfo;
std::vector<unsigned int> outerArraySizes;
}; };
struct BufferVariable : public sh::ShaderVariable, public ActiveVariable struct BufferVariable : public sh::ShaderVariable, public ActiveVariable
......
...@@ -217,7 +217,8 @@ class UniformEncodingVisitorD3D : public sh::BlockEncoderVisitor ...@@ -217,7 +217,8 @@ class UniformEncodingVisitorD3D : public sh::BlockEncoderVisitor
void visitNamedSampler(const sh::ShaderVariable &sampler, void visitNamedSampler(const sh::ShaderVariable &sampler,
const std::string &name, const std::string &name,
const std::string &mappedName) override const std::string &mappedName,
const std::vector<unsigned int> &arraySizes) override
{ {
auto uniformMapEntry = mUniformMapOut->find(name); auto uniformMapEntry = mUniformMapOut->find(name);
if (uniformMapEntry == mUniformMapOut->end()) if (uniformMapEntry == mUniformMapOut->end())
......
...@@ -240,6 +240,7 @@ ContextVk::ContextVk(const gl::State &state, gl::ErrorSet *errorSet, RendererVk ...@@ -240,6 +240,7 @@ ContextVk::ContextVk(const gl::State &state, gl::ErrorSet *errorSet, RendererVk
mIsAnyHostVisibleBufferWritten(false), mIsAnyHostVisibleBufferWritten(false),
mEmulateSeamfulCubeMapSampling(false), mEmulateSeamfulCubeMapSampling(false),
mEmulateSeamfulCubeMapSamplingWithSubgroupOps(false), mEmulateSeamfulCubeMapSamplingWithSubgroupOps(false),
mUseOldRewriteStructSamplers(false),
mLastCompletedQueueSerial(renderer->nextSerial()), mLastCompletedQueueSerial(renderer->nextSerial()),
mCurrentQueueSerial(renderer->nextSerial()), mCurrentQueueSerial(renderer->nextSerial()),
mPoolAllocator(kDefaultPoolAllocatorPageSize, 1), mPoolAllocator(kDefaultPoolAllocatorPageSize, 1),
...@@ -446,6 +447,8 @@ angle::Result ContextVk::initialize() ...@@ -446,6 +447,8 @@ angle::Result ContextVk::initialize()
mEmulateSeamfulCubeMapSampling = mEmulateSeamfulCubeMapSampling =
shouldEmulateSeamfulCubeMapSampling(&mEmulateSeamfulCubeMapSamplingWithSubgroupOps); shouldEmulateSeamfulCubeMapSampling(&mEmulateSeamfulCubeMapSamplingWithSubgroupOps);
mUseOldRewriteStructSamplers = shouldUseOldRewriteStructSamplers();
return angle::Result::Continue; return angle::Result::Continue;
} }
...@@ -2947,4 +2950,9 @@ bool ContextVk::shouldEmulateSeamfulCubeMapSampling(bool *useSubgroupOpsOut) con ...@@ -2947,4 +2950,9 @@ bool ContextVk::shouldEmulateSeamfulCubeMapSampling(bool *useSubgroupOpsOut) con
return true; return true;
} }
bool ContextVk::shouldUseOldRewriteStructSamplers() const
{
return mRenderer->getFeatures().forceOldRewriteStructSamplers.enabled;
}
} // namespace rx } // namespace rx
...@@ -333,6 +333,8 @@ class ContextVk : public ContextImpl, public vk::Context, public vk::RenderPassO ...@@ -333,6 +333,8 @@ class ContextVk : public ContextImpl, public vk::Context, public vk::RenderPassO
return mEmulateSeamfulCubeMapSampling; return mEmulateSeamfulCubeMapSampling;
} }
bool useOldRewriteStructSamplers() const { return mUseOldRewriteStructSamplers; }
private: private:
// Dirty bits. // Dirty bits.
enum DirtyBitType : size_t enum DirtyBitType : size_t
...@@ -492,6 +494,8 @@ class ContextVk : public ContextImpl, public vk::Context, public vk::RenderPassO ...@@ -492,6 +494,8 @@ class ContextVk : public ContextImpl, public vk::Context, public vk::RenderPassO
bool shouldEmulateSeamfulCubeMapSampling(bool *useSubgroupOpsOut) const; bool shouldEmulateSeamfulCubeMapSampling(bool *useSubgroupOpsOut) const;
bool shouldUseOldRewriteStructSamplers() const;
vk::PipelineHelper *mCurrentGraphicsPipeline; vk::PipelineHelper *mCurrentGraphicsPipeline;
vk::PipelineAndSerial *mCurrentComputePipeline; vk::PipelineAndSerial *mCurrentComputePipeline;
gl::PrimitiveMode mCurrentDrawMode; gl::PrimitiveMode mCurrentDrawMode;
...@@ -558,6 +562,10 @@ class ContextVk : public ContextImpl, public vk::Context, public vk::RenderPassO ...@@ -558,6 +562,10 @@ class ContextVk : public ContextImpl, public vk::Context, public vk::RenderPassO
bool mEmulateSeamfulCubeMapSampling; bool mEmulateSeamfulCubeMapSampling;
bool mEmulateSeamfulCubeMapSamplingWithSubgroupOps; bool mEmulateSeamfulCubeMapSamplingWithSubgroupOps;
// Whether this context should use the old version of the
// RewriteStructSamplers pass.
bool mUseOldRewriteStructSamplers;
struct DriverUniformsDescriptorSet struct DriverUniformsDescriptorSet
{ {
vk::DynamicBuffer dynamicBuffer; vk::DynamicBuffer dynamicBuffer;
......
...@@ -366,7 +366,7 @@ std::string IntermediateShaderSource::getShaderSource() ...@@ -366,7 +366,7 @@ std::string IntermediateShaderSource::getShaderSource()
return shaderSource; return shaderSource;
} }
std::string GetMappedSamplerName(const std::string &originalName) std::string GetMappedSamplerNameOld(const std::string &originalName)
{ {
std::string samplerName = gl::ParseResourceName(originalName, nullptr); std::string samplerName = gl::ParseResourceName(originalName, nullptr);
...@@ -777,7 +777,8 @@ void AssignBufferBindings(const gl::ProgramState &programState, ...@@ -777,7 +777,8 @@ void AssignBufferBindings(const gl::ProgramState &programState,
bindingStart, shaderSources); bindingStart, shaderSources);
} }
void AssignTextureBindings(const gl::ProgramState &programState, void AssignTextureBindings(bool useOldRewriteStructSamplers,
const gl::ProgramState &programState,
gl::ShaderMap<IntermediateShaderSource> *shaderSources) gl::ShaderMap<IntermediateShaderSource> *shaderSources)
{ {
const std::string texturesDescriptorSet = "set = " + Str(kTextureDescriptorSetIndex); const std::string texturesDescriptorSet = "set = " + Str(kTextureDescriptorSetIndex);
...@@ -789,18 +790,28 @@ void AssignTextureBindings(const gl::ProgramState &programState, ...@@ -789,18 +790,28 @@ void AssignTextureBindings(const gl::ProgramState &programState,
for (unsigned int uniformIndex : programState.getSamplerUniformRange()) for (unsigned int uniformIndex : programState.getSamplerUniformRange())
{ {
const gl::LinkedUniform &samplerUniform = uniforms[uniformIndex]; const gl::LinkedUniform &samplerUniform = uniforms[uniformIndex];
if (!useOldRewriteStructSamplers &&
vk::SamplerNameContainsNonZeroArrayElement(samplerUniform.name))
{
continue;
}
const std::string bindingString = const std::string bindingString =
texturesDescriptorSet + ", binding = " + Str(bindingIndex++); texturesDescriptorSet + ", binding = " + Str(bindingIndex++);
// Samplers in structs are extracted and renamed. // Samplers in structs are extracted and renamed.
const std::string samplerName = GetMappedSamplerName(samplerUniform.name); const std::string samplerName = useOldRewriteStructSamplers
? GetMappedSamplerNameOld(samplerUniform.name)
: vk::GetMappedSamplerName(samplerUniform.name);
AssignResourceBinding(samplerUniform.activeShaders(), samplerName, bindingString, AssignResourceBinding(samplerUniform.activeShaders(), samplerName, bindingString,
kUniformQualifier, kUnusedUniformSubstitution, shaderSources); kUniformQualifier, kUnusedUniformSubstitution, shaderSources);
} }
} }
void CleanupUnusedEntities(const gl::ProgramState &programState, void CleanupUnusedEntities(bool useOldRewriteStructSamplers,
const gl::ProgramState &programState,
const gl::ProgramLinkedResources &resources, const gl::ProgramLinkedResources &resources,
gl::Shader *glVertexShader, gl::Shader *glVertexShader,
gl::ShaderMap<IntermediateShaderSource> *shaderSources) gl::ShaderMap<IntermediateShaderSource> *shaderSources)
...@@ -847,8 +858,11 @@ void CleanupUnusedEntities(const gl::ProgramState &programState, ...@@ -847,8 +858,11 @@ void CleanupUnusedEntities(const gl::ProgramState &programState,
// uniforms to a single line. // uniforms to a single line.
for (const gl::UnusedUniform &unusedUniform : resources.unusedUniforms) for (const gl::UnusedUniform &unusedUniform : resources.unusedUniforms)
{ {
std::string uniformName = std::string uniformName = unusedUniform.isSampler
unusedUniform.isSampler ? GetMappedSamplerName(unusedUniform.name) : unusedUniform.name; ? useOldRewriteStructSamplers
? GetMappedSamplerNameOld(unusedUniform.name)
: vk::GetMappedSamplerName(unusedUniform.name)
: unusedUniform.name;
for (IntermediateShaderSource &shaderSource : *shaderSources) for (IntermediateShaderSource &shaderSource : *shaderSources)
{ {
...@@ -880,7 +894,8 @@ void GlslangWrapper::Release() ...@@ -880,7 +894,8 @@ void GlslangWrapper::Release()
} }
// static // static
void GlslangWrapper::GetShaderSource(const gl::ProgramState &programState, void GlslangWrapper::GetShaderSource(bool useOldRewriteStructSamplers,
const gl::ProgramState &programState,
const gl::ProgramLinkedResources &resources, const gl::ProgramLinkedResources &resources,
gl::ShaderMap<std::string> *shaderSourcesOut) gl::ShaderMap<std::string> *shaderSourcesOut)
{ {
...@@ -906,9 +921,9 @@ void GlslangWrapper::GetShaderSource(const gl::ProgramState &programState, ...@@ -906,9 +921,9 @@ void GlslangWrapper::GetShaderSource(const gl::ProgramState &programState,
} }
AssignUniformBindings(&intermediateSources); AssignUniformBindings(&intermediateSources);
AssignBufferBindings(programState, &intermediateSources); AssignBufferBindings(programState, &intermediateSources);
AssignTextureBindings(programState, &intermediateSources); AssignTextureBindings(useOldRewriteStructSamplers, programState, &intermediateSources);
CleanupUnusedEntities(programState, resources, CleanupUnusedEntities(useOldRewriteStructSamplers, programState, resources,
programState.getAttachedShader(gl::ShaderType::Vertex), programState.getAttachedShader(gl::ShaderType::Vertex),
&intermediateSources); &intermediateSources);
......
...@@ -22,7 +22,8 @@ class GlslangWrapper ...@@ -22,7 +22,8 @@ class GlslangWrapper
static void Initialize(); static void Initialize();
static void Release(); static void Release();
static void GetShaderSource(const gl::ProgramState &programState, static void GetShaderSource(bool useOldRewriteStructSamplers,
const gl::ProgramState &programState,
const gl::ProgramLinkedResources &resources, const gl::ProgramLinkedResources &resources,
gl::ShaderMap<std::string> *shaderSourcesOut); gl::ShaderMap<std::string> *shaderSourcesOut);
......
...@@ -503,7 +503,8 @@ std::unique_ptr<LinkEvent> ProgramVk::link(const gl::Context *context, ...@@ -503,7 +503,8 @@ std::unique_ptr<LinkEvent> ProgramVk::link(const gl::Context *context,
// assignment done in that function. // assignment done in that function.
linkResources(resources); linkResources(resources);
GlslangWrapper::GetShaderSource(mState, resources, &mShaderSources); GlslangWrapper::GetShaderSource(contextVk->useOldRewriteStructSamplers(), mState, resources,
&mShaderSources);
reset(contextVk); reset(contextVk);
...@@ -565,6 +566,7 @@ angle::Result ProgramVk::linkImpl(const gl::Context *glContext, gl::InfoLog &inf ...@@ -565,6 +566,7 @@ angle::Result ProgramVk::linkImpl(const gl::Context *glContext, gl::InfoLog &inf
// Textures: // Textures:
vk::DescriptorSetLayoutDesc texturesSetDesc; vk::DescriptorSetLayoutDesc texturesSetDesc;
uint32_t bindingIndex = 0;
for (uint32_t textureIndex = 0; textureIndex < mState.getSamplerBindings().size(); for (uint32_t textureIndex = 0; textureIndex < mState.getSamplerBindings().size();
++textureIndex) ++textureIndex)
...@@ -575,11 +577,27 @@ angle::Result ProgramVk::linkImpl(const gl::Context *glContext, gl::InfoLog &inf ...@@ -575,11 +577,27 @@ angle::Result ProgramVk::linkImpl(const gl::Context *glContext, gl::InfoLog &inf
const gl::LinkedUniform &samplerUniform = mState.getUniforms()[uniformIndex]; const gl::LinkedUniform &samplerUniform = mState.getUniforms()[uniformIndex];
// The front-end always binds array sampler units sequentially. // The front-end always binds array sampler units sequentially.
const uint32_t arraySize = static_cast<uint32_t>(samplerBinding.boundTextureUnits.size()); uint32_t arraySize = static_cast<uint32_t>(samplerBinding.boundTextureUnits.size());
VkShaderStageFlags activeStages = VkShaderStageFlags activeStages =
gl_vk::GetShaderStageFlags(samplerUniform.activeShaders()); gl_vk::GetShaderStageFlags(samplerUniform.activeShaders());
texturesSetDesc.update(textureIndex, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, arraySize, if (!contextVk->useOldRewriteStructSamplers())
{
// 2D arrays are split into multiple 1D arrays when generating
// LinkedUniforms. Since they are flattened into one array, ignore the
// nonzero elements and expand the array to the total array size.
if (vk::SamplerNameContainsNonZeroArrayElement(samplerUniform.name))
{
continue;
}
for (unsigned int outerArraySize : samplerUniform.outerArraySizes)
{
arraySize *= outerArraySize;
}
}
texturesSetDesc.update(bindingIndex++, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, arraySize,
activeStages); activeStages);
} }
...@@ -1459,6 +1477,12 @@ angle::Result ProgramVk::updateTexturesDescriptorSet(ContextVk *contextVk) ...@@ -1459,6 +1477,12 @@ angle::Result ProgramVk::updateTexturesDescriptorSet(ContextVk *contextVk)
bool useSubgroupOps = false; bool useSubgroupOps = false;
bool emulateSeamfulCubeMapSampling = contextVk->emulateSeamfulCubeMapSampling(&useSubgroupOps); bool emulateSeamfulCubeMapSampling = contextVk->emulateSeamfulCubeMapSampling(&useSubgroupOps);
bool useOldRewriteStructSamplers = contextVk->useOldRewriteStructSamplers();
std::unordered_map<std::string, uint32_t> mappedSamplerNameToBindingIndex;
std::unordered_map<std::string, uint32_t> mappedSamplerNameToArrayOffset;
uint32_t currentBindingIndex = 0;
for (uint32_t textureIndex = 0; textureIndex < mState.getSamplerBindings().size(); for (uint32_t textureIndex = 0; textureIndex < mState.getSamplerBindings().size();
++textureIndex) ++textureIndex)
...@@ -1467,8 +1491,30 @@ angle::Result ProgramVk::updateTexturesDescriptorSet(ContextVk *contextVk) ...@@ -1467,8 +1491,30 @@ angle::Result ProgramVk::updateTexturesDescriptorSet(ContextVk *contextVk)
ASSERT(!samplerBinding.unreferenced); ASSERT(!samplerBinding.unreferenced);
for (uint32_t arrayElement = 0; arrayElement < samplerBinding.boundTextureUnits.size(); uint32_t uniformIndex = mState.getUniformIndexFromSamplerIndex(textureIndex);
++arrayElement) const gl::LinkedUniform &samplerUniform = mState.getUniforms()[uniformIndex];
std::string mappedSamplerName = vk::GetMappedSamplerName(samplerUniform.name);
if (useOldRewriteStructSamplers ||
mappedSamplerNameToBindingIndex.emplace(mappedSamplerName, currentBindingIndex).second)
{
currentBindingIndex++;
}
uint32_t bindingIndex = textureIndex;
uint32_t arrayOffset = 0;
uint32_t arraySize = static_cast<uint32_t>(samplerBinding.boundTextureUnits.size());
if (!useOldRewriteStructSamplers)
{
bindingIndex = mappedSamplerNameToBindingIndex[mappedSamplerName];
arrayOffset = mappedSamplerNameToArrayOffset[mappedSamplerName];
// Front-end generates array elements in order, so we can just increment
// the offset each time we process a nested array.
mappedSamplerNameToArrayOffset[mappedSamplerName] += arraySize;
}
for (uint32_t arrayElement = 0; arrayElement < arraySize; ++arrayElement)
{ {
GLuint textureUnit = samplerBinding.boundTextureUnits[arrayElement]; GLuint textureUnit = samplerBinding.boundTextureUnits[arrayElement];
TextureVk *textureVk = activeTextures[textureUnit].texture; TextureVk *textureVk = activeTextures[textureUnit].texture;
...@@ -1496,8 +1542,8 @@ angle::Result ProgramVk::updateTexturesDescriptorSet(ContextVk *contextVk) ...@@ -1496,8 +1542,8 @@ angle::Result ProgramVk::updateTexturesDescriptorSet(ContextVk *contextVk)
writeInfo.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; writeInfo.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
writeInfo.pNext = nullptr; writeInfo.pNext = nullptr;
writeInfo.dstSet = descriptorSet; writeInfo.dstSet = descriptorSet;
writeInfo.dstBinding = textureIndex; writeInfo.dstBinding = bindingIndex;
writeInfo.dstArrayElement = arrayElement; writeInfo.dstArrayElement = arrayOffset + arrayElement;
writeInfo.descriptorCount = 1; writeInfo.descriptorCount = 1;
writeInfo.descriptorType = VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER; writeInfo.descriptorType = VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER;
writeInfo.pImageInfo = &imageInfo; writeInfo.pImageInfo = &imageInfo;
......
...@@ -1266,6 +1266,8 @@ void RendererVk::initFeatures(const ExtensionNameList &deviceExtensionNames) ...@@ -1266,6 +1266,8 @@ void RendererVk::initFeatures(const ExtensionNameList &deviceExtensionNames)
} }
mFeatures.bindEmptyForUnusedDescriptorSets.enabled = true; mFeatures.bindEmptyForUnusedDescriptorSets.enabled = true;
mFeatures.forceOldRewriteStructSamplers.enabled = true;
} }
if (IsWindows() && IsIntel(mPhysicalDeviceProperties.vendorID)) if (IsWindows() && IsIntel(mPhysicalDeviceProperties.vendorID))
......
...@@ -53,6 +53,11 @@ std::shared_ptr<WaitableCompileEvent> ShaderVk::compile(const gl::Context *conte ...@@ -53,6 +53,11 @@ std::shared_ptr<WaitableCompileEvent> ShaderVk::compile(const gl::Context *conte
} }
} }
if (contextVk->useOldRewriteStructSamplers())
{
compileOptions |= SH_USE_OLD_REWRITE_STRUCT_SAMPLERS;
}
return compileImpl(context, compilerInstance, mData.getSource(), compileOptions | options); return compileImpl(context, compilerInstance, mData.getSource(), compileOptions | options);
} }
......
...@@ -526,6 +526,57 @@ bool GarbageObject::destroyIfComplete(VkDevice device, Serial completedSerial) ...@@ -526,6 +526,57 @@ bool GarbageObject::destroyIfComplete(VkDevice device, Serial completedSerial)
return false; return false;
} }
bool SamplerNameContainsNonZeroArrayElement(const std::string &name)
{
constexpr char kZERO_ELEMENT[] = "[0]";
size_t start = 0;
while (true)
{
start = name.find(kZERO_ELEMENT[0], start);
if (start == std::string::npos)
{
break;
}
if (name.compare(start, strlen(kZERO_ELEMENT), kZERO_ELEMENT) != 0)
{
return true;
}
start++;
}
return false;
}
std::string GetMappedSamplerName(const std::string &originalName)
{
std::string samplerName = originalName;
// Samplers in structs are extracted.
std::replace(samplerName.begin(), samplerName.end(), '.', '_');
// Remove array elements
auto out = samplerName.begin();
for (auto in = samplerName.begin(); in != samplerName.end(); in++)
{
if (*in == '[')
{
while (*in != ']')
{
in++;
ASSERT(in != samplerName.end());
}
}
else
{
*out++ = *in;
}
}
samplerName.erase(out, samplerName.end());
return samplerName;
}
} // namespace vk } // namespace vk
// VK_EXT_debug_utils // VK_EXT_debug_utils
......
...@@ -528,6 +528,9 @@ class Recycler final : angle::NonCopyable ...@@ -528,6 +528,9 @@ class Recycler final : angle::NonCopyable
std::vector<T> mObjectFreeList; std::vector<T> mObjectFreeList;
}; };
bool SamplerNameContainsNonZeroArrayElement(const std::string &name);
std::string GetMappedSamplerName(const std::string &originalName);
} // namespace vk } // namespace vk
// List of function pointers for used extensions. // List of function pointers for used extensions.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment