Skip to content

Commit

Permalink
Capability type checking. (#3530)
Browse files Browse the repository at this point in the history
* Capability type checking.

* Fix.

---------

Co-authored-by: Yong He <yhe@nvidia.com>
  • Loading branch information
csyonghe and Yong He authored Feb 3, 2024
1 parent c15e7ad commit 1476489
Show file tree
Hide file tree
Showing 51 changed files with 1,869 additions and 489 deletions.
3 changes: 3 additions & 0 deletions source/slang/core.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -2475,6 +2475,9 @@ attribute_syntax [vk_image_format(format : String)] : FormatAttribute;
__attributeTarget(Decl)
attribute_syntax [allow(diagnostic: String)] : AllowAttribute;

__attributeTarget(Decl)
attribute_syntax[require(capability)] : RequireCapabilityAttribute;

// Linking
__attributeTarget(Decl)
attribute_syntax [__extern] : ExternAttribute;
Expand Down
37 changes: 10 additions & 27 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -4716,7 +4716,6 @@ T GetAttributeAtVertex(T attribute, uint vertexIndex)
{
case hlsl:
__intrinsic_asm "GetAttributeAtVertex";
case _GL_NV_fragment_shader_barycentric:
case _GL_EXT_fragment_shader_barycentric:
__intrinsic_asm "$0[$1]";
case spirv:
Expand Down Expand Up @@ -4749,7 +4748,6 @@ vector<T,N> GetAttributeAtVertex(vector<T,N> attribute, uint vertexIndex)
{
case hlsl:
__intrinsic_asm "GetAttributeAtVertex";
case _GL_NV_fragment_shader_barycentric:
case _GL_EXT_fragment_shader_barycentric:
__intrinsic_asm "$0[$1]";
case spirv:
Expand Down Expand Up @@ -4782,7 +4780,6 @@ matrix<T,N,M> GetAttributeAtVertex(matrix<T,N,M> attribute, uint vertexIndex)
{
case hlsl:
__intrinsic_asm "GetAttributeAtVertex";
case _GL_NV_fragment_shader_barycentric:
case _GL_EXT_fragment_shader_barycentric:
__intrinsic_asm "$0[$1]";
case spirv:
Expand Down Expand Up @@ -9288,8 +9285,7 @@ struct BuiltInTriangleIntersectionAttributes
// `executeCallableNV` is the GLSL intrinsic that will be used to implement
// `CallShader()` for GLSL-based targets.
//
__target_intrinsic(GL_NV_ray_tracing, "executeCallableNV")
__target_intrinsic(GL_EXT_ray_tracing, "executeCallableEXT")
__target_intrinsic(_GL_EXT_ray_tracing, "executeCallableEXT")
void __executeCallable(uint shaderIndex, int payloadLocation);

// Next is the custom intrinsic that will compute the payload location
Expand Down Expand Up @@ -9335,8 +9331,7 @@ void CallShader(uint shaderIndex, inout Payload payload)

// 10.3.2

__target_intrinsic(GL_NV_ray_tracing, "traceNV")
__target_intrinsic(GL_EXT_ray_tracing, "traceRayEXT")
__target_intrinsic(_GL_EXT_ray_tracing, "traceRayEXT")
void __traceRay(
RaytracingAccelerationStructure AccelerationStructure,
uint RayFlags,
Expand Down Expand Up @@ -9528,7 +9523,6 @@ bool __reportIntersection(float tHit, uint hitKind)
__target_switch
{
case _GL_EXT_ray_tracing: __intrinsic_asm "reportIntersectionEXT";
case _GL_NV_ray_tracing: __intrinsic_asm "reportIntersectionNV";
case spirv:
return spirv_asm {
result:$$bool = OpReportIntersectionKHR $tHit $hitKind;
Expand All @@ -9555,7 +9549,6 @@ void IgnoreHit()
{
case hlsl: __intrinsic_asm "IgnoreHit";
case _GL_EXT_ray_tracing: __intrinsic_asm "ignoreIntersectionEXT;";
case _GL_NV_ray_tracing: __intrinsic_asm "ignoreIntersectionNV";
case cuda: __intrinsic_asm "optixIgnoreIntersection";
case spirv: spirv_asm { OpIgnoreIntersectionKHR; %_ = OpLabel };
}
Expand All @@ -9568,7 +9561,6 @@ void AcceptHitAndEndSearch()
{
case hlsl: __intrinsic_asm "AcceptHitAndEndSearch";
case _GL_EXT_ray_tracing: __intrinsic_asm "terminateRayEXT;";
case _GL_NV_ray_tracing: __intrinsic_asm "terminateRayNV";
case cuda: __intrinsic_asm "optixTerminateRay";
case spirv: spirv_asm { OpTerminateRayKHR; %_ = OpLabel };
}
Expand All @@ -9587,7 +9579,6 @@ uint3 DispatchRaysIndex()
{
case hlsl: __intrinsic_asm "DispatchRaysIndex";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_LaunchIDEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_LaunchIDNV)";
case cuda: __intrinsic_asm "optixGetLaunchIndex";
case spirv:
return spirv_asm {
Expand All @@ -9602,7 +9593,6 @@ uint3 DispatchRaysDimensions()
{
case hlsl: __intrinsic_asm "DispatchRaysDimensions";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_LaunchSizeEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_LaunchSizeNV)";
case cuda: __intrinsic_asm "optixGetLaunchDimensions";
case spirv:
return spirv_asm {
Expand All @@ -9619,7 +9609,6 @@ float3 WorldRayOrigin()
{
case hlsl: __intrinsic_asm "WorldRayOrigin";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldRayOriginEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldRayOriginNV)";
case cuda: __intrinsic_asm "optixGetWorldRayOrigin";
case spirv:
return spirv_asm {
Expand All @@ -9634,7 +9623,6 @@ float3 WorldRayDirection()
{
case hlsl: __intrinsic_asm "WorldRayDirection";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldRayDirectionEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldRayDirectionNV)";
case cuda: __intrinsic_asm "optixGetWorldRayDirection";
case spirv:
return spirv_asm {
Expand All @@ -9649,7 +9637,6 @@ float RayTMin()
{
case hlsl: __intrinsic_asm "RayTMin";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_RayTminEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_RayTminNV)";
case cuda: __intrinsic_asm "optixGetRayTmin";
case spirv:
return spirv_asm {
Expand All @@ -9674,7 +9661,6 @@ float RayTCurrent()
{
case hlsl: __intrinsic_asm "RayTCurrent";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_RayTmaxEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_RayTmaxNV)";
case cuda: __intrinsic_asm "optixGetRayTmax";
case spirv:
return spirv_asm {
Expand All @@ -9689,7 +9675,6 @@ uint RayFlags()
{
case hlsl: __intrinsic_asm "RayFlags";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_IncomingRayFlagsEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_IncomingRayFlagsNV)";
case cuda: __intrinsic_asm "optixGetRayFlags";
case spirv:
return spirv_asm {
Expand Down Expand Up @@ -9720,7 +9705,6 @@ uint InstanceID()
{
case hlsl: __intrinsic_asm "InstanceID";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_InstanceCustomIndexEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_InstanceCustomIndexNV)";
case cuda: __intrinsic_asm "optixGetInstanceId";
case spirv:
return spirv_asm {
Expand Down Expand Up @@ -9749,7 +9733,6 @@ float3 ObjectRayOrigin()
{
case hlsl: __intrinsic_asm "ObjectRayOrigin";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectRayOriginEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectRayOriginNV)";
case cuda: __intrinsic_asm "optixGetObjectRayOrigin";
case spirv:
return spirv_asm {
Expand All @@ -9764,7 +9747,6 @@ float3 ObjectRayDirection()
{
case hlsl: __intrinsic_asm "ObjectRayDirection";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectRayDirectionEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectRayDirectionNV)";
case cuda: __intrinsic_asm "optixGetObjectRayDirection";
case spirv:
return spirv_asm {
Expand All @@ -9781,7 +9763,6 @@ float3x4 ObjectToWorld3x4()
{
case hlsl: __intrinsic_asm "ObjectToWorld3x4";
case _GL_EXT_ray_tracing: __intrinsic_asm "transpose(gl_ObjectToWorldEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "transpose(gl_ObjectToWorldNV)";
case spirv:
return spirv_asm {
%mat:$$float4x3 = OpLoad builtin(ObjectToWorldKHR:float4x3);
Expand All @@ -9796,7 +9777,6 @@ float3x4 WorldToObject3x4()
{
case hlsl: __intrinsic_asm "WorldToObject3x4";
case _GL_EXT_ray_tracing: __intrinsic_asm "transpose(gl_WorldToObjectEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "transpose(gl_WorldToObjectNV)";
case spirv:
return spirv_asm {
%mat:$$float4x3 = OpLoad builtin(WorldToObjectKHR:float4x3);
Expand All @@ -9811,7 +9791,6 @@ float4x3 ObjectToWorld4x3()
{
case hlsl: __intrinsic_asm "ObjectToWorld4x3";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectToWorldEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectToWorldNV)";
case spirv:
return spirv_asm {
result:$$float4x3 = OpLoad builtin(ObjectToWorldKHR:float4x3);
Expand All @@ -9825,7 +9804,6 @@ float4x3 WorldToObject4x3()
{
case hlsl: __intrinsic_asm "WorldToObject4x3";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldToObjectEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldToObjectNV)";
case spirv:
return spirv_asm {
result:$$float4x3 = OpLoad builtin(WorldToObjectKHR:float4x3);
Expand Down Expand Up @@ -9872,7 +9850,6 @@ uint HitKind()
{
case hlsl: __intrinsic_asm "HitKind";
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_HitKindEXT)";
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_HitKindNV)";
case cuda: __intrinsic_asm "optixGetHitKind";
case spirv:
return spirv_asm {
Expand Down Expand Up @@ -11874,6 +11851,7 @@ void debugBreak();

[__requiresNVAPI]
__glsl_extension(GL_EXT_shader_realtime_clock)
[require(shaderclock)]
uint getRealtimeClockLow()
{
__target_switch
Expand All @@ -11886,14 +11864,18 @@ uint getRealtimeClockLow()
__intrinsic_asm "clock";
case spirv:
return getRealtimeClock().x;
case cpp:
__intrinsic_asm "(uint32_t)std::chrono::high_resolution_clock::now().time_since_epoch().count()";
}
}

__target_intrinsic(cpp, "std::chrono::high_resolution_clock::now().time_since_epoch().count()")
__target_intrinsic(cuda, "clock64")
int64_t __cudaGetRealtimeClock();
int64_t __cudaCppGetRealtimeClock();

[__requiresNVAPI]
__glsl_extension(GL_EXT_shader_realtime_clock)
[require(shaderclock)]
uint2 getRealtimeClock()
{
__target_switch
Expand All @@ -11903,7 +11885,8 @@ uint2 getRealtimeClock()
case glsl:
__intrinsic_asm "clockRealtime2x32EXT()";
case cuda:
int64_t ticks = __cudaGetRealtimeClock();
case cpp:
int64_t ticks = __cudaCppGetRealtimeClock();
return uint2(uint(ticks), uint(uint64_t(ticks) >> 32));
case spirv:
return spirv_asm
Expand Down
10 changes: 9 additions & 1 deletion source/slang/slang-ast-base.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include "slang-generated-ast.h"
#include "slang-ast-reflect.h"

#include "slang-capability.h"
#include "slang-serialize-reflection.h"

// This file defines the primary base classes for the hierarchy of
Expand Down Expand Up @@ -695,6 +695,11 @@ class ModifiableSyntaxNode : public SyntaxNode
bool hasModifier() { return findModifier<T>() != nullptr; }
};

struct DeclReferenceWithLoc
{
Decl* referencedDecl;
SourceLoc referenceLoc;
};

// An intermediate type to represent either a single declaration, or a group of declarations
class DeclBase : public ModifiableSyntaxNode
Expand All @@ -716,6 +721,7 @@ class Decl : public DeclBase
DeclRefBase* getDefaultDeclRef();

NameLoc nameAndLoc;
CapabilitySet inferredCapabilityRequirements;

RefPtr<MarkupEntry> markup;

Expand All @@ -736,6 +742,8 @@ class Decl : public DeclBase
}
bool isChildOf(Decl* other) const;

// Track the decl reference that caused the requirement of a capability atom.
SLANG_UNREFLECTED Dictionary<CapabilityAtom, DeclReferenceWithLoc> capabilityRequirementProvenance;
private:
SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr;
SLANG_UNREFLECTED Index m_defaultDeclRefEpoch = -1;
Expand Down
25 changes: 25 additions & 0 deletions source/slang/slang-ast-dump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,31 @@ struct ASTDumpContext
m_writer->emit("}");
}

void dump(const CapabilitySet& capSet)
{
m_writer->emit("capability_set(");
bool isFirstSet = true;
for (auto& set : capSet.getExpandedAtoms())
{
if (!isFirstSet)
{
m_writer->emit(" | ");
}
bool isFirst = true;
for (auto atom : set.getExpandedAtoms())
{
if (!isFirst)
{
m_writer->emit("+");
}
dump(capabilityNameToString((CapabilityName)atom));
isFirst = false;
}
isFirstSet = false;
}
m_writer->emit(")");
}

void dumpObjectFull(NodeBase* node);

ASTDumpContext(SourceWriter* writer, ASTDumpUtil::Flags flags, ASTDumpUtil::Style dumpStyle):
Expand Down
47 changes: 28 additions & 19 deletions source/slang/slang-ast-iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@

namespace Slang
{
template <typename Callback>
template <typename Callback, typename Filter>
struct ASTIterator
{
const Callback& callback;
UnownedStringSlice fileName;
SourceManager* sourceManager;
ASTIterator(const Callback& func, SourceManager* manager, UnownedStringSlice sourceFileName)
const Filter& filter;
ASTIterator(const Callback& func, const Filter& filterFunc)
: callback(func)
, fileName(sourceFileName)
, sourceManager(manager)
, filter(filterFunc)
{}

void visitDecl(DeclBase* decl);
Expand Down Expand Up @@ -429,13 +427,11 @@ struct ASTIterator
};
};

template <typename CallbackFunc>
void ASTIterator<CallbackFunc>::visitDecl(DeclBase* decl)
template <typename CallbackFunc, typename FilterFunc>
void ASTIterator<CallbackFunc, FilterFunc>::visitDecl(DeclBase* decl)
{
// Don't look at the decl if it is defined in a different file.
if (!as<NamespaceDeclBase>(decl) && !sourceManager->getHumaneLoc(decl->loc, SourceLocType::Actual)
.pathInfo.foundPath.getUnownedSlice()
.endsWithCaseInsensitive(fileName))
if (!filter(decl))
return;

maybeDispatchCallback(decl);
Expand Down Expand Up @@ -490,24 +486,23 @@ void ASTIterator<CallbackFunc>::visitDecl(DeclBase* decl)
}
}
}
template <typename CallbackFunc>
void ASTIterator<CallbackFunc>::visitExpr(Expr* expr)
template <typename CallbackFunc, typename FilterFunc>
void ASTIterator<CallbackFunc, FilterFunc>::visitExpr(Expr* expr)
{
ASTIteratorExprVisitor visitor(this);
visitor.dispatchIfNotNull(expr);
}
template <typename CallbackFunc>
void ASTIterator<CallbackFunc>::visitStmt(Stmt* stmt)
template <typename CallbackFunc, typename FilterFunc>
void ASTIterator<CallbackFunc, FilterFunc>::visitStmt(Stmt* stmt)
{
ASTIteratorStmtVisitor visitor(this);
visitor.dispatchIfNotNull(stmt);
}

template <typename Func>
void iterateAST(
UnownedStringSlice fileName, SourceManager* manager, SyntaxNode* node, const Func& f)
template <typename Func, typename FilterFunc>
void iterateAST(SyntaxNode* node, const FilterFunc& filterFunc, const Func& f)
{
ASTIterator<Func> iter(f, manager, fileName);
ASTIterator<Func, FilterFunc> iter(f, filterFunc);
if (auto decl = as<Decl>(node))
{
iter.visitDecl(decl);
Expand All @@ -521,4 +516,18 @@ void iterateAST(
iter.visitStmt(stmt);
}
}

template <typename Func>
void iterateASTWithLanguageServerFilter(
UnownedStringSlice fileName, SourceManager* sourceManager, SyntaxNode* node, const Func& f)
{
auto filter = [&](DeclBase* decl)
{
return as<NamespaceDeclBase>(decl) ||
sourceManager->getHumaneLoc(decl->loc, SourceLocType::Actual)
.pathInfo.foundPath.getUnownedSlice()
.endsWithCaseInsensitive(fileName);
};
iterateAST(node, filter, f);
}
} // namespace Slang
Loading

0 comments on commit 1476489

Please sign in to comment.