Skip to content

Commit

Permalink
Improve generic type argument inference. (#3370)
Browse files Browse the repository at this point in the history
* Improve generic type argument inference.

* Fix.

* Fix.

---------

Co-authored-by: Yong He <yhe@nvidia.com>
  • Loading branch information
csyonghe and Yong He authored Nov 29, 2023
1 parent 62426e9 commit 4fb3b10
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 45 deletions.
3 changes: 3 additions & 0 deletions source/slang/slang-ast-support-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ namespace Slang
// No conversion at all
kConversionCost_None = 0,

kConversionCost_GenericParamUpcast = 1,
kConversionCost_UnconstraintGenericParam = 20,

// Convert between matrices of different layout
kConversionCost_MatrixLayout = 5,

Expand Down
29 changes: 29 additions & 0 deletions source/slang/slang-ast-val.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ Val* DeclaredSubtypeWitness::_resolveImplOverride()
return this;
}

ConversionCost DeclaredSubtypeWitness::_getOverloadResolutionCostOverride()
{
return kConversionCost_GenericParamUpcast;
}

Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff)
{
if (auto genConstraintDeclRef = getDeclRef().as<GenericTypeConstraintDecl>())
Expand Down Expand Up @@ -431,6 +436,11 @@ Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, S
return astBuilder->getTransitiveSubtypeWitness(substSubToMid, substMidToSup);
}

ConversionCost TransitiveSubtypeWitness::_getOverloadResolutionCostOverride()
{
return getSubToMid()->getOverloadResolutionCost() + getMidToSup()->getOverloadResolutionCost();
}

void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out)
{
// Note: we only print the constituent
Expand Down Expand Up @@ -471,6 +481,17 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a
substSub, substSup, substWitness, getIndexInConjunction());
}

ConversionCost ExtractFromConjunctionSubtypeWitness::_getOverloadResolutionCostOverride()
{
auto witness = as<ConjunctionSubtypeWitness>(getConjunctionWitness());
if (!witness)
return kConversionCost_None;
auto index = getIndexInConjunction();
if (index < witness->getComponentCount())
return witness->getComponentWitness(index)->getOverloadResolutionCost();
return kConversionCost_None;
}

// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

void ExtractExistentialSubtypeWitness::_toTextOverride(StringBuilder& out)
Expand Down Expand Up @@ -541,6 +562,14 @@ Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder,
return result;
}

ConversionCost ConjunctionSubtypeWitness::_getOverloadResolutionCostOverride()
{
ConversionCost result = kConversionCost_None;
for (Index i = 0; i < getComponentCount(); i++)
result += getComponentWitness(i)->getOverloadResolutionCost();
return result;
}

void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out)
{
out << "ExtractFromConjunctionSubtypeWitness(";
Expand Down
11 changes: 11 additions & 0 deletions source/slang/slang-ast-val.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,9 @@ class SubtypeWitness : public Witness

Type* getSub() { return as<Type>(getOperand(0)); }
Type* getSup() { return as<Type>(getOperand(1)); }

ConversionCost _getOverloadResolutionCostOverride();
ConversionCost getOverloadResolutionCost();
};

class TypeEqualityWitness : public SubtypeWitness
Expand Down Expand Up @@ -493,6 +496,8 @@ class DeclaredSubtypeWitness : public SubtypeWitness
{
setOperands(inSub, inSup, inDeclRef);
}

ConversionCost _getOverloadResolutionCostOverride();
};

// A witness that `sub : sup` because `sub : mid` and `mid : sup`
Expand Down Expand Up @@ -520,6 +525,8 @@ class TransitiveSubtypeWitness : public SubtypeWitness
{
setOperands(subType, supType, inSubToMid, inMidToSup);
}

ConversionCost _getOverloadResolutionCostOverride();
};

// A witness that `sub : sup` because `sub` was wrapped into
Expand Down Expand Up @@ -580,6 +587,8 @@ class ConjunctionSubtypeWitness : public SubtypeWitness

void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);

ConversionCost _getOverloadResolutionCostOverride();
};

/// A witness that `T <: L` or `T <: R` because `T <: L&R`
Expand Down Expand Up @@ -609,6 +618,8 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness

void _toTextOverride(StringBuilder& out);
Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);

ConversionCost _getOverloadResolutionCostOverride();
};

/// A value that represents a modifier attached to some other value
Expand Down
65 changes: 57 additions & 8 deletions source/slang/slang-check-constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,11 @@ namespace Slang
DeclRef<Decl> SemanticsVisitor::trySolveConstraintSystem(
ConstraintSystem* system,
DeclRef<GenericDecl> genericDeclRef,
ArrayView<Val*> knownGenericArgs)
ArrayView<Val*> knownGenericArgs,
ConversionCost& outBaseCost)
{
outBaseCost = kConversionCost_None;

// For now the "solver" is going to be ridiculously simplistic.

// The generic itself will have some constraints, and for now we add these
Expand Down Expand Up @@ -340,6 +343,8 @@ namespace Slang
}

QualType type;
bool typeConstraintOptional = true;

for (auto& c : system->constraints)
{
if (c.decl != typeParam.getDecl())
Expand All @@ -348,11 +353,12 @@ namespace Slang
auto cType = QualType(as<Type>(c.val), c.isUsedAsLValue);
SLANG_RELEASE_ASSERT(cType);

if (!type)
if (!type || (typeConstraintOptional && !c.isOptional))
{
type = cType;
typeConstraintOptional = c.isOptional;
}
else
else if (!typeConstraintOptional)
{
auto joinType = TryJoinTypes(type, cType);
if (!joinType)
Expand Down Expand Up @@ -397,6 +403,7 @@ namespace Slang
// TODO(tfoley): figure out how this needs to interact with
// compile-time integers that aren't just constants...
IntVal* val = nullptr;
bool valOptional = true;
for (auto& c : system->constraints)
{
if (c.decl != valParam.getDecl())
Expand All @@ -405,13 +412,14 @@ namespace Slang
auto cVal = as<IntVal>(c.val);
SLANG_RELEASE_ASSERT(cVal);

if (!val)
if (!val || (valOptional && !c.isOptional))
{
val = cVal;
valOptional = c.isOptional;
}
else
{
if(!val->equals(cVal))
if(!valOptional && !val->equals(cVal))
{
// failure!
return DeclRef<Decl>();
Expand Down Expand Up @@ -450,6 +458,8 @@ namespace Slang
// search for a conformance `Robin : ISidekick`, which involved
// apply the substitutions we already know...

HashSet<Decl*> constrainedGenericParams;

for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeConstraintDecl>() )
{
DeclRef<GenericTypeConstraintDecl> constraintDeclRef = m_astBuilder->getGenericAppDeclRef(
Expand All @@ -458,6 +468,10 @@ namespace Slang
// Extract the (substituted) sub- and super-type from the constraint.
auto sub = getSub(m_astBuilder, constraintDeclRef);
auto sup = getSup(m_astBuilder, constraintDeclRef);

// Mark sub type as constrained.
if (auto subDeclRefType = as<DeclRefType>(constraintDeclRef.getDecl()->sub.type))
constrainedGenericParams.add(subDeclRefType->getDeclRef().getDecl());

if (sub->equals(sup))
{
Expand All @@ -475,6 +489,7 @@ namespace Slang
{
// We found a witness, so it will become an (implicit) argument.
args.add(subTypeWitness);
outBaseCost += subTypeWitness->getOverloadResolutionCost();
}
else
{
Expand All @@ -489,6 +504,13 @@ namespace Slang
// system as being solved now, as a result of the witness we found.
}

// Add a flat cost to all unconstrained generic params.
for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>())
{
if (!constrainedGenericParams.contains(typeParamDecl))
outBaseCost += kConversionCost_UnconstraintGenericParam;
}

// Make sure we haven't constructed any spurious constraints
// that we aren't able to satisfy:
for (auto c : system->constraints)
Expand Down Expand Up @@ -810,6 +832,29 @@ namespace Slang
return false;
}

void SemanticsVisitor::maybeUnifyUnconstraintIntParam(ConstraintSystem& constraints, IntVal* param, IntVal* arg, bool paramIsLVal)
{
// If `param` is an unconstrained integer val param, and `arg` is a const int val,
// we add a constraint to the system that `param` must be equal to `arg`.
// If `param` is already constrained, ignore and do nothing.
if (auto typeCastParam = as<TypeCastIntVal>(param))
{
param = as<IntVal>(typeCastParam->getBase());
}
auto intParam = as<GenericParamIntVal>(param);
if (!intParam)
return;
for (auto c : constraints.constraints)
if (c.decl == intParam->getDeclRef().getDecl())
return;
Constraint c;
c.decl = intParam->getDeclRef().getDecl();
c.isUsedAsLValue = paramIsLVal;
c.val = arg;
c.isOptional = true;
constraints.constraints.add(c);
}

bool SemanticsVisitor::TryUnifyTypes(
ConstraintSystem& constraints,
QualType fst,
Expand Down Expand Up @@ -880,6 +925,12 @@ namespace Slang
{
if(auto sndScalarType = as<BasicExpressionType>(snd))
{
// Try unify the vector count param. In case the vector count is defined by a generic value
// parameter, we want to be able to infer that parameter should be 1.
// However, we don't want a failed unification to fail the entire generic argument inference,
// because a scalar can still be casted into a vector of any length.

maybeUnifyUnconstraintIntParam(constraints, fstVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), fst.isLeftValue);
return TryUnifyTypes(
constraints,
QualType(fstVectorType->getElementType(), fst.isLeftValue),
Expand All @@ -891,15 +942,13 @@ namespace Slang
{
if(auto sndVectorType = as<VectorExpressionType>(snd))
{
maybeUnifyUnconstraintIntParam(constraints, sndVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), snd.isLeftValue);
return TryUnifyTypes(
constraints,
QualType(fstScalarType, fst.isLeftValue),
QualType(sndVectorType->getElementType(), snd.isLeftValue));
}
}

// TODO: the same thing for vectors...

return false;
}

Expand Down
4 changes: 3 additions & 1 deletion source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6630,7 +6630,9 @@ namespace Slang

if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type))
return DeclRef<ExtensionDecl>();
auto solvedDeclRef = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView<Val*>());

ConversionCost baseCost;
auto solvedDeclRef = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView<Val*>(), baseCost);
if (!solvedDeclRef)
{
return DeclRef<ExtensionDecl>();
Expand Down
Loading

0 comments on commit 4fb3b10

Please sign in to comment.