Commit 77222c97 by maxvujovic@gmail.com

Apply SH_TIMING_RESTRICTIONS to all samplers.

Issue: 332 Review URL: https://codereview.appspot.com/6273044/ git-svn-id: https://angleproject.googlecode.com/svn/trunk@1131 736b8ea6-26fd-11df-bfd4-992fa37f6226
parent 911cd6d6
...@@ -165,12 +165,8 @@ bool TCompiler::compile(const char* const shaderStrings[], ...@@ -165,12 +165,8 @@ bool TCompiler::compile(const char* const shaderStrings[],
if (success && (compileOptions & SH_VALIDATE_LOOP_INDEXING)) if (success && (compileOptions & SH_VALIDATE_LOOP_INDEXING))
success = validateLimitations(root); success = validateLimitations(root);
// FIXME(mvujovic): For now, we only consider "u_texture" to be a potentially unsafe symbol.
// If we end up using timing restrictions in WebGL and CSS Shaders, we should expose an API
// to pass in the names of other potentially unsafe symbols (e.g. uniforms referencing
// cross-domain textures).
if (success && (compileOptions & SH_TIMING_RESTRICTIONS)) if (success && (compileOptions & SH_TIMING_RESTRICTIONS))
success = enforceTimingRestrictions(root, "u_texture", (compileOptions & SH_DEPENDENCY_GRAPH) != 0); success = enforceTimingRestrictions(root, (compileOptions & SH_DEPENDENCY_GRAPH) != 0);
// Unroll for-loop markup needs to happen after validateLimitations pass. // Unroll for-loop markup needs to happen after validateLimitations pass.
if (success && (compileOptions & SH_UNROLL_FOR_LOOP_WITH_INTEGER_INDEX)) if (success && (compileOptions & SH_UNROLL_FOR_LOOP_WITH_INTEGER_INDEX))
...@@ -252,9 +248,7 @@ bool TCompiler::validateLimitations(TIntermNode* root) { ...@@ -252,9 +248,7 @@ bool TCompiler::validateLimitations(TIntermNode* root) {
return validate.numErrors() == 0; return validate.numErrors() == 0;
} }
bool TCompiler::enforceTimingRestrictions(TIntermNode* root, bool TCompiler::enforceTimingRestrictions(TIntermNode* root, bool outputGraph)
const TString& restrictedSymbol,
bool outputGraph)
{ {
if (shaderSpec != SH_WEBGL_SPEC) { if (shaderSpec != SH_WEBGL_SPEC) {
infoSink.info << "Timing restrictions must be enforced under the WebGL spec."; infoSink.info << "Timing restrictions must be enforced under the WebGL spec.";
...@@ -265,7 +259,7 @@ bool TCompiler::enforceTimingRestrictions(TIntermNode* root, ...@@ -265,7 +259,7 @@ bool TCompiler::enforceTimingRestrictions(TIntermNode* root,
TDependencyGraph graph(root); TDependencyGraph graph(root);
// Output any errors first. // Output any errors first.
bool success = enforceFragmentShaderTimingRestrictions(graph, restrictedSymbol); bool success = enforceFragmentShaderTimingRestrictions(graph);
// Then, output the dependency graph. // Then, output the dependency graph.
if (outputGraph) { if (outputGraph) {
...@@ -276,22 +270,20 @@ bool TCompiler::enforceTimingRestrictions(TIntermNode* root, ...@@ -276,22 +270,20 @@ bool TCompiler::enforceTimingRestrictions(TIntermNode* root,
return success; return success;
} }
else { else {
return enforceVertexShaderTimingRestrictions(root, restrictedSymbol); return enforceVertexShaderTimingRestrictions(root);
} }
} }
bool TCompiler::enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph, bool TCompiler::enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph)
const TString& restrictedSymbol)
{ {
RestrictFragmentShaderTiming restrictor(infoSink.info, restrictedSymbol); RestrictFragmentShaderTiming restrictor(infoSink.info);
restrictor.enforceRestrictions(graph); restrictor.enforceRestrictions(graph);
return restrictor.numErrors() == 0; return restrictor.numErrors() == 0;
} }
bool TCompiler::enforceVertexShaderTimingRestrictions(TIntermNode* root, bool TCompiler::enforceVertexShaderTimingRestrictions(TIntermNode* root)
const TString& restrictedSymbol)
{ {
RestrictVertexShaderTiming restrictor(infoSink.info, restrictedSymbol); RestrictVertexShaderTiming restrictor(infoSink.info);
restrictor.enforceRestrictions(root); restrictor.enforceRestrictions(root);
return restrictor.numErrors() == 0; return restrictor.numErrors() == 0;
} }
......
...@@ -81,16 +81,12 @@ protected: ...@@ -81,16 +81,12 @@ protected:
// Translate to object code. // Translate to object code.
virtual void translate(TIntermNode* root) = 0; virtual void translate(TIntermNode* root) = 0;
// Returns true if the shader passes the restrictions that aim to prevent timing attacks. // Returns true if the shader passes the restrictions that aim to prevent timing attacks.
bool enforceTimingRestrictions(TIntermNode* root, bool enforceTimingRestrictions(TIntermNode* root, bool outputGraph);
const TString& restrictedSymbol, // Returns true if the shader does not use samplers.
bool outputGraph); bool enforceVertexShaderTimingRestrictions(TIntermNode* root);
// Returns true if the shader does not define the restricted symbol. // Returns true if the shader does not use sampler dependent values to affect control
bool enforceVertexShaderTimingRestrictions(TIntermNode* root, // flow or in operations whose time can depend on the input values.
const TString& restrictedSymbol); bool enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph);
// Returns true if the shader does not use the restricted symbol to affect control flow or in
// operations whose time can depend on the input values.
bool enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph,
const TString& restrictedSymbol);
// Get built-in extensions with default behavior. // Get built-in extensions with default behavior.
const TExtensionBehavior& getExtensionBehavior() const; const TExtensionBehavior& getExtensionBehavior() const;
......
...@@ -23,17 +23,6 @@ TDependencyGraph::~TDependencyGraph() ...@@ -23,17 +23,6 @@ TDependencyGraph::~TDependencyGraph()
} }
} }
TGraphSymbol* TDependencyGraph::getGlobalSymbolByName(const TString& name) const
{
TSymbolNameMap::const_iterator iter = mGlobalSymbolMap.find(name);
if (iter == mGlobalSymbolMap.end())
return NULL;
TSymbolNamePair pair = *iter;
TGraphSymbol* symbol = pair.second;
return symbol;
}
TGraphArgument* TDependencyGraph::createArgument(TIntermAggregate* intermFunctionCall, TGraphArgument* TDependencyGraph::createArgument(TIntermAggregate* intermFunctionCall,
int argumentNumber) int argumentNumber)
{ {
...@@ -51,7 +40,7 @@ TGraphFunctionCall* TDependencyGraph::createFunctionCall(TIntermAggregate* inter ...@@ -51,7 +40,7 @@ TGraphFunctionCall* TDependencyGraph::createFunctionCall(TIntermAggregate* inter
return functionCall; return functionCall;
} }
TGraphSymbol* TDependencyGraph::getOrCreateSymbol(TIntermSymbol* intermSymbol, bool isGlobalSymbol) TGraphSymbol* TDependencyGraph::getOrCreateSymbol(TIntermSymbol* intermSymbol)
{ {
TSymbolIdMap::const_iterator iter = mSymbolIdMap.find(intermSymbol->getId()); TSymbolIdMap::const_iterator iter = mSymbolIdMap.find(intermSymbol->getId());
...@@ -67,12 +56,9 @@ TGraphSymbol* TDependencyGraph::getOrCreateSymbol(TIntermSymbol* intermSymbol, b ...@@ -67,12 +56,9 @@ TGraphSymbol* TDependencyGraph::getOrCreateSymbol(TIntermSymbol* intermSymbol, b
TSymbolIdPair pair(intermSymbol->getId(), symbol); TSymbolIdPair pair(intermSymbol->getId(), symbol);
mSymbolIdMap.insert(pair); mSymbolIdMap.insert(pair);
if (isGlobalSymbol) { // We save all sampler symbols in a collection, so we can start graph traversals from them quickly.
// We map all symbols in the global scope by name, so traversers of the graph can if (IsSampler(intermSymbol->getBasicType()))
// quickly start searches at global symbols with specific names. mSamplerSymbols.push_back(symbol);
TSymbolNamePair pair(intermSymbol->getSymbol(), symbol);
mGlobalSymbolMap.insert(pair);
}
} }
return symbol; return symbol;
......
...@@ -25,6 +25,7 @@ class TDependencyGraphOutput; ...@@ -25,6 +25,7 @@ class TDependencyGraphOutput;
typedef std::set<TGraphNode*> TGraphNodeSet; typedef std::set<TGraphNode*> TGraphNodeSet;
typedef std::vector<TGraphNode*> TGraphNodeVector; typedef std::vector<TGraphNode*> TGraphNodeVector;
typedef std::vector<TGraphSymbol*> TGraphSymbolVector;
typedef std::vector<TGraphFunctionCall*> TFunctionCallVector; typedef std::vector<TGraphFunctionCall*> TFunctionCallVector;
// //
...@@ -142,6 +143,16 @@ public: ...@@ -142,6 +143,16 @@ public:
TGraphNodeVector::const_iterator begin() const { return mAllNodes.begin(); } TGraphNodeVector::const_iterator begin() const { return mAllNodes.begin(); }
TGraphNodeVector::const_iterator end() const { return mAllNodes.end(); } TGraphNodeVector::const_iterator end() const { return mAllNodes.end(); }
TGraphSymbolVector::const_iterator beginSamplerSymbols() const
{
return mSamplerSymbols.begin();
}
TGraphSymbolVector::const_iterator endSamplerSymbols() const
{
return mSamplerSymbols.end();
}
TFunctionCallVector::const_iterator beginUserDefinedFunctionCalls() const TFunctionCallVector::const_iterator beginUserDefinedFunctionCalls() const
{ {
return mUserDefinedFunctionCalls.begin(); return mUserDefinedFunctionCalls.begin();
...@@ -152,12 +163,9 @@ public: ...@@ -152,12 +163,9 @@ public:
return mUserDefinedFunctionCalls.end(); return mUserDefinedFunctionCalls.end();
} }
// Returns NULL if the symbol is not found.
TGraphSymbol* getGlobalSymbolByName(const TString& name) const;
TGraphArgument* createArgument(TIntermAggregate* intermFunctionCall, int argumentNumber); TGraphArgument* createArgument(TIntermAggregate* intermFunctionCall, int argumentNumber);
TGraphFunctionCall* createFunctionCall(TIntermAggregate* intermFunctionCall); TGraphFunctionCall* createFunctionCall(TIntermAggregate* intermFunctionCall);
TGraphSymbol* getOrCreateSymbol(TIntermSymbol* intermSymbol, bool isGlobalSymbol); TGraphSymbol* getOrCreateSymbol(TIntermSymbol* intermSymbol);
TGraphSelection* createSelection(TIntermSelection* intermSelection); TGraphSelection* createSelection(TIntermSelection* intermSelection);
TGraphLoop* createLoop(TIntermLoop* intermLoop); TGraphLoop* createLoop(TIntermLoop* intermLoop);
TGraphLogicalOp* createLogicalOp(TIntermBinary* intermLogicalOp); TGraphLogicalOp* createLogicalOp(TIntermBinary* intermLogicalOp);
...@@ -165,13 +173,10 @@ private: ...@@ -165,13 +173,10 @@ private:
typedef TMap<int, TGraphSymbol*> TSymbolIdMap; typedef TMap<int, TGraphSymbol*> TSymbolIdMap;
typedef std::pair<int, TGraphSymbol*> TSymbolIdPair; typedef std::pair<int, TGraphSymbol*> TSymbolIdPair;
typedef TMap<TString, TGraphSymbol*> TSymbolNameMap;
typedef std::pair<TString, TGraphSymbol*> TSymbolNamePair;
TSymbolIdMap mSymbolIdMap;
TSymbolNameMap mGlobalSymbolMap;
TFunctionCallVector mUserDefinedFunctionCalls;
TGraphNodeVector mAllNodes; TGraphNodeVector mAllNodes;
TGraphSymbolVector mSamplerSymbols;
TFunctionCallVector mUserDefinedFunctionCalls;
TSymbolIdMap mSymbolIdMap;
}; };
// //
......
...@@ -31,18 +31,11 @@ bool TDependencyGraphBuilder::visitAggregate(Visit visit, TIntermAggregate* inte ...@@ -31,18 +31,11 @@ bool TDependencyGraphBuilder::visitAggregate(Visit visit, TIntermAggregate* inte
void TDependencyGraphBuilder::visitFunctionDefinition(TIntermAggregate* intermAggregate) void TDependencyGraphBuilder::visitFunctionDefinition(TIntermAggregate* intermAggregate)
{ {
// Function defintions should only exist in the global scope.
ASSERT(mIsGlobalScope);
// Currently, we do not support user defined functions. // Currently, we do not support user defined functions.
if (intermAggregate->getName() != "main(") if (intermAggregate->getName() != "main(")
return; return;
mIsGlobalScope = false;
visitAggregateChildren(intermAggregate); visitAggregateChildren(intermAggregate);
mIsGlobalScope = true;
} }
// Takes an expression like "f(x)" and creates a dependency graph like // Takes an expression like "f(x)" and creates a dependency graph like
...@@ -93,7 +86,7 @@ void TDependencyGraphBuilder::visitSymbol(TIntermSymbol* intermSymbol) ...@@ -93,7 +86,7 @@ void TDependencyGraphBuilder::visitSymbol(TIntermSymbol* intermSymbol)
{ {
// Push this symbol into the set of dependent symbols for the current assignment or condition // Push this symbol into the set of dependent symbols for the current assignment or condition
// that we are traversing. // that we are traversing.
TGraphSymbol* symbol = mGraph->getOrCreateSymbol(intermSymbol, mIsGlobalScope); TGraphSymbol* symbol = mGraph->getOrCreateSymbol(intermSymbol);
mNodeSets.insertIntoTopSet(symbol); mNodeSets.insertIntoTopSet(symbol);
// If this symbol is the current leftmost symbol under an assignment, replace the previous // If this symbol is the current leftmost symbol under an assignment, replace the previous
......
...@@ -164,8 +164,7 @@ private: ...@@ -164,8 +164,7 @@ private:
TDependencyGraphBuilder(TDependencyGraph* graph) TDependencyGraphBuilder(TDependencyGraph* graph)
: TIntermTraverser(true, false, false) : TIntermTraverser(true, false, false)
, mGraph(graph) , mGraph(graph) {}
, mIsGlobalScope(true) {}
void build(TIntermNode* intermNode) { intermNode->traverse(this); } void build(TIntermNode* intermNode) { intermNode->traverse(this); }
void connectMultipleNodesToSingleNode(TParentNodeSet* nodes, TGraphNode* node) const; void connectMultipleNodesToSingleNode(TParentNodeSet* nodes, TGraphNode* node) const;
...@@ -180,7 +179,6 @@ private: ...@@ -180,7 +179,6 @@ private:
TDependencyGraph* mGraph; TDependencyGraph* mGraph;
TNodeSetStack mNodeSets; TNodeSetStack mNodeSets;
TSymbolStack mLeftmostSymbols; TSymbolStack mLeftmostSymbols;
bool mIsGlobalScope;
}; };
#endif // COMPILER_DEPGRAPH_DEPENDENCY_GRAPH_BUILDER_H #endif // COMPILER_DEPGRAPH_DEPENDENCY_GRAPH_BUILDER_H
...@@ -19,13 +19,16 @@ void RestrictFragmentShaderTiming::enforceRestrictions(const TDependencyGraph& g ...@@ -19,13 +19,16 @@ void RestrictFragmentShaderTiming::enforceRestrictions(const TDependencyGraph& g
// so we generate errors for them. // so we generate errors for them.
validateUserDefinedFunctionCallUsage(graph); validateUserDefinedFunctionCallUsage(graph);
// Traverse the dependency graph starting at s_texture and generate an error each time we hit a // Starting from each sampler, traverse the dependency graph and generate an error each time we
// condition node. // hit a node where sampler dependent values are not allowed.
TGraphSymbol* uTextureGraphSymbol = graph.getGlobalSymbolByName(mRestrictedSymbol); for (TGraphSymbolVector::const_iterator iter = graph.beginSamplerSymbols();
if (uTextureGraphSymbol && iter != graph.endSamplerSymbols();
uTextureGraphSymbol->getIntermSymbol()->getQualifier() == EvqUniform && ++iter)
uTextureGraphSymbol->getIntermSymbol()->getBasicType() == EbtSampler2D) {
uTextureGraphSymbol->traverse(this); TGraphSymbol* samplerSymbol = *iter;
clearVisited();
samplerSymbol->traverse(this);
}
} }
void RestrictFragmentShaderTiming::validateUserDefinedFunctionCallUsage(const TDependencyGraph& graph) void RestrictFragmentShaderTiming::validateUserDefinedFunctionCallUsage(const TDependencyGraph& graph)
...@@ -50,35 +53,33 @@ void RestrictFragmentShaderTiming::beginError(const TIntermNode* node) ...@@ -50,35 +53,33 @@ void RestrictFragmentShaderTiming::beginError(const TIntermNode* node)
void RestrictFragmentShaderTiming::visitArgument(TGraphArgument* parameter) void RestrictFragmentShaderTiming::visitArgument(TGraphArgument* parameter)
{ {
// FIXME(mvujovic): We should restrict sampler dependent values from being texture coordinates // FIXME(mvujovic): We should restrict sampler dependent values from being texture coordinates
// in all available sampling operationsn supported in GLSL ES. // in all available sampling operations supported in GLSL ES.
// This includes overloaded signatures of texture2D, textureCube, and others. // This includes overloaded signatures of texture2D, textureCube, and others.
if (parameter->getIntermFunctionCall()->getName() != "texture2D(s21;vf2;" || if (parameter->getIntermFunctionCall()->getName() != "texture2D(s21;vf2;" ||
parameter->getArgumentNumber() != 1) parameter->getArgumentNumber() != 1)
return; return;
beginError(parameter->getIntermFunctionCall()); beginError(parameter->getIntermFunctionCall());
mSink << "An expression dependent on a uniform sampler2D by the name '" << mRestrictedSymbol mSink << "An expression dependent on a sampler is not permitted to be the second argument"
<< "' is not permitted to be the second argument of a texture2D call.\n"; << " of a texture2D call.\n";
} }
void RestrictFragmentShaderTiming::visitSelection(TGraphSelection* selection) void RestrictFragmentShaderTiming::visitSelection(TGraphSelection* selection)
{ {
beginError(selection->getIntermSelection()); beginError(selection->getIntermSelection());
mSink << "An expression dependent on a uniform sampler2D by the name '" << mRestrictedSymbol mSink << "An expression dependent on a sampler is not permitted in a conditional statement.\n";
<< "' is not permitted in a conditional statement.\n";
} }
void RestrictFragmentShaderTiming::visitLoop(TGraphLoop* loop) void RestrictFragmentShaderTiming::visitLoop(TGraphLoop* loop)
{ {
beginError(loop->getIntermLoop()); beginError(loop->getIntermLoop());
mSink << "An expression dependent on a uniform sampler2D by the name '" << mRestrictedSymbol mSink << "An expression dependent on a sampler is not permitted in a loop condition.\n";
<< "' is not permitted in a loop condition.\n";
} }
void RestrictFragmentShaderTiming::visitLogicalOp(TGraphLogicalOp* logicalOp) void RestrictFragmentShaderTiming::visitLogicalOp(TGraphLogicalOp* logicalOp)
{ {
beginError(logicalOp->getIntermLogicalOp()); beginError(logicalOp->getIntermLogicalOp());
mSink << "An expression dependent on a uniform sampler2D by the name '" << mRestrictedSymbol mSink << "An expression dependent on a sampler is not permitted on the left hand side of a logical "
<< "' is not permitted on the left hand side of a logical " << logicalOp->getOpString() << logicalOp->getOpString()
<< " operator.\n"; << " operator.\n";
} }
...@@ -16,9 +16,8 @@ class TInfoSinkBase; ...@@ -16,9 +16,8 @@ class TInfoSinkBase;
class RestrictFragmentShaderTiming : TDependencyGraphTraverser { class RestrictFragmentShaderTiming : TDependencyGraphTraverser {
public: public:
RestrictFragmentShaderTiming(TInfoSinkBase& sink, const TString& restrictedSymbol) RestrictFragmentShaderTiming(TInfoSinkBase& sink)
: mSink(sink) : mSink(sink)
, mRestrictedSymbol(restrictedSymbol)
, mNumErrors(0) {} , mNumErrors(0) {}
void enforceRestrictions(const TDependencyGraph& graph); void enforceRestrictions(const TDependencyGraph& graph);
...@@ -34,7 +33,6 @@ private: ...@@ -34,7 +33,6 @@ private:
void validateUserDefinedFunctionCallUsage(const TDependencyGraph& graph); void validateUserDefinedFunctionCallUsage(const TDependencyGraph& graph);
TInfoSinkBase& mSink; TInfoSinkBase& mSink;
const TString mRestrictedSymbol;
int mNumErrors; int mNumErrors;
}; };
......
...@@ -8,23 +8,10 @@ ...@@ -8,23 +8,10 @@
void RestrictVertexShaderTiming::visitSymbol(TIntermSymbol* node) void RestrictVertexShaderTiming::visitSymbol(TIntermSymbol* node)
{ {
if (node->getQualifier() == EvqUniform && if (IsSampler(node->getBasicType())) {
node->getBasicType() == EbtSampler2D && ++mNumErrors;
node->getSymbol() == mRestrictedSymbol) {
mFoundRestrictedSymbol = true;
mSink.prefix(EPrefixError); mSink.prefix(EPrefixError);
mSink.location(node->getLine()); mSink.location(node->getLine());
mSink << "Definition of a uniform sampler2D by the name '" << mRestrictedSymbol mSink << "Samplers are not permitted in vertex shaders.\n";
<< "' is not permitted in vertex shaders.\n";
} }
} }
bool RestrictVertexShaderTiming::visitAggregate(Visit visit, TIntermAggregate* node)
{
// Don't keep exploring if we've found the restricted symbol, and don't explore anything besides
// the global scope (i.e. don't explore function definitions).
if (mFoundRestrictedSymbol || node->getOp() == EOpFunction)
return false;
return true;
}
...@@ -16,26 +16,18 @@ class TInfoSinkBase; ...@@ -16,26 +16,18 @@ class TInfoSinkBase;
class RestrictVertexShaderTiming : public TIntermTraverser { class RestrictVertexShaderTiming : public TIntermTraverser {
public: public:
RestrictVertexShaderTiming(TInfoSinkBase& sink, const TString& restrictedSymbol) RestrictVertexShaderTiming(TInfoSinkBase& sink)
: TIntermTraverser(true, false, false) : TIntermTraverser(true, false, false)
, mSink(sink) , mSink(sink)
, mRestrictedSymbol(restrictedSymbol) , mNumErrors(0) {}
, mFoundRestrictedSymbol(false) {}
void enforceRestrictions(TIntermNode* root) { root->traverse(this); } void enforceRestrictions(TIntermNode* root) { root->traverse(this); }
int numErrors() { return mFoundRestrictedSymbol ? 1 : 0; } int numErrors() { return mNumErrors; }
virtual void visitSymbol(TIntermSymbol*); virtual void visitSymbol(TIntermSymbol*);
virtual bool visitBinary(Visit visit, TIntermBinary*) { return false; }
virtual bool visitUnary(Visit visit, TIntermUnary*) { return false; }
virtual bool visitSelection(Visit visit, TIntermSelection*) { return false; }
virtual bool visitAggregate(Visit visit, TIntermAggregate*);
virtual bool visitLoop(Visit visit, TIntermLoop*) { return false; };
virtual bool visitBranch(Visit visit, TIntermBranch*) { return false; };
private: private:
TInfoSinkBase& mSink; TInfoSinkBase& mSink;
const TString mRestrictedSymbol; int mNumErrors;
bool mFoundRestrictedSymbol;
}; };
#endif // COMPILER_TIMING_RESTRICT_VERTEX_SHADER_TIMING_H_ #endif // COMPILER_TIMING_RESTRICT_VERTEX_SHADER_TIMING_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment