Commit 2c9cc8b6 by Olli Etuaho Committed by Commit Bot

Don't duplicate symbol type information in AST nodes

Function prototype nodes and symbol nodes already refer to symbols that have type information, so the type doesn't need to be copied to the TInterm* AST node classes. Now type is only stored in those AST node classes that represent other types of expressions. They use a new TIntermExpression base class for this. Since now we may use the TType from builtin symbols directly instead of copying it, building the mangled names of types in the correct memory pool is also required. The code now realizes the types of built-in variables when they get added to the symbol table. BUG=angleproject:2267 TEST=angle_unittests Change-Id: Ic8d7fc912937cb8abb1e306e58c63bb9c146aae9 Reviewed-on: https://chromium-review.googlesource.com/857005Reviewed-by: 's avatarCorentin Wallez <cwallez@chromium.org> Reviewed-by: 's avatarGeoff Lang <geofflang@chromium.org> Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
parent 28efd82c
......@@ -182,7 +182,11 @@ bool CanFoldAggregateBuiltInOp(TOperator op)
//
////////////////////////////////////////////////////////////////
void TIntermTyped::setTypePreservePrecision(const TType &t)
TIntermExpression::TIntermExpression(const TType &t) : TIntermTyped(), mType(t)
{
}
void TIntermExpression::setTypePreservePrecision(const TType &t)
{
TPrecision precision = getPrecision();
mType = t;
......@@ -303,8 +307,7 @@ bool TIntermAggregateBase::insertChildNodes(TIntermSequence::size_type position,
return true;
}
TIntermSymbol::TIntermSymbol(const TVariable *variable)
: TIntermTyped(variable->getType()), mVariable(variable)
TIntermSymbol::TIntermSymbol(const TVariable *variable) : TIntermTyped(), mVariable(variable)
{
}
......@@ -328,6 +331,11 @@ const TString &TIntermSymbol::getName() const
return mVariable->name();
}
const TType &TIntermSymbol::getType() const
{
return mVariable->getType();
}
TIntermAggregate *TIntermAggregate::CreateFunctionCall(const TFunction &func,
TIntermSequence *arguments)
{
......@@ -752,11 +760,11 @@ bool TIntermCase::replaceChildNode(TIntermNode *original, TIntermNode *replaceme
return false;
}
TIntermTyped::TIntermTyped(const TIntermTyped &node) : TIntermNode(), mType(node.mType)
TIntermTyped::TIntermTyped(const TIntermTyped &node) : TIntermNode()
{
// Copy constructor is disallowed for TIntermNode in order to disallow it for subclasses that
// don't explicitly allow it, so normal TIntermNode constructor is used to construct the copy.
// We need to manually copy any fields of TIntermNode besides handling fields in TIntermTyped.
// We need to manually copy any fields of TIntermNode.
mLine = node.mLine;
}
......@@ -770,17 +778,23 @@ const TConstantUnion *TIntermTyped::getConstantValue() const
return nullptr;
}
TIntermConstantUnion::TIntermConstantUnion(const TIntermConstantUnion &node) : TIntermTyped(node)
TIntermConstantUnion::TIntermConstantUnion(const TIntermConstantUnion &node)
: TIntermExpression(node)
{
mUnionArrayPointer = node.mUnionArrayPointer;
}
TIntermFunctionPrototype::TIntermFunctionPrototype(const TFunction *function)
: TIntermTyped(function->getReturnType()), mFunction(function)
: TIntermTyped(), mFunction(function)
{
ASSERT(mFunction->symbolType() != SymbolType::Empty);
}
const TType &TIntermFunctionPrototype::getType() const
{
return mFunction->getReturnType();
}
TIntermAggregate::TIntermAggregate(const TIntermAggregate &node)
: TIntermOperator(node),
mUseEmulatedFunction(node.mUseEmulatedFunction),
......@@ -805,7 +819,7 @@ TIntermAggregate *TIntermAggregate::shallowCopy() const
return copyNode;
}
TIntermSwizzle::TIntermSwizzle(const TIntermSwizzle &node) : TIntermTyped(node)
TIntermSwizzle::TIntermSwizzle(const TIntermSwizzle &node) : TIntermExpression(node)
{
TIntermTyped *operandCopy = node.mOperand->deepCopy();
ASSERT(operandCopy != nullptr);
......@@ -831,7 +845,7 @@ TIntermUnary::TIntermUnary(const TIntermUnary &node)
mOperand = operandCopy;
}
TIntermTernary::TIntermTernary(const TIntermTernary &node) : TIntermTyped(node)
TIntermTernary::TIntermTernary(const TIntermTernary &node) : TIntermExpression(node)
{
TIntermTyped *conditionCopy = node.mCondition->deepCopy();
TIntermTyped *trueCopy = node.mTrueExpression->deepCopy();
......@@ -1055,7 +1069,7 @@ void TIntermUnary::promote()
}
TIntermSwizzle::TIntermSwizzle(TIntermTyped *operand, const TVector<int> &swizzleOffsets)
: TIntermTyped(TType(EbtFloat, EbpUndefined)),
: TIntermExpression(TType(EbtFloat, EbpUndefined)),
mOperand(operand),
mSwizzleOffsets(swizzleOffsets)
{
......@@ -1085,7 +1099,7 @@ TIntermInvariantDeclaration::TIntermInvariantDeclaration(TIntermSymbol *symbol,
TIntermTernary::TIntermTernary(TIntermTyped *cond,
TIntermTyped *trueExpression,
TIntermTyped *falseExpression)
: TIntermTyped(trueExpression->getType()),
: TIntermExpression(trueExpression->getType()),
mCondition(cond),
mTrueExpression(trueExpression),
mFalseExpression(falseExpression)
......
......@@ -123,7 +123,7 @@ struct TIntermNodePair
class TIntermTyped : public TIntermNode
{
public:
TIntermTyped(const TType &t) : mType(t) {}
TIntermTyped() {}
virtual TIntermTyped *deepCopy() const = 0;
......@@ -142,34 +142,29 @@ class TIntermTyped : public TIntermNode
// affecting state. May return true conservatively.
virtual bool hasSideEffects() const = 0;
void setType(const TType &t) { mType = t; }
void setTypePreservePrecision(const TType &t);
const TType &getType() const { return mType; }
TType *getTypePointer() { return &mType; }
virtual const TType &getType() const = 0;
TBasicType getBasicType() const { return mType.getBasicType(); }
TQualifier getQualifier() const { return mType.getQualifier(); }
TPrecision getPrecision() const { return mType.getPrecision(); }
TMemoryQualifier getMemoryQualifier() const { return mType.getMemoryQualifier(); }
int getCols() const { return mType.getCols(); }
int getRows() const { return mType.getRows(); }
int getNominalSize() const { return mType.getNominalSize(); }
int getSecondarySize() const { return mType.getSecondarySize(); }
bool isInterfaceBlock() const { return mType.isInterfaceBlock(); }
bool isMatrix() const { return mType.isMatrix(); }
bool isArray() const { return mType.isArray(); }
bool isVector() const { return mType.isVector(); }
bool isScalar() const { return mType.isScalar(); }
bool isScalarInt() const { return mType.isScalarInt(); }
const char *getBasicString() const { return mType.getBasicString(); }
TString getCompleteString() const { return mType.getCompleteString(); }
unsigned int getOutermostArraySize() const { return mType.getOutermostArraySize(); }
TBasicType getBasicType() const { return getType().getBasicType(); }
TQualifier getQualifier() const { return getType().getQualifier(); }
TPrecision getPrecision() const { return getType().getPrecision(); }
TMemoryQualifier getMemoryQualifier() const { return getType().getMemoryQualifier(); }
int getCols() const { return getType().getCols(); }
int getRows() const { return getType().getRows(); }
int getNominalSize() const { return getType().getNominalSize(); }
int getSecondarySize() const { return getType().getSecondarySize(); }
protected:
TType mType;
bool isInterfaceBlock() const { return getType().isInterfaceBlock(); }
bool isMatrix() const { return getType().isMatrix(); }
bool isArray() const { return getType().isArray(); }
bool isVector() const { return getType().isVector(); }
bool isScalar() const { return getType().isScalar(); }
bool isScalarInt() const { return getType().isScalarInt(); }
const char *getBasicString() const { return getType().getBasicString(); }
TString getCompleteString() const { return getType().getCompleteString(); }
unsigned int getOutermostArraySize() const { return getType().getOutermostArraySize(); }
protected:
TIntermTyped(const TIntermTyped &node);
};
......@@ -250,6 +245,8 @@ class TIntermSymbol : public TIntermTyped
bool hasSideEffects() const override { return false; }
const TType &getType() const override;
const TSymbolUniqueId &uniqueId() const;
const TString &getName() const;
const TVariable &variable() const { return *mVariable; }
......@@ -264,13 +261,34 @@ class TIntermSymbol : public TIntermTyped
const TVariable *const mVariable; // Guaranteed to be non-null
};
// A typed expression that is not just representing a symbol table symbol.
class TIntermExpression : public TIntermTyped
{
public:
TIntermExpression(const TType &t);
const TType &getType() const override { return mType; }
TType *getTypePointer() { return &mType; }
protected:
void setType(const TType &t) { mType = t; }
void setTypePreservePrecision(const TType &t);
TIntermExpression(const TIntermExpression &node) = default;
TType mType;
};
// A Raw node stores raw code, that the translator will insert verbatim
// into the output stream. Useful for transformation operations that make
// complex code that might not fit naturally into the GLSL model.
class TIntermRaw : public TIntermTyped
class TIntermRaw : public TIntermExpression
{
public:
TIntermRaw(const TType &type, const TString &rawText) : TIntermTyped(type), mRawText(rawText) {}
TIntermRaw(const TType &type, const TString &rawText)
: TIntermExpression(type), mRawText(rawText)
{
}
TIntermRaw(const TIntermRaw &) = delete;
TIntermTyped *deepCopy() const override
......@@ -298,11 +316,11 @@ class TIntermRaw : public TIntermTyped
// "true ? 1.0 : non_constant"
// Other nodes than TIntermConstantUnion may also be constant expressions.
//
class TIntermConstantUnion : public TIntermTyped
class TIntermConstantUnion : public TIntermExpression
{
public:
TIntermConstantUnion(const TConstantUnion *unionPointer, const TType &type)
: TIntermTyped(type), mUnionArrayPointer(unionPointer)
: TIntermExpression(type), mUnionArrayPointer(unionPointer)
{
ASSERT(unionPointer);
}
......@@ -375,7 +393,7 @@ class TIntermConstantUnion : public TIntermTyped
//
// Intermediate class for node types that hold operators.
//
class TIntermOperator : public TIntermTyped
class TIntermOperator : public TIntermExpression
{
public:
TOperator getOp() const { return mOp; }
......@@ -391,8 +409,8 @@ class TIntermOperator : public TIntermTyped
bool hasSideEffects() const override { return isAssignment(); }
protected:
TIntermOperator(TOperator op) : TIntermTyped(TType(EbtFloat, EbpUndefined)), mOp(op) {}
TIntermOperator(TOperator op, const TType &type) : TIntermTyped(type), mOp(op) {}
TIntermOperator(TOperator op) : TIntermExpression(TType(EbtFloat, EbpUndefined)), mOp(op) {}
TIntermOperator(TOperator op, const TType &type) : TIntermExpression(type), mOp(op) {}
TIntermOperator(const TIntermOperator &) = default;
......@@ -400,7 +418,7 @@ class TIntermOperator : public TIntermTyped
};
// Node for vector swizzles.
class TIntermSwizzle : public TIntermTyped
class TIntermSwizzle : public TIntermExpression
{
public:
// This constructor determines the type of the node based on the operand.
......@@ -665,6 +683,8 @@ class TIntermFunctionPrototype : public TIntermTyped, public TIntermAggregateBas
void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
const TType &getType() const override;
TIntermTyped *deepCopy() const override
{
UNREACHABLE();
......@@ -756,7 +776,7 @@ class TIntermInvariantDeclaration : public TIntermNode
};
// For ternary operators like a ? b : c.
class TIntermTernary : public TIntermTyped
class TIntermTernary : public TIntermExpression
{
public:
TIntermTernary(TIntermTyped *cond, TIntermTyped *trueExpression, TIntermTyped *falseExpression);
......
......@@ -13,9 +13,8 @@ namespace sh
void RegenerateStructNames::visitSymbol(TIntermSymbol *symbol)
{
ASSERT(symbol);
TType *type = symbol->getTypePointer();
ASSERT(type);
TStructure *userType = type->getStruct();
const TType &type = symbol->getType();
const TStructure *userType = type.getStruct();
if (!userType)
return;
......@@ -59,7 +58,10 @@ void RegenerateStructNames::visitSymbol(TIntermSymbol *symbol)
std::string id = Str(uniqueId);
TString tmp = kPrefix + TString(id.c_str());
tmp += "_" + userType->name();
userType->setName(tmp);
// TODO(oetuaho): Add another mechanism to change symbol names so that the const_cast is not
// needed.
const_cast<TStructure *>(userType)->setName(tmp);
}
bool RegenerateStructNames::visitBlock(Visit, TIntermBlock *block)
......
......@@ -236,7 +236,6 @@ TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
TIntermBinary *cond =
new TIntermBinary(EOpLessThan, indexParam->deepCopy(), CreateIntConstantNode(0));
cond->setType(TType(EbtBool, EbpUndefined));
// Two blocks: one accesses (either reads or writes) the first element and returns,
// the other accesses the last element.
......
......@@ -237,8 +237,7 @@ TVariable *TSymbolTable::insertVariable(ESymbolLevel level,
TVariable *var = new TVariable(this, name, type, symbolType);
if (insert(level, var))
{
// Do lazy initialization for struct types, so we allocate to the current scope.
if (var->getType().getBasicType() == EbtStruct)
if (level <= LAST_BUILTIN_LEVEL)
{
var->getType().realize();
}
......@@ -255,7 +254,7 @@ TVariable *TSymbolTable::insertVariableExt(ESymbolLevel level,
TVariable *var = new TVariable(this, NewPoolTString(name), type, SymbolType::BuiltIn, ext);
if (insert(level, var))
{
if (var->getType().getBasicType() == EbtStruct)
if (level <= LAST_BUILTIN_LEVEL)
{
var->getType().realize();
}
......
......@@ -159,6 +159,7 @@ class TSymbolTable : angle::NonCopyable
{
TVariable *constant = new TVariable(
this, NewPoolTString(name), TType(EbtInt, precision, EvqConst, 1), SymbolType::BuiltIn);
constant->getType().realize();
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray[0].setIConst(value);
constant->shareConstPointer(unionArray);
......@@ -174,6 +175,7 @@ class TSymbolTable : angle::NonCopyable
TVariable *constant =
new TVariable(this, NewPoolTString(name), TType(EbtInt, precision, EvqConst, 1),
SymbolType::BuiltIn, ext);
constant->getType().realize();
TConstantUnion *unionArray = new TConstantUnion[1];
unionArray[0].setIConst(value);
constant->shareConstPointer(unionArray);
......@@ -187,6 +189,7 @@ class TSymbolTable : angle::NonCopyable
{
TVariable *constantIvec3 = new TVariable(
this, NewPoolTString(name), TType(EbtInt, precision, EvqConst, 3), SymbolType::BuiltIn);
constantIvec3->getType().realize();
TConstantUnion *unionArray = new TConstantUnion[3];
for (size_t index = 0u; index < 3u; ++index)
......
......@@ -193,7 +193,7 @@ TType::TType(const TPublicType &p)
}
}
TType::TType(TStructure *userDef)
TType::TType(const TStructure *userDef)
: type(EbtStruct),
precision(EbpUndefined),
qualifier(EvqTemporary),
......@@ -766,15 +766,6 @@ void TType::setInterfaceBlock(TInterfaceBlock *interfaceBlockIn)
}
}
void TType::setStruct(TStructure *s)
{
if (mStructure != s)
{
mStructure = s;
invalidateMangledName();
}
}
const char *TType::getMangledName() const
{
if (mMangledName == nullptr)
......
......@@ -97,7 +97,7 @@ class TType
unsigned char ps = 1,
unsigned char ss = 1);
explicit TType(const TPublicType &p);
explicit TType(TStructure *userDef);
explicit TType(const TStructure *userDef);
TType(TInterfaceBlock *interfaceBlockIn,
TQualifier qualifierIn,
TLayoutQualifier layoutQualifierIn);
......@@ -227,9 +227,7 @@ class TType
bool canBeConstructed() const;
TStructure *getStruct() { return mStructure; }
const TStructure *getStruct() const { return mStructure; }
void setStruct(TStructure *s);
const char *getMangledName() const;
......@@ -343,8 +341,8 @@ class TType
// It's nullptr also for members of named interface blocks.
TInterfaceBlock *mInterfaceBlock;
// 0 unless this is a struct
TStructure *mStructure;
// nullptr unless this is a struct
const TStructure *mStructure;
bool mIsStructSpecifier;
mutable const char *mMangledName;
......
......@@ -167,7 +167,7 @@ class FindStructByName final : public TIntermTraverser
return;
}
TStructure *structure = symbol->getTypePointer()->getStruct();
const TStructure *structure = symbol->getType().getStruct();
if (structure != nullptr && structure->symbolType() != SymbolType::Empty &&
structure->name() == mStructName)
......@@ -177,11 +177,11 @@ class FindStructByName final : public TIntermTraverser
}
bool isStructureFound() const { return mStructure != nullptr; };
TStructure *getStructure() const { return mStructure; }
const TStructure *getStructure() const { return mStructure; }
private:
TString mStructName;
TStructure *mStructure;
const TStructure *mStructure;
};
} // namespace
......@@ -301,8 +301,8 @@ TEST_F(InitOutputVariablesWebGL2VertexShaderTest, OutputStruct)
mASTRoot->traverse(&findStruct);
ASSERT(findStruct.isStructureFound());
TType type(EbtStruct, EbpUndefined, EvqVertexOut);
type.setStruct(findStruct.getStructure());
TType type(findStruct.getStructure());
type.setQualifier(EvqVertexOut);
TIntermTyped *expectedLValue = CreateLValueNode("out1", type);
EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValue));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment