Skip to content

Commit

Permalink
Fix ICE when lowering an associatedtype declref from an derived inter…
Browse files Browse the repository at this point in the history
…face. (#3312)

* Fix ICE when lowering an associatedtype declref from an derived interface.

* Fixes.

* Fix test.

* Fix GLSL/SPIRV image subscript swizzle store regression.

* Fix.

---------

Co-authored-by: Yong He <yhe@nvidia.com>
  • Loading branch information
csyonghe and Yong He authored Nov 6, 2023
1 parent da9e0ad commit 46529df
Show file tree
Hide file tree
Showing 14 changed files with 282 additions and 53 deletions.
25 changes: 23 additions & 2 deletions source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2937,6 +2937,17 @@ namespace Slang
// With the big picture spelled out, we can settle into
// the work of constructing our synthesized method.
//

// First, we check that the differentiabliity of the method matches the requirement,
// and we don't attempt to synthesize a method if they don't match.
if (getShared()->getFuncDifferentiableLevel(
as<FunctionDeclBase>(lookupResult.item.declRef.getDecl()))
< getShared()->getFuncDifferentiableLevel(
as<FunctionDeclBase>(requiredMemberDeclRef.getDecl())))
{
return false;
}

ThisExpr* synThis = nullptr;
List<Expr*> synArgs;
auto synFuncDecl = synthesizeMethodSignatureForRequirementWitness(
Expand Down Expand Up @@ -3945,8 +3956,15 @@ namespace Slang
// and if nothing is found we print the candidates that made it
// furthest in checking.
//
getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef);
getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef);
if (!lookupResult.isOverloaded() && lookupResult.isValid())
{
getSink()->diagnose(lookupResult.item.declRef, Diagnostics::memberDoesNotMatchRequirementSignature, lookupResult.item.declRef);
}
else
{
getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef);
}
getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOfInterfaceRequirement, requiredMemberDeclRef);
return false;
}

Expand Down Expand Up @@ -7004,6 +7022,9 @@ namespace Slang

FunctionDifferentiableLevel SharedSemanticsContext::_getFuncDifferentiableLevelImpl(FunctionDeclBase* func, int recurseLimit)
{
if (!func)
return FunctionDifferentiableLevel::None;

if (recurseLimit > 0)
{
if (auto primalSubst = func->findModifier<PrimalSubstituteAttribute>())
Expand Down
6 changes: 6 additions & 0 deletions source/slang/slang-check-expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,12 @@ namespace Slang
SourceLoc loc,
Expr* originalExpr)
{
if (!item.declRef)
{
originalExpr->type = QualType(m_astBuilder->getErrorType());
return originalExpr;
}

// We could be referencing a decl that will be synthesized. If so create a placeholder
// and return a DeclRefExpr to it.
if (auto lookupResultExpr = maybeUseSynthesizedDeclForLookupResult(item, originalExpr))
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-diagnostic-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ DIAGNOSTIC(-1, Note, seeRequirementDeclaration, "see requirement declaration")
DIAGNOSTIC(-1, Note, doYouForgetToMakeComponentAccessible, "do you forget to make component '$0' acessible from '$1' (missing public qualifier)?")

DIAGNOSTIC(-1, Note, seeDeclarationOf, "see declaration of '$0'")
DIAGNOSTIC(-1, Note, seeDeclarationOfInterfaceRequirement, "see interface requirement declaration of '$0'")
// An alternate wording of the above note, emphasing the position rather than content of the declaration.
DIAGNOSTIC(-1, Note, declaredHere, "declared here")
DIAGNOSTIC(-1, Note, seeOtherDeclarationOf, "see other declaration of '$0'")
Expand Down Expand Up @@ -542,6 +543,7 @@ DIAGNOSTIC(38008, Error, specializationParameterNotSpecialized, "no specializati
DIAGNOSTIC(38009, Error, expectedValueOfTypeForSpecializationArg, "expected a constant value of type '$0' as argument for specialization parameter '$1'")

DIAGNOSTIC(38100, Error, typeDoesntImplementInterfaceRequirement, "type '$0' does not provide required interface member '$1'")
DIAGNOSTIC(38105, Error, memberDoesNotMatchRequirementSignature, "member '$0' does not match interface requirement.")
DIAGNOSTIC(38101, Error, thisExpressionOutsideOfTypeDecl, "'this' expression can only be used in members of an aggregate type")
DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is only allowed inside a type or 'extension' declaration")
DIAGNOSTIC(38103, Error, thisTypeOutsideOfTypeDecl, "'This' type can only be used inside of an aggregate type")
Expand Down
20 changes: 20 additions & 0 deletions source/slang/slang-emit-spirv-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,26 @@ SpvInst* emitOpCompositeExtract(
);
}

// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCompositeInsert
template<typename T1, typename T2, typename T3, Index N>
SpvInst* emitOpCompositeInsert(
SpvInstParent* parent,
IRInst* inst,
const T1& idResultType,
const T2& object,
const T3& composite,
const Array<SpvLiteralInteger, N>& indexes
)
{
static_assert(isSingular<T1>);
static_assert(isSingular<T2>);
static_assert(isSingular<T3>);

return emitInst(
parent, inst, SpvOpCompositeInsert, idResultType, kResultID, object, composite, indexes
);
}

// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpVectorExtractDynamic
template<typename T1, typename T2, typename T3>
SpvInst* emitOpVectorExtractDynamic(
Expand Down
14 changes: 14 additions & 0 deletions source/slang/slang-emit-spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3827,17 +3827,31 @@ struct SPIRVEmitContext

SpvInst* emitSwizzleSet(SpvInstParent* parent, IRSwizzleSet* inst)
{
if (inst->getElementCount() == 1)
{
auto index = inst->getElementIndex(0);
if (auto intLit = as<IRIntLit>(index))
return emitOpCompositeInsert(parent, inst, inst->getFullType(), inst->getSource(), inst->getBase(), makeArray(SpvLiteralInteger::from32((uint32_t)intLit->value.intVal)));
}
auto resultVectorType = as<IRVectorType>(inst->getDataType());
List<SpvLiteralInteger> shuffleIndices;
shuffleIndices.setCount((Index)getIntVal(resultVectorType->getElementCount()));
for (Index i = 0; i < shuffleIndices.getCount(); i++)
shuffleIndices[i] = SpvLiteralInteger::from32((int32_t)i);

for (UInt i = 0; i < inst->getElementCount(); i++)
{
auto destIndex = (int32_t)getIntVal(inst->getElementIndex(i));
SLANG_ASSERT(destIndex < shuffleIndices.getCount());
shuffleIndices[destIndex] = SpvLiteralInteger::from32((int32_t)(i + shuffleIndices.getCount()));
}
auto source = inst->getSource();
if (!as<IRVectorType>(source->getDataType()))
{
IRBuilder builder(inst);
builder.setInsertBefore(inst);
source = builder.emitMakeVectorFromScalar(resultVectorType, source);
}
return emitOpVectorShuffle(parent, inst, inst->getFullType(), inst->getBase(), inst->getSource(), shuffleIndices.getArrayView());
}

Expand Down
33 changes: 25 additions & 8 deletions source/slang/slang-ir-glsl-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ IRType* getIRVectorBaseType(IRType* type)
void legalizeImageSubscriptStoreForGLSL(IRBuilder& builder, IRInst* storeInst)
{
builder.setInsertBefore(storeInst);
auto imageSubscript = as<IRImageSubscript>(storeInst->getOperand(0));
auto getElementPtr = as<IRGetElementPtr>(storeInst->getOperand(0));
IRImageSubscript* imageSubscript = nullptr;
if (getElementPtr)
imageSubscript = as<IRImageSubscript>(getElementPtr->getBase());
else
imageSubscript = as<IRImageSubscript>(storeInst->getOperand(0));
assert(imageSubscript);
auto imageElementType = cast<IRPtrTypeBase>(imageSubscript->getDataType())->getValueType();
auto coordType = imageSubscript->getCoord()->getDataType();
Expand All @@ -52,14 +57,26 @@ void legalizeImageSubscriptStoreForGLSL(IRBuilder& builder, IRInst* storeInst)
{
case kIROp_Store:
{
auto newValue = storeInst->getOperand(1);
if (getIRVectorElementSize(imageElementType) != 4)
IRInst* newValue = nullptr;
if (getElementPtr)
{
auto vectorBaseType = getIRVectorBaseType(imageElementType);
newValue = builder.emitVectorReshape(
builder.getVectorType(
vectorBaseType, builder.getIntValue(builder.getIntType(), 4)),
newValue);
IRType* vector4Type = builder.getVectorType(vectorBaseType, 4);
auto originalValue = builder.emitImageLoad(vector4Type, imageSubscript->getImage(), legalizedCoord);
auto index = getElementPtr->getIndex();
newValue = builder.emitSwizzleSet(vector4Type, originalValue, storeInst->getOperand(1), 1, &index);
}
else
{
newValue = storeInst->getOperand(1);
if (getIRVectorElementSize(imageElementType) != 4)
{
auto vectorBaseType = getIRVectorBaseType(imageElementType);
newValue = builder.emitVectorReshape(
builder.getVectorType(
vectorBaseType, builder.getIntValue(builder.getIntType(), 4)),
newValue);
}
}
auto imageStore = builder.emitImageStore(
builder.getVoidType(),
Expand Down Expand Up @@ -128,7 +145,7 @@ void legalizeImageSubscriptForGLSL(IRModule* module)
{
case kIROp_Store:
case kIROp_SwizzledStore:
if (inst->getOperand(0)->getOp() == kIROp_ImageSubscript)
if (getRootAddr(inst->getOperand(0))->getOp() == kIROp_ImageSubscript)
{
legalizeImageSubscriptStoreForGLSL(builder, inst);
}
Expand Down
16 changes: 15 additions & 1 deletion source/slang/slang-language-server-ast-lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,21 @@ struct ASTLookupExprVisitor: public ExprVisitor<ASTLookupExprVisitor, bool>
return false;
}

bool visitThisExpr(ThisExpr*) { return false; }
bool visitThisExpr(ThisExpr* expr)
{
static const int thisTokenLength = 4;
if (_isLocInRange(
context, expr->loc, thisTokenLength))
{
ASTLookupResult result;
result.path = context->nodePath;
result.path.add(expr);
context->results.add(result);
return true;
}
return false;
}

bool visitThisTypeExpr(ThisTypeExpr*) { return false; }
bool visitAndTypeExpr(AndTypeExpr* expr)
{
Expand Down
13 changes: 13 additions & 0 deletions source/slang/slang-language-server-semantic-tokens.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ const char* kSemanticTokenTypes[] = {
"string"
};

static const int kInitTokenLegnth = 6;

static_assert(SLANG_COUNT_OF(kSemanticTokenTypes) == (int)SemanticTokenType::NormalText, "kSemanticTokenTypes must match SemanticTokenType");

SemanticToken _createSemanticToken(SourceManager* manager, SourceLoc loc, Name* name)
Expand Down Expand Up @@ -183,6 +185,17 @@ List<SemanticToken> getSemanticTokens(Linkage* linkage, Module* module, UnownedS
maybeInsertToken(token);
}
}
else if (auto ctorDecl = as<ConstructorDecl>(node))
{
if (ctorDecl->getName())
{
SemanticToken token = _createSemanticToken(
manager, ctorDecl->getNameLoc(), ctorDecl->getName());
token.type = SemanticTokenType::Function;
token.length = kInitTokenLegnth;
maybeInsertToken(token);
}
}
else if (auto paramDecl = as<ParamDecl>(node))
{
if (paramDecl->getName())
Expand Down
27 changes: 21 additions & 6 deletions source/slang/slang-language-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,10 +642,26 @@ SlangResult LanguageServer::hover(
}
}
};
auto fillLoc = [&](SourceLoc loc)
{
auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(loc, SourceLocType::Actual);
hover.range.start.line = int(humaneLoc.line - 1);
hover.range.end.line = int(humaneLoc.line - 1);
hover.range.start.character = int(humaneLoc.column - 1);
hover.range.end.character = hover.range.start.character + int(doc->getTokenLength(humaneLoc.line, humaneLoc.column));
};
auto fillExprHoverInfo = [&](Expr* expr)
{
if (auto declRefExpr = as<DeclRefExpr>(expr))
return fillDeclRefHoverInfo(declRefExpr->declRef);
else if (auto thisExpr = as<ThisExpr>(expr))
{
if (expr->type)
{
sb << "```\n" << expr->type->toString() << " this" << "\n```\n";
}
fillLoc(expr->loc);
}
if (const auto higherOrderExpr = as<HigherOrderInvokeExpr>(expr))
{
String documentation;
Expand All @@ -657,12 +673,7 @@ SlangResult LanguageServer::hover(
<< "\n```\n";
sb << documentation;
maybeAppendAdditionalOverloadsHint();
auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(
expr->loc, SourceLocType::Actual);
hover.range.start.line = int(humaneLoc.line - 1);
hover.range.end.line = int(humaneLoc.line - 1);
hover.range.start.character = int(humaneLoc.column - 1);
hover.range.end.character = hover.range.start.character + int(doc->getTokenLength(humaneLoc.line, humaneLoc.column));
fillLoc(expr->loc);
}
};
if (auto declRefExpr = as<DeclRefExpr>(leafNode))
Expand All @@ -686,6 +697,10 @@ SlangResult LanguageServer::hover(
{
fillExprHoverInfo(higherOrderExpr);
}
else if (auto thisExprExpr = as<ThisExpr>(leafNode))
{
fillExprHoverInfo(thisExprExpr);
}
else if (auto importDecl = as<ImportDecl>(leafNode))
{
auto moduleLoc = getModuleLoc(version->linkage->getSourceManager(), importDecl->importedModuleDecl);
Expand Down
34 changes: 26 additions & 8 deletions source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1659,17 +1659,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
// produce transitive witnesses in shapes that will cuase us
// problems here.
//

if (!baseWitnessTable)
{
// If we don't have a valid baseWitnessTable,
// we are in a situation that `subToMid` is a `DifferentialBottomSubtypeWitness`
// that applies for all non-differentiable types.
// In this case `midToSup` will give us the `DifferentialBottom:IDifferentiable`
// witness table and we can just use that as the final result of
// this transitive witness.
SLANG_RELEASE_ASSERT(midToSup && as<IRWitnessTableType>(midToSup->getDataType()));
return LoweredValInfo::simple(midToSup);
// This can happen when we are looking up an associatedtype defined in the base interface from
// a derived interface that inherits the base interface.
// For now, we just emit a null witness.
// In the future, we may want to consider lower `ThisTypeConstraint` into IR as something like
// `IRThisTypeWitness`, and emit an explicit lookup through that witness instead.
SLANG_RELEASE_ASSERT(as<ThisType>(val->getSub()));
return LoweredValInfo();
}

if (auto declaredMidToSup = as<DeclaredSubtypeWitness>(val->getMidToSup()))
Expand Down Expand Up @@ -9854,6 +9853,25 @@ LoweredValInfo emitDeclRef(
// witness table for the concrete type that conforms to `ISomething<Foo>`.
//
auto irWitnessTable = lowerSimpleVal(context, thisTypeSubst->getWitness());
if (!irWitnessTable)
{
// If `thisTypeSubst` doesn't lower into an IRWitnessTable,
// this is a lookup of an interface requirement
// defined in some base interface from an interface type.
// For now we just lower that decl as if it is referenced
// from the same interface directly, e.g. a reference to
// IBase.AssocType from IDerived:IBase will be lowered as
// IRAssocType(IBase).
// We may want to consider extend our IR representation to
// have a `IRThisTypeWitness` object, so we can lower this case
// into an explicit lookup from `IRThisTypeWitness`,
// just like any other cases.
return emitDeclRef(
context,
createDefaultSpecializedDeclRef(context, nullptr, decl),
context->irBuilder->getTypeKind());
}

//
// The key to use for looking up the interface member is
// derived from the declaration.
Expand Down
Loading

0 comments on commit 46529df

Please sign in to comment.