diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 93c53a9754..5da19f377f 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -80,6 +80,9 @@ namespace Slang // No conversion at all kConversionCost_None = 0, + kConversionCost_GenericParamUpcast = 1, + kConversionCost_UnconstraintGenericParam = 20, + // Convert between matrices of different layout kConversionCost_MatrixLayout = 5, diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index e860e1ec6b..d1408a3fcb 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -286,6 +286,11 @@ Val* DeclaredSubtypeWitness::_resolveImplOverride() return this; } +ConversionCost DeclaredSubtypeWitness::_getOverloadResolutionCostOverride() +{ + return kConversionCost_GenericParamUpcast; +} + Val* DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) { if (auto genConstraintDeclRef = getDeclRef().as()) @@ -431,6 +436,11 @@ Val* TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, S return astBuilder->getTransitiveSubtypeWitness(substSubToMid, substMidToSup); } +ConversionCost TransitiveSubtypeWitness::_getOverloadResolutionCostOverride() +{ + return getSubToMid()->getOverloadResolutionCost() + getMidToSup()->getOverloadResolutionCost(); +} + void TransitiveSubtypeWitness::_toTextOverride(StringBuilder& out) { // Note: we only print the constituent @@ -471,6 +481,17 @@ Val* ExtractFromConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* a substSub, substSup, substWitness, getIndexInConjunction()); } +ConversionCost ExtractFromConjunctionSubtypeWitness::_getOverloadResolutionCostOverride() +{ + auto witness = as(getConjunctionWitness()); + if (!witness) + return kConversionCost_None; + auto index = getIndexInConjunction(); + if (index < witness->getComponentCount()) + return witness->getComponentWitness(index)->getOverloadResolutionCost(); + return kConversionCost_None; +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ExtractExistentialSubtypeWitness::_toTextOverride(StringBuilder& out) @@ -541,6 +562,14 @@ Val* ConjunctionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, return result; } +ConversionCost ConjunctionSubtypeWitness::_getOverloadResolutionCostOverride() +{ + ConversionCost result = kConversionCost_None; + for (Index i = 0; i < getComponentCount(); i++) + result += getComponentWitness(i)->getOverloadResolutionCost(); + return result; +} + void ExtractFromConjunctionSubtypeWitness::_toTextOverride(StringBuilder& out) { out << "ExtractFromConjunctionSubtypeWitness("; diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index c45c42e027..f85a761875 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -457,6 +457,9 @@ class SubtypeWitness : public Witness Type* getSub() { return as(getOperand(0)); } Type* getSup() { return as(getOperand(1)); } + + ConversionCost _getOverloadResolutionCostOverride(); + ConversionCost getOverloadResolutionCost(); }; class TypeEqualityWitness : public SubtypeWitness @@ -493,6 +496,8 @@ class DeclaredSubtypeWitness : public SubtypeWitness { setOperands(inSub, inSup, inDeclRef); } + + ConversionCost _getOverloadResolutionCostOverride(); }; // A witness that `sub : sup` because `sub : mid` and `mid : sup` @@ -520,6 +525,8 @@ class TransitiveSubtypeWitness : public SubtypeWitness { setOperands(subType, supType, inSubToMid, inMidToSup); } + + ConversionCost _getOverloadResolutionCostOverride(); }; // A witness that `sub : sup` because `sub` was wrapped into @@ -580,6 +587,8 @@ class ConjunctionSubtypeWitness : public SubtypeWitness void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + + ConversionCost _getOverloadResolutionCostOverride(); }; /// A witness that `T <: L` or `T <: R` because `T <: L&R` @@ -609,6 +618,8 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + + ConversionCost _getOverloadResolutionCostOverride(); }; /// A value that represents a modifier attached to some other value diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 8fd4061db0..97dbbcfa39 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -261,8 +261,11 @@ namespace Slang DeclRef SemanticsVisitor::trySolveConstraintSystem( ConstraintSystem* system, DeclRef genericDeclRef, - ArrayView knownGenericArgs) + ArrayView knownGenericArgs, + ConversionCost& outBaseCost) { + outBaseCost = kConversionCost_None; + // For now the "solver" is going to be ridiculously simplistic. // The generic itself will have some constraints, and for now we add these @@ -340,6 +343,8 @@ namespace Slang } QualType type; + bool typeConstraintOptional = true; + for (auto& c : system->constraints) { if (c.decl != typeParam.getDecl()) @@ -348,11 +353,12 @@ namespace Slang auto cType = QualType(as(c.val), c.isUsedAsLValue); SLANG_RELEASE_ASSERT(cType); - if (!type) + if (!type || (typeConstraintOptional && !c.isOptional)) { type = cType; + typeConstraintOptional = c.isOptional; } - else + else if (!typeConstraintOptional) { auto joinType = TryJoinTypes(type, cType); if (!joinType) @@ -397,6 +403,7 @@ namespace Slang // TODO(tfoley): figure out how this needs to interact with // compile-time integers that aren't just constants... IntVal* val = nullptr; + bool valOptional = true; for (auto& c : system->constraints) { if (c.decl != valParam.getDecl()) @@ -405,13 +412,14 @@ namespace Slang auto cVal = as(c.val); SLANG_RELEASE_ASSERT(cVal); - if (!val) + if (!val || (valOptional && !c.isOptional)) { val = cVal; + valOptional = c.isOptional; } else { - if(!val->equals(cVal)) + if(!valOptional && !val->equals(cVal)) { // failure! return DeclRef(); @@ -450,6 +458,8 @@ namespace Slang // search for a conformance `Robin : ISidekick`, which involved // apply the substitutions we already know... + HashSet constrainedGenericParams; + for( auto constraintDecl : genericDeclRef.getDecl()->getMembersOfType() ) { DeclRef constraintDeclRef = m_astBuilder->getGenericAppDeclRef( @@ -458,6 +468,10 @@ namespace Slang // Extract the (substituted) sub- and super-type from the constraint. auto sub = getSub(m_astBuilder, constraintDeclRef); auto sup = getSup(m_astBuilder, constraintDeclRef); + + // Mark sub type as constrained. + if (auto subDeclRefType = as(constraintDeclRef.getDecl()->sub.type)) + constrainedGenericParams.add(subDeclRefType->getDeclRef().getDecl()); if (sub->equals(sup)) { @@ -475,6 +489,7 @@ namespace Slang { // We found a witness, so it will become an (implicit) argument. args.add(subTypeWitness); + outBaseCost += subTypeWitness->getOverloadResolutionCost(); } else { @@ -489,6 +504,13 @@ namespace Slang // system as being solved now, as a result of the witness we found. } + // Add a flat cost to all unconstrained generic params. + for (auto typeParamDecl : genericDeclRef.getDecl()->getMembersOfType()) + { + if (!constrainedGenericParams.contains(typeParamDecl)) + outBaseCost += kConversionCost_UnconstraintGenericParam; + } + // Make sure we haven't constructed any spurious constraints // that we aren't able to satisfy: for (auto c : system->constraints) @@ -810,6 +832,29 @@ namespace Slang return false; } + void SemanticsVisitor::maybeUnifyUnconstraintIntParam(ConstraintSystem& constraints, IntVal* param, IntVal* arg, bool paramIsLVal) + { + // If `param` is an unconstrained integer val param, and `arg` is a const int val, + // we add a constraint to the system that `param` must be equal to `arg`. + // If `param` is already constrained, ignore and do nothing. + if (auto typeCastParam = as(param)) + { + param = as(typeCastParam->getBase()); + } + auto intParam = as(param); + if (!intParam) + return; + for (auto c : constraints.constraints) + if (c.decl == intParam->getDeclRef().getDecl()) + return; + Constraint c; + c.decl = intParam->getDeclRef().getDecl(); + c.isUsedAsLValue = paramIsLVal; + c.val = arg; + c.isOptional = true; + constraints.constraints.add(c); + } + bool SemanticsVisitor::TryUnifyTypes( ConstraintSystem& constraints, QualType fst, @@ -880,6 +925,12 @@ namespace Slang { if(auto sndScalarType = as(snd)) { + // Try unify the vector count param. In case the vector count is defined by a generic value + // parameter, we want to be able to infer that parameter should be 1. + // However, we don't want a failed unification to fail the entire generic argument inference, + // because a scalar can still be casted into a vector of any length. + + maybeUnifyUnconstraintIntParam(constraints, fstVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), fst.isLeftValue); return TryUnifyTypes( constraints, QualType(fstVectorType->getElementType(), fst.isLeftValue), @@ -891,15 +942,13 @@ namespace Slang { if(auto sndVectorType = as(snd)) { + maybeUnifyUnconstraintIntParam(constraints, sndVectorType->getElementCount(), m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1), snd.isLeftValue); return TryUnifyTypes( constraints, QualType(fstScalarType, fst.isLeftValue), QualType(sndVectorType->getElementType(), snd.isLeftValue)); } } - - // TODO: the same thing for vectors... - return false; } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7c36bdd5f1..98a2b18a13 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -6630,7 +6630,9 @@ namespace Slang if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type)) return DeclRef(); - auto solvedDeclRef = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView()); + + ConversionCost baseCost; + auto solvedDeclRef = trySolveConstraintSystem(&constraints, makeDeclRef(extGenericDecl), ArrayView(), baseCost); if (!solvedDeclRef) { return DeclRef(); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 5b67dc4134..31be012d32 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1827,6 +1827,10 @@ namespace Slang Val* val = nullptr; // the value to which we are constraining it bool isUsedAsLValue = false; // If this constraint is for a type parameter, is the type used in an l-value parameter? bool satisfied = false; // Has this constraint been met? + + // Is this constraint optional? An optional constraint provides a hint value to a parameter + // if it is otherwise unconstrained, but doesn't take precedence over a constraint that is not optional. + bool isOptional = false; }; // A collection of constraints that will need to be satisfied (solved) @@ -1944,7 +1948,8 @@ namespace Slang DeclRef trySolveConstraintSystem( ConstraintSystem* system, DeclRef genericDeclRef, - ArrayView knownGenericArgs); + ArrayView knownGenericArgs, + ConversionCost& outBaseCost); // State related to overload resolution for a call @@ -2120,25 +2125,30 @@ namespace Slang void AddOverloadCandidate( OverloadResolveContext& context, - OverloadCandidate& candidate); + OverloadCandidate& candidate, + ConversionCost baseCost); void AddHigherOrderOverloadCandidates( Expr* funcExpr, - OverloadResolveContext& context); + OverloadResolveContext& context, + ConversionCost baseCost); void AddFuncOverloadCandidate( LookupResultItem item, DeclRef funcDeclRef, - OverloadResolveContext& context); + OverloadResolveContext& context, + ConversionCost baseCost); void AddFuncOverloadCandidate( FuncType* /*funcType*/, - OverloadResolveContext& /*context*/); + OverloadResolveContext& /*context*/, + ConversionCost baseCost); void AddFuncExprOverloadCandidate( FuncType* funcType, OverloadResolveContext& context, - Expr* expr); + Expr* expr, + ConversionCost baseCost); // Add a candidate callee for overload resolution, based on // calling a particular `ConstructorDecl`. @@ -2147,7 +2157,8 @@ namespace Slang Type* type, DeclRef ctorDeclRef, OverloadResolveContext& context, - Type* resultType); + Type* resultType, + ConversionCost baseCost); // If the given declaration has generic parameters, then // return the corresponding `GenericDecl` that holds the @@ -2216,6 +2227,12 @@ namespace Slang QualType fst, QualType snd); + void maybeUnifyUnconstraintIntParam( + ConstraintSystem& constraints, + IntVal* param, + IntVal* arg, + bool paramIsLVal); + // Is the candidate extension declaration actually applicable to the given type DeclRef applyExtensionToType( ExtensionDecl* extDecl, @@ -2226,12 +2243,26 @@ namespace Slang // arguments to form a `DeclRef` to the inner declaration // that could be applicable in the context of the given // overloaded call. + // Also computes a `baseCost` for the inferred arguments, + // so that we can prefer a more specialized generic candidate + // when there is ambiguity. For example, given + // ``` + // interface IBase; + // interface IDerived : IBase; + // struct Derived : IDerived {} + // void f1(T b) + // void f2(T b); + // ``` + // We will prefer f2 when seeing f(Derived()), because it takes + // less steps to upcast `Derived` to `IDerived` than it does + // to `IBase`. // DeclRef inferGenericArguments( DeclRef genericDeclRef, OverloadResolveContext& context, ArrayView knownGenericArgs, - List *innerParameterTypes = nullptr); + ConversionCost &outBaseCost, + List *innerParameterTypes = nullptr); void AddTypeOverloadCandidates( Type* type, @@ -2239,7 +2270,8 @@ namespace Slang void AddDeclRefOverloadCandidates( LookupResultItem item, - OverloadResolveContext& context); + OverloadResolveContext& context, + ConversionCost baseCost); void AddOverloadCandidates( LookupResult const& result, diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 2d7315cd27..d7d29a4e14 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1246,11 +1246,14 @@ namespace Slang void SemanticsVisitor::AddOverloadCandidate( OverloadResolveContext& context, - OverloadCandidate& candidate) + OverloadCandidate& candidate, + ConversionCost baseCost) { // Try the candidate out, to see if it is applicable at all. TryCheckOverloadCandidate(context, candidate); + candidate.conversionCostSum += baseCost; + // Now (potentially) add it to the set of candidate overloads to consider. AddOverloadCandidateInner(context, candidate); } @@ -1258,7 +1261,8 @@ namespace Slang void SemanticsVisitor::AddFuncOverloadCandidate( LookupResultItem item, DeclRef funcDeclRef, - OverloadResolveContext& context) + OverloadResolveContext& context, + ConversionCost baseCost) { auto funcDecl = funcDeclRef.getDecl(); ensureDecl(funcDecl, DeclCheckState::CanUseFuncSignature); @@ -1288,25 +1292,27 @@ namespace Slang candidate.item = item; candidate.resultType = getResultType(m_astBuilder, funcDeclRef); - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } void SemanticsVisitor::AddFuncOverloadCandidate( FuncType* funcType, - OverloadResolveContext& context) + OverloadResolveContext& context, + ConversionCost baseCost) { OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; candidate.funcType = funcType; candidate.resultType = funcType->getResultType(); - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } void SemanticsVisitor::AddFuncExprOverloadCandidate( FuncType* funcType, OverloadResolveContext& context, - Expr* expr) + Expr* expr, + ConversionCost baseCost) { SLANG_ASSERT(expr); OverloadCandidate candidate; @@ -1315,7 +1321,7 @@ namespace Slang candidate.resultType = funcType->getResultType(); candidate.exprVal = expr; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } void SemanticsVisitor::AddCtorOverloadCandidate( @@ -1323,7 +1329,8 @@ namespace Slang Type* type, DeclRef ctorDeclRef, OverloadResolveContext& context, - Type* resultType) + Type* resultType, + ConversionCost baseCost) { SLANG_UNUSED(type) @@ -1346,13 +1353,14 @@ namespace Slang candidate.item = ctorItem; candidate.resultType = resultType; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } DeclRef SemanticsVisitor::inferGenericArguments( DeclRef genericDeclRef, OverloadResolveContext& context, ArrayView knownGenericArgs, + ConversionCost& outBaseCost, List *innerParameterTypes) { // We have been asked to infer zero or more arguments to @@ -1469,7 +1477,7 @@ namespace Slang // so that the solver knows to accept those arguments as-is. // return trySolveConstraintSystem( - &constraints, genericDeclRef, knownGenericArgs); + &constraints, genericDeclRef, knownGenericArgs, outBaseCost); } void SemanticsVisitor::AddTypeOverloadCandidates( @@ -1517,8 +1525,10 @@ namespace Slang auto genericDeclRef = genericItem.declRef.as(); SLANG_ASSERT(genericDeclRef); + ConversionCost baseCost = kConversionCost_None; + // Try to infer generic arguments, based on the context - DeclRef innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs); + DeclRef innerRef = inferGenericArguments(genericDeclRef, context, knownGenericArgs, baseCost); if (innerRef) { @@ -1528,7 +1538,7 @@ namespace Slang LookupResultItem innerItem; innerItem.breadcrumbs = genericItem.breadcrumbs; innerItem.declRef = innerRef; - AddDeclRefOverloadCandidates(innerItem, context); + AddDeclRefOverloadCandidates(innerItem, context, baseCost); } else { @@ -1546,11 +1556,12 @@ namespace Slang void SemanticsVisitor::AddDeclRefOverloadCandidates( LookupResultItem item, - OverloadResolveContext& context) + OverloadResolveContext& context, + ConversionCost baseCost) { if (auto funcDeclRef = item.declRef.as()) { - AddFuncOverloadCandidate(item, funcDeclRef, context); + AddFuncOverloadCandidate(item, funcDeclRef, context, baseCost); } else if (auto aggTypeDeclRef = item.declRef.as()) { @@ -1584,7 +1595,7 @@ namespace Slang const auto type = localDeclRef.getDecl()->getType(); // We can only add overload candidates if this is known to be a function if(const auto funType = as(type)) - AddFuncExprOverloadCandidate(funType, context, context.originalExpr->functionExpr); + AddFuncExprOverloadCandidate(funType, context, context.originalExpr->functionExpr, baseCost); else return; } @@ -1603,12 +1614,12 @@ namespace Slang { for(auto item : result.items) { - AddDeclRefOverloadCandidates(item, context); + AddDeclRefOverloadCandidates(item, context, kConversionCost_None); } } else { - AddDeclRefOverloadCandidates(result.item, context); + AddDeclRefOverloadCandidates(result.item, context, kConversionCost_None); } } @@ -1633,17 +1644,17 @@ namespace Slang // The expression directly referenced a declaration, // so we can use that declaration directly to look // for anything applicable. - AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context); + AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context, kConversionCost_None); } else if (auto higherOrderExpr = as(funcExpr)) { // The expression is the result of a higher order function application. - AddHigherOrderOverloadCandidates(higherOrderExpr, context); + AddHigherOrderOverloadCandidates(higherOrderExpr, context, kConversionCost_None); } else if (auto funcType = as(funcExprType)) { // TODO(tfoley): deprecate this path... - AddFuncOverloadCandidate(funcType, context); + AddFuncOverloadCandidate(funcType, context, kConversionCost_None); } else if (auto overloadedExpr = as(funcExpr)) { @@ -1683,7 +1694,8 @@ namespace Slang void SemanticsVisitor::AddHigherOrderOverloadCandidates( Expr* funcExpr, - OverloadResolveContext& context) + OverloadResolveContext& context, + ConversionCost baseCost) { // Lookup the higher order function and process types accordingly. In the future, // if there are enough varieties, we can have dispatch logic instead of an @@ -1705,7 +1717,7 @@ namespace Slang candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(baseFuncDeclRef); candidate.exprVal = expr; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost); } else if (auto baseFuncGenericDeclRef = funcDeclRefExpr->declRef.as()) { @@ -1721,10 +1733,12 @@ namespace Slang // Try to infer generic arguments, based on the updated context. OverloadResolveContext subContext = context; + ConversionCost baseCost1 = kConversionCost_None; DeclRef innerRef = inferGenericArguments( baseFuncGenericDeclRef, context, ArrayView(), + baseCost1, ¶mTypes); if (!innerRef) @@ -1762,7 +1776,7 @@ namespace Slang } candidate.exprVal = expr; expr->type.type = diffFuncType; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, baseCost + baseCost1); } else { @@ -1868,7 +1882,6 @@ namespace Slang context.originalExpr = expr; context.funcLoc = funcExpr->loc; - context.argCount = expr->arguments.getCount(); context.args = expr->arguments.getBuffer(); context.loc = expr->loc; @@ -2039,7 +2052,7 @@ namespace Slang candidate.item = baseItem; candidate.resultType = nullptr; - AddOverloadCandidate(context, candidate); + AddOverloadCandidate(context, candidate, kConversionCost_None); } } diff --git a/source/slang/slang-check-resolve-val.cpp b/source/slang/slang-check-resolve-val.cpp index 91722f82c5..7cd78a1bfb 100644 --- a/source/slang/slang-check-resolve-val.cpp +++ b/source/slang/slang-check-resolve-val.cpp @@ -45,4 +45,14 @@ Val* SubtypeWitness::_resolveImplOverride() return as(defaultResolveImpl()); } +ConversionCost SubtypeWitness::_getOverloadResolutionCostOverride() +{ + return kConversionCost_None; +} + +ConversionCost SubtypeWitness::getOverloadResolutionCost() +{ + SLANG_AST_NODE_VIRTUAL_CALL(SubtypeWitness, getOverloadResolutionCost, ()); +} + } diff --git a/tests/diagnostics/bad-operator-call.slang b/tests/diagnostics/bad-operator-call.slang index 2764d27ae9..2e0a196a2f 100644 --- a/tests/diagnostics/bad-operator-call.slang +++ b/tests/diagnostics/bad-operator-call.slang @@ -14,7 +14,7 @@ void test() { int a; S b; - // CHECK:{{.*}}.slang(18): error {{.*}}: no overload for '+=' applicable to arguments of type (int, S) + // CHECK:{{.*}}.slang(18): error {{.*}}: a += b; // CHECK:{{.*}}.slang(20): error {{.*}}: no overload for '+' applicable to arguments of type (int, S) a = a + b; @@ -31,4 +31,4 @@ void test() d += c; // CHECK:{{.*}}.slang(33): error {{.*}}: no overload for '+' applicable to arguments of type (vector, vector) d = c + d; -} \ No newline at end of file +} diff --git a/tests/language-feature/generics/generic-overload-disambiguation.slang b/tests/language-feature/generics/generic-overload-disambiguation.slang new file mode 100644 index 0000000000..57d7a6d7e5 --- /dev/null +++ b/tests/language-feature/generics/generic-overload-disambiguation.slang @@ -0,0 +1,42 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +// Test that calls to overloaded generic functions can be resolved to prefer the +// generic candidate with more specialized constraints. +interface IBase +{ + float get(); +} +interface IDerived : IBase +{ + +} +float process(T v) +{ + return 0.0; +} +float process(T v) +{ + return v.get(); +} + +float process(T v) +{ + return v.get() + 1.0; +} + +struct D : IDerived +{ + float get() { return 1.0; } +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + D d; + outputBuffer[0] = process(d); + // CHECK: 2.0 +} diff --git a/tests/language-feature/generics/vector-generic.slang b/tests/language-feature/generics/vector-generic.slang new file mode 100644 index 0000000000..6bcd88018c --- /dev/null +++ b/tests/language-feature/generics/vector-generic.slang @@ -0,0 +1,19 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +// Test that generic argument inference works when passing a scalar to a generic vector parameter. +T process(vector v) +{ + return v[0]; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + float a = 1.0; + outputBuffer[0] = process(a); + // CHECK: 1.0 +}