Commit b6fa043d by Olli Etuaho Committed by Commit Bot

Split vector swizzle AST nodes into a different node class

This avoids creating a weird aggregate node with a sequence of constant union nodes to store the offsets. They're stored neatly inside a vector instead. This makes code that needs to iterate over the swizzle offsets much simpler. BUG=angleproject:1490 TEST=angle_unittests Change-Id: I156b95723529ee05a94d30295ffb6d0952a98564 Reviewed-on: https://chromium-review.googlesource.com/390832Reviewed-by: 's avatarJamie Madill <jmadill@chromium.org> Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent fde9d8c1
...@@ -511,7 +511,7 @@ bool EmulatePrecision::visitBinary(Visit visit, TIntermBinary *node) ...@@ -511,7 +511,7 @@ bool EmulatePrecision::visitBinary(Visit visit, TIntermBinary *node)
if (op == EOpInitialize && visit == InVisit) if (op == EOpInitialize && visit == InVisit)
mDeclaringVariables = false; mDeclaringVariables = false;
if ((op == EOpIndexDirectStruct || op == EOpVectorSwizzle) && visit == InVisit) if ((op == EOpIndexDirectStruct) && visit == InVisit)
visitChildren = false; visitChildren = false;
if (visit != PreVisit) if (visit != PreVisit)
......
...@@ -181,6 +181,13 @@ bool TIntermBranch::replaceChildNode( ...@@ -181,6 +181,13 @@ bool TIntermBranch::replaceChildNode(
return false; return false;
} }
bool TIntermSwizzle::replaceChildNode(TIntermNode *original, TIntermNode *replacement)
{
ASSERT(original->getAsTyped()->getType() == replacement->getAsTyped()->getType());
REPLACE_IF_IS(mOperand, TIntermTyped, original, replacement);
return false;
}
bool TIntermBinary::replaceChildNode( bool TIntermBinary::replaceChildNode(
TIntermNode *original, TIntermNode *replacement) TIntermNode *original, TIntermNode *replacement)
{ {
...@@ -445,6 +452,13 @@ TIntermAggregate::TIntermAggregate(const TIntermAggregate &node) ...@@ -445,6 +452,13 @@ TIntermAggregate::TIntermAggregate(const TIntermAggregate &node)
} }
} }
TIntermSwizzle::TIntermSwizzle(const TIntermSwizzle &node) : TIntermTyped(node)
{
TIntermTyped *operandCopy = node.mOperand->deepCopy();
ASSERT(operandCopy != nullptr);
mOperand = operandCopy;
}
TIntermBinary::TIntermBinary(const TIntermBinary &node) TIntermBinary::TIntermBinary(const TIntermBinary &node)
: TIntermOperator(node), mAddIndexClamp(node.mAddIndexClamp) : TIntermOperator(node), mAddIndexClamp(node.mAddIndexClamp)
{ {
...@@ -682,6 +696,15 @@ void TIntermUnary::promote() ...@@ -682,6 +696,15 @@ void TIntermUnary::promote()
} }
} }
TIntermSwizzle::TIntermSwizzle(TIntermTyped *operand, const TVector<int> &swizzleOffsets)
: TIntermTyped(TType(EbtFloat, EbpUndefined)),
mOperand(operand),
mSwizzleOffsets(swizzleOffsets)
{
ASSERT(mSwizzleOffsets.size() <= 4);
promote();
}
TIntermUnary::TIntermUnary(TOperator op, TIntermTyped *operand) TIntermUnary::TIntermUnary(TOperator op, TIntermTyped *operand)
: TIntermOperator(op), mOperand(operand), mUseEmulatedFunction(false) : TIntermOperator(op), mOperand(operand), mUseEmulatedFunction(false)
{ {
...@@ -719,13 +742,57 @@ TQualifier TIntermTernary::DetermineQualifier(TIntermTyped *cond, ...@@ -719,13 +742,57 @@ TQualifier TIntermTernary::DetermineQualifier(TIntermTyped *cond,
return EvqTemporary; return EvqTemporary;
} }
// void TIntermSwizzle::promote()
// Establishes the type of the resultant operation, as well as {
// makes the operator the correct one for the operands. TQualifier resultQualifier = EvqTemporary;
// if (mOperand->getQualifier() == EvqConst)
// For lots of operations it should already be established that the operand resultQualifier = EvqConst;
// combination is valid, but returns false if operator can't work on operands.
// auto numFields = mSwizzleOffsets.size();
setType(TType(mOperand->getBasicType(), mOperand->getPrecision(), resultQualifier,
static_cast<unsigned char>(numFields)));
}
bool TIntermSwizzle::hasDuplicateOffsets() const
{
int offsetCount[4] = {0u, 0u, 0u, 0u};
for (const auto offset : mSwizzleOffsets)
{
offsetCount[offset]++;
if (offsetCount[offset] > 1)
{
return true;
}
}
return false;
}
void TIntermSwizzle::writeOffsetsAsXYZW(TInfoSinkBase *out) const
{
for (const int offset : mSwizzleOffsets)
{
switch (offset)
{
case 0:
*out << "x";
break;
case 1:
*out << "y";
break;
case 2:
*out << "z";
break;
case 3:
*out << "w";
break;
default:
UNREACHABLE();
}
}
}
// Establishes the type of the result of the binary operation.
void TIntermBinary::promote() void TIntermBinary::promote()
{ {
ASSERT(!isMultiplication() || ASSERT(!isMultiplication() ||
...@@ -783,13 +850,6 @@ void TIntermBinary::promote() ...@@ -783,13 +850,6 @@ void TIntermBinary::promote()
getTypePointer()->setQualifier(resultQualifier); getTypePointer()->setQualifier(resultQualifier);
return; return;
} }
case EOpVectorSwizzle:
{
auto numFields = mRight->getAsAggregate()->getSequence()->size();
setType(TType(mLeft->getBasicType(), mLeft->getPrecision(), resultQualifier,
static_cast<unsigned char>(numFields)));
return;
}
default: default:
break; break;
} }
...@@ -925,7 +985,6 @@ void TIntermBinary::promote() ...@@ -925,7 +985,6 @@ void TIntermBinary::promote()
case EOpIndexIndirect: case EOpIndexIndirect:
case EOpIndexDirectInterfaceBlock: case EOpIndexDirectInterfaceBlock:
case EOpIndexDirectStruct: case EOpIndexDirectStruct:
case EOpVectorSwizzle:
// These ops should be already fully handled. // These ops should be already fully handled.
UNREACHABLE(); UNREACHABLE();
break; break;
...@@ -963,6 +1022,22 @@ const TConstantUnion *TIntermConstantUnion::foldIndexing(int index) ...@@ -963,6 +1022,22 @@ const TConstantUnion *TIntermConstantUnion::foldIndexing(int index)
} }
} }
TIntermTyped *TIntermSwizzle::fold()
{
TIntermConstantUnion *operandConstant = mOperand->getAsConstantUnion();
if (operandConstant == nullptr)
{
return nullptr;
}
TConstantUnion *constArray = new TConstantUnion[mSwizzleOffsets.size()];
for (size_t i = 0; i < mSwizzleOffsets.size(); ++i)
{
constArray[i] = *operandConstant->foldIndexing(mSwizzleOffsets.at(i));
}
return CreateFoldedNode(constArray, this, mType.getQualifier());
}
TIntermTyped *TIntermBinary::fold(TDiagnostics *diagnostics) TIntermTyped *TIntermBinary::fold(TDiagnostics *diagnostics)
{ {
TIntermConstantUnion *leftConstant = mLeft->getAsConstantUnion(); TIntermConstantUnion *leftConstant = mLeft->getAsConstantUnion();
...@@ -1002,24 +1077,6 @@ TIntermTyped *TIntermBinary::fold(TDiagnostics *diagnostics) ...@@ -1002,24 +1077,6 @@ TIntermTyped *TIntermBinary::fold(TDiagnostics *diagnostics)
case EOpIndexDirectInterfaceBlock: case EOpIndexDirectInterfaceBlock:
// Can never be constant folded. // Can never be constant folded.
return nullptr; return nullptr;
case EOpVectorSwizzle:
{
if (leftConstant == nullptr)
{
return nullptr;
}
TIntermAggregate *fieldsAgg = mRight->getAsAggregate();
TIntermSequence *fieldsSequence = fieldsAgg->getSequence();
size_t numFields = fieldsSequence->size();
TConstantUnion *constArray = new TConstantUnion[numFields];
for (size_t i = 0; i < numFields; i++)
{
int fieldOffset = fieldsSequence->at(i)->getAsConstantUnion()->getIConst(0);
constArray[i] = *leftConstant->foldIndexing(fieldOffset);
}
return CreateFoldedNode(constArray, this, mType.getQualifier());
}
default: default:
{ {
if (leftConstant == nullptr || rightConstant == nullptr) if (leftConstant == nullptr || rightConstant == nullptr)
......
...@@ -31,6 +31,7 @@ class TDiagnostics; ...@@ -31,6 +31,7 @@ class TDiagnostics;
class TIntermTraverser; class TIntermTraverser;
class TIntermAggregate; class TIntermAggregate;
class TIntermSwizzle;
class TIntermBinary; class TIntermBinary;
class TIntermUnary; class TIntermUnary;
class TIntermConstantUnion; class TIntermConstantUnion;
...@@ -92,6 +93,7 @@ class TIntermNode : angle::NonCopyable ...@@ -92,6 +93,7 @@ class TIntermNode : angle::NonCopyable
virtual TIntermTyped *getAsTyped() { return 0; } virtual TIntermTyped *getAsTyped() { return 0; }
virtual TIntermConstantUnion *getAsConstantUnion() { return 0; } virtual TIntermConstantUnion *getAsConstantUnion() { return 0; }
virtual TIntermAggregate *getAsAggregate() { return 0; } virtual TIntermAggregate *getAsAggregate() { return 0; }
virtual TIntermSwizzle *getAsSwizzleNode() { return nullptr; }
virtual TIntermBinary *getAsBinaryNode() { return 0; } virtual TIntermBinary *getAsBinaryNode() { return 0; }
virtual TIntermUnary *getAsUnaryNode() { return 0; } virtual TIntermUnary *getAsUnaryNode() { return 0; }
virtual TIntermTernary *getAsTernaryNode() { return nullptr; } virtual TIntermTernary *getAsTernaryNode() { return nullptr; }
...@@ -410,6 +412,38 @@ class TIntermOperator : public TIntermTyped ...@@ -410,6 +412,38 @@ class TIntermOperator : public TIntermTyped
TOperator mOp; TOperator mOp;
}; };
// Node for vector swizzles.
class TIntermSwizzle : public TIntermTyped
{
public:
// This constructor determines the type of the node based on the operand.
TIntermSwizzle(TIntermTyped *operand, const TVector<int> &swizzleOffsets);
TIntermTyped *deepCopy() const override { return new TIntermSwizzle(*this); }
TIntermSwizzle *getAsSwizzleNode() override { return this; };
void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
bool hasSideEffects() const override { return mOperand->hasSideEffects(); }
TIntermTyped *getOperand() { return mOperand; }
void writeOffsetsAsXYZW(TInfoSinkBase *out) const;
bool hasDuplicateOffsets() const;
TIntermTyped *fold();
protected:
TIntermTyped *mOperand;
TVector<int> mSwizzleOffsets;
private:
void promote();
TIntermSwizzle(const TIntermSwizzle &node); // Note: not deleted, just private!
};
// //
// Nodes for all the basic binary math operators. // Nodes for all the basic binary math operators.
// //
...@@ -707,6 +741,7 @@ class TIntermTraverser : angle::NonCopyable ...@@ -707,6 +741,7 @@ class TIntermTraverser : angle::NonCopyable
virtual void visitSymbol(TIntermSymbol *node) {} virtual void visitSymbol(TIntermSymbol *node) {}
virtual void visitRaw(TIntermRaw *node) {} virtual void visitRaw(TIntermRaw *node) {}
virtual void visitConstantUnion(TIntermConstantUnion *node) {} virtual void visitConstantUnion(TIntermConstantUnion *node) {}
virtual bool visitSwizzle(Visit visit, TIntermSwizzle *node) { return true; }
virtual bool visitBinary(Visit visit, TIntermBinary *node) { return true; } virtual bool visitBinary(Visit visit, TIntermBinary *node) { return true; }
virtual bool visitUnary(Visit visit, TIntermUnary *node) { return true; } virtual bool visitUnary(Visit visit, TIntermUnary *node) { return true; }
virtual bool visitTernary(Visit visit, TIntermTernary *node) { return true; } virtual bool visitTernary(Visit visit, TIntermTernary *node) { return true; }
...@@ -723,6 +758,7 @@ class TIntermTraverser : angle::NonCopyable ...@@ -723,6 +758,7 @@ class TIntermTraverser : angle::NonCopyable
virtual void traverseSymbol(TIntermSymbol *node); virtual void traverseSymbol(TIntermSymbol *node);
virtual void traverseRaw(TIntermRaw *node); virtual void traverseRaw(TIntermRaw *node);
virtual void traverseConstantUnion(TIntermConstantUnion *node); virtual void traverseConstantUnion(TIntermConstantUnion *node);
virtual void traverseSwizzle(TIntermSwizzle *node);
virtual void traverseBinary(TIntermBinary *node); virtual void traverseBinary(TIntermBinary *node);
virtual void traverseUnary(TIntermUnary *node); virtual void traverseUnary(TIntermUnary *node);
virtual void traverseTernary(TIntermTernary *node); virtual void traverseTernary(TIntermTernary *node);
......
...@@ -23,6 +23,11 @@ void TIntermConstantUnion::traverse(TIntermTraverser *it) ...@@ -23,6 +23,11 @@ void TIntermConstantUnion::traverse(TIntermTraverser *it)
it->traverseConstantUnion(this); it->traverseConstantUnion(this);
} }
void TIntermSwizzle::traverse(TIntermTraverser *it)
{
it->traverseSwizzle(this);
}
void TIntermBinary::traverse(TIntermTraverser *it) void TIntermBinary::traverse(TIntermTraverser *it)
{ {
it->traverseBinary(this); it->traverseBinary(this);
...@@ -232,6 +237,26 @@ void TIntermTraverser::traverseConstantUnion(TIntermConstantUnion *node) ...@@ -232,6 +237,26 @@ void TIntermTraverser::traverseConstantUnion(TIntermConstantUnion *node)
visitConstantUnion(node); visitConstantUnion(node);
} }
void TIntermTraverser::traverseSwizzle(TIntermSwizzle *node)
{
bool visit = true;
if (preVisit)
visit = visitSwizzle(PreVisit, node);
if (visit)
{
incrementDepth(node);
node->getOperand()->traverse(this);
decrementDepth();
}
if (visit && postVisit)
visitSwizzle(PostVisit, node);
}
// //
// Traverse a binary node. // Traverse a binary node.
// //
......
...@@ -304,25 +304,22 @@ TIntermConstantUnion *TIntermediate::addConstantUnion(const TConstantUnion *cons ...@@ -304,25 +304,22 @@ TIntermConstantUnion *TIntermediate::addConstantUnion(const TConstantUnion *cons
return node; return node;
} }
TIntermTyped *TIntermediate::addSwizzle( TIntermTyped *TIntermediate::AddSwizzle(TIntermTyped *baseExpression,
TVectorFields &fields, const TSourceLoc &line) const TVectorFields &fields,
const TSourceLoc &dotLocation)
{ {
TVector<int> fieldsVector;
for (int i = 0; i < fields.num; ++i)
{
fieldsVector.push_back(fields.offsets[i]);
}
TIntermSwizzle *node = new TIntermSwizzle(baseExpression, fieldsVector);
node->setLine(dotLocation);
TIntermAggregate *node = new TIntermAggregate(EOpSequence); TIntermTyped *folded = node->fold();
node->getTypePointer()->setQualifier(EvqConst); if (folded)
node->setLine(line);
TIntermConstantUnion *constIntNode;
TIntermSequence *sequenceVector = node->getSequence();
TConstantUnion *unionArray;
for (int i = 0; i < fields.num; i++)
{ {
unionArray = new TConstantUnion[1]; return folded;
unionArray->setIConst(fields.offsets[i]);
constIntNode = addConstantUnion(
unionArray, TType(EbtInt, EbpUndefined, EvqConst), line);
sequenceVector->push_back(constIntNode);
} }
return node; return node;
......
...@@ -58,7 +58,9 @@ class TIntermediate ...@@ -58,7 +58,9 @@ class TIntermediate
TIntermNode *, const TSourceLoc &); TIntermNode *, const TSourceLoc &);
TIntermBranch *addBranch(TOperator, const TSourceLoc &); TIntermBranch *addBranch(TOperator, const TSourceLoc &);
TIntermBranch *addBranch(TOperator, TIntermTyped *, const TSourceLoc &); TIntermBranch *addBranch(TOperator, TIntermTyped *, const TSourceLoc &);
TIntermTyped *addSwizzle(TVectorFields &, const TSourceLoc &); static TIntermTyped *AddSwizzle(TIntermTyped *baseExpression,
const TVectorFields &fields,
const TSourceLoc &dotLocation);
static TIntermAggregate *PostProcess(TIntermNode *root); static TIntermAggregate *PostProcess(TIntermNode *root);
static void outputTree(TIntermNode *, TInfoSinkBase &); static void outputTree(TIntermNode *, TInfoSinkBase &);
......
...@@ -62,8 +62,6 @@ const char *GetOperatorString(TOperator op) ...@@ -62,8 +62,6 @@ const char *GetOperatorString(TOperator op)
case EOpIndexDirectStruct: case EOpIndexDirectStruct:
case EOpIndexDirectInterfaceBlock: return "."; case EOpIndexDirectInterfaceBlock: return ".";
case EOpVectorSwizzle: return ".";
case EOpRadians: return "radians"; case EOpRadians: return "radians";
case EOpDegrees: return "degrees"; case EOpDegrees: return "degrees";
case EOpSin: return "sin"; case EOpSin: return "sin";
......
...@@ -77,8 +77,6 @@ enum TOperator ...@@ -77,8 +77,6 @@ enum TOperator
EOpIndexDirectStruct, EOpIndexDirectStruct,
EOpIndexDirectInterfaceBlock, EOpIndexDirectInterfaceBlock,
EOpVectorSwizzle,
// //
// Built-in functions potentially mapped to operators // Built-in functions potentially mapped to operators
// //
......
...@@ -280,6 +280,17 @@ void TOutputGLSLBase::visitConstantUnion(TIntermConstantUnion *node) ...@@ -280,6 +280,17 @@ void TOutputGLSLBase::visitConstantUnion(TIntermConstantUnion *node)
writeConstantUnion(node->getType(), node->getUnionArrayPointer()); writeConstantUnion(node->getType(), node->getUnionArrayPointer());
} }
bool TOutputGLSLBase::visitSwizzle(Visit visit, TIntermSwizzle *node)
{
TInfoSinkBase &out = objSink();
if (visit == PostVisit)
{
out << ".";
node->writeOffsetsAsXYZW(&out);
}
return true;
}
bool TOutputGLSLBase::visitBinary(Visit visit, TIntermBinary *node) bool TOutputGLSLBase::visitBinary(Visit visit, TIntermBinary *node)
{ {
bool visitChildren = true; bool visitChildren = true;
...@@ -410,40 +421,6 @@ bool TOutputGLSLBase::visitBinary(Visit visit, TIntermBinary *node) ...@@ -410,40 +421,6 @@ bool TOutputGLSLBase::visitBinary(Visit visit, TIntermBinary *node)
visitChildren = false; visitChildren = false;
} }
break; break;
case EOpVectorSwizzle:
if (visit == InVisit)
{
out << ".";
TIntermAggregate *rightChild = node->getRight()->getAsAggregate();
TIntermSequence *sequence = rightChild->getSequence();
for (TIntermSequence::iterator sit = sequence->begin(); sit != sequence->end(); ++sit)
{
TIntermConstantUnion *element = (*sit)->getAsConstantUnion();
ASSERT(element->getBasicType() == EbtInt);
ASSERT(element->getNominalSize() == 1);
const TConstantUnion& data = element->getUnionArrayPointer()[0];
ASSERT(data.getType() == EbtInt);
switch (data.getIConst())
{
case 0:
out << "x";
break;
case 1:
out << "y";
break;
case 2:
out << "z";
break;
case 3:
out << "w";
break;
default:
UNREACHABLE();
}
}
visitChildren = false;
}
break;
case EOpAdd: case EOpAdd:
writeTriplet(visit, "(", " + ", ")"); writeTriplet(visit, "(", " + ", ")");
......
...@@ -42,6 +42,7 @@ class TOutputGLSLBase : public TIntermTraverser ...@@ -42,6 +42,7 @@ class TOutputGLSLBase : public TIntermTraverser
void visitSymbol(TIntermSymbol *node) override; void visitSymbol(TIntermSymbol *node) override;
void visitConstantUnion(TIntermConstantUnion *node) override; void visitConstantUnion(TIntermConstantUnion *node) override;
bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
bool visitBinary(Visit visit, TIntermBinary *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override;
bool visitUnary(Visit visit, TIntermUnary *node) override; bool visitUnary(Visit visit, TIntermUnary *node) override;
bool visitTernary(Visit visit, TIntermTernary *node) override; bool visitTernary(Visit visit, TIntermTernary *node) override;
......
...@@ -848,6 +848,17 @@ bool OutputHLSL::ancestorEvaluatesToSamplerInStruct(Visit visit) ...@@ -848,6 +848,17 @@ bool OutputHLSL::ancestorEvaluatesToSamplerInStruct(Visit visit)
return false; return false;
} }
bool OutputHLSL::visitSwizzle(Visit visit, TIntermSwizzle *node)
{
TInfoSinkBase &out = getInfoSink();
if (visit == PostVisit)
{
out << ".";
node->writeOffsetsAsXYZW(&out);
}
return true;
}
bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node) bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
{ {
TInfoSinkBase &out = getInfoSink(); TInfoSinkBase &out = getInfoSink();
...@@ -1066,42 +1077,6 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node) ...@@ -1066,42 +1077,6 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
return false; return false;
} }
break; break;
case EOpVectorSwizzle:
if (visit == InVisit)
{
out << ".";
TIntermAggregate *swizzle = node->getRight()->getAsAggregate();
if (swizzle)
{
TIntermSequence *sequence = swizzle->getSequence();
for (TIntermSequence::iterator sit = sequence->begin(); sit != sequence->end(); sit++)
{
TIntermConstantUnion *element = (*sit)->getAsConstantUnion();
if (element)
{
int i = element->getIConst(0);
switch (i)
{
case 0: out << "x"; break;
case 1: out << "y"; break;
case 2: out << "z"; break;
case 3: out << "w"; break;
default: UNREACHABLE();
}
}
else UNREACHABLE();
}
}
else UNREACHABLE();
return false; // Fully processed
}
break;
case EOpAdd: case EOpAdd:
outputTriplet(out, visit, "(", " + ", ")"); outputTriplet(out, visit, "(", " + ", ")");
break; break;
......
...@@ -59,6 +59,7 @@ class OutputHLSL : public TIntermTraverser ...@@ -59,6 +59,7 @@ class OutputHLSL : public TIntermTraverser
void visitSymbol(TIntermSymbol*); void visitSymbol(TIntermSymbol*);
void visitRaw(TIntermRaw*); void visitRaw(TIntermRaw*);
void visitConstantUnion(TIntermConstantUnion*); void visitConstantUnion(TIntermConstantUnion*);
bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
bool visitBinary(Visit visit, TIntermBinary*); bool visitBinary(Visit visit, TIntermBinary*);
bool visitUnary(Visit visit, TIntermUnary*); bool visitUnary(Visit visit, TIntermUnary*);
bool visitTernary(Visit visit, TIntermTernary *); bool visitTernary(Visit visit, TIntermTernary *);
......
...@@ -247,6 +247,18 @@ bool TParseContext::checkCanBeLValue(const TSourceLoc &line, const char *op, TIn ...@@ -247,6 +247,18 @@ bool TParseContext::checkCanBeLValue(const TSourceLoc &line, const char *op, TIn
{ {
TIntermSymbol *symNode = node->getAsSymbolNode(); TIntermSymbol *symNode = node->getAsSymbolNode();
TIntermBinary *binaryNode = node->getAsBinaryNode(); TIntermBinary *binaryNode = node->getAsBinaryNode();
TIntermSwizzle *swizzleNode = node->getAsSwizzleNode();
if (swizzleNode)
{
bool ok = checkCanBeLValue(line, op, swizzleNode->getOperand());
if (ok && swizzleNode->hasDuplicateOffsets())
{
error(line, " l-value of swizzle cannot have duplicate components", op);
return false;
}
return ok;
}
if (binaryNode) if (binaryNode)
{ {
...@@ -257,34 +269,10 @@ bool TParseContext::checkCanBeLValue(const TSourceLoc &line, const char *op, TIn ...@@ -257,34 +269,10 @@ bool TParseContext::checkCanBeLValue(const TSourceLoc &line, const char *op, TIn
case EOpIndexDirectStruct: case EOpIndexDirectStruct:
case EOpIndexDirectInterfaceBlock: case EOpIndexDirectInterfaceBlock:
return checkCanBeLValue(line, op, binaryNode->getLeft()); return checkCanBeLValue(line, op, binaryNode->getLeft());
case EOpVectorSwizzle:
{
bool ok = checkCanBeLValue(line, op, binaryNode->getLeft());
if (ok)
{
int offsetCount[4] = {0, 0, 0, 0};
TIntermAggregate *swizzleOffsets = binaryNode->getRight()->getAsAggregate();
for (const auto &offset : *swizzleOffsets->getSequence())
{
int value = offset->getAsTyped()->getAsConstantUnion()->getIConst(0);
offsetCount[value]++;
if (offsetCount[value] > 1)
{
error(line, " l-value of swizzle cannot have duplicate components", op);
return false;
}
}
}
return ok;
}
default: default:
break; break;
} }
error(line, " l-value required", op); error(line, " l-value required", op);
return false; return false;
} }
...@@ -2770,9 +2758,7 @@ TIntermTyped *TParseContext::addFieldSelectionExpression(TIntermTyped *baseExpre ...@@ -2770,9 +2758,7 @@ TIntermTyped *TParseContext::addFieldSelectionExpression(TIntermTyped *baseExpre
fields.offsets[0] = 0; fields.offsets[0] = 0;
} }
TIntermTyped *index = intermediate.addSwizzle(fields, fieldLocation); return TIntermediate::AddSwizzle(baseExpression, fields, dotLocation);
return intermediate.addIndex(EOpVectorSwizzle, baseExpression, index, dotLocation,
&mDiagnostics);
} }
else if (baseExpression->getBasicType() == EbtStruct) else if (baseExpression->getBasicType() == EbtStruct)
{ {
......
...@@ -42,6 +42,7 @@ class TOutputTraverser : public TIntermTraverser ...@@ -42,6 +42,7 @@ class TOutputTraverser : public TIntermTraverser
protected: protected:
void visitSymbol(TIntermSymbol *) override; void visitSymbol(TIntermSymbol *) override;
void visitConstantUnion(TIntermConstantUnion *) override; void visitConstantUnion(TIntermConstantUnion *) override;
bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
bool visitBinary(Visit visit, TIntermBinary *) override; bool visitBinary(Visit visit, TIntermBinary *) override;
bool visitUnary(Visit visit, TIntermUnary *) override; bool visitUnary(Visit visit, TIntermUnary *) override;
bool visitTernary(Visit visit, TIntermTernary *node) override; bool visitTernary(Visit visit, TIntermTernary *node) override;
...@@ -84,6 +85,14 @@ void TOutputTraverser::visitSymbol(TIntermSymbol *node) ...@@ -84,6 +85,14 @@ void TOutputTraverser::visitSymbol(TIntermSymbol *node)
sink << "(" << node->getCompleteString() << ")\n"; sink << "(" << node->getCompleteString() << ")\n";
} }
bool TOutputTraverser::visitSwizzle(Visit visit, TIntermSwizzle *node)
{
TInfoSinkBase &out = sink;
OutputTreeText(out, node, mDepth);
out << "vector swizzle";
return true;
}
bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node) bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node)
{ {
TInfoSinkBase& out = sink; TInfoSinkBase& out = sink;
...@@ -153,9 +162,6 @@ bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node) ...@@ -153,9 +162,6 @@ bool TOutputTraverser::visitBinary(Visit visit, TIntermBinary *node)
case EOpIndexDirectInterfaceBlock: case EOpIndexDirectInterfaceBlock:
out << "direct index for interface block"; out << "direct index for interface block";
break; break;
case EOpVectorSwizzle:
out << "vector swizzle";
break;
case EOpAdd: case EOpAdd:
out << "add"; out << "add";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment