Skip to content

Commit

Permalink
Add ConstBufferPointer::subscript. (#3415)
Browse files Browse the repository at this point in the history
Co-authored-by: Yong He <yhe@nvidia.com>
  • Loading branch information
csyonghe and Yong He authored Dec 16, 2023
1 parent f8b3027 commit b507d88
Show file tree
Hide file tree
Showing 27 changed files with 189 additions and 62 deletions.
12 changes: 12 additions & 0 deletions source/slang/core.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -2294,6 +2294,18 @@ bool __isVector()
return __isVector_impl(__declVal<T>());
}

__generic<T>
__intrinsic_op($(kIROp_GetNaturalStride))
int __naturalStrideOf_impl(T v);

__generic<T>
[__unsafeForceInlineEarly]
int __naturalStrideOf()
{
return __naturalStrideOf_impl(__declVal<T>());
}



// Binding Attributes

Expand Down
40 changes: 39 additions & 1 deletion source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -12555,7 +12555,45 @@ struct ConstBufferPointer
__intrinsic_asm "$0._data";
case spirv:
return spirv_asm {
result:$$T = OpLoad $this Aligned $Alignment;
result:$$T = OpLoad $this Aligned !Alignment;
};
}
}

__subscript(int index) -> T
{
[ForceInline]
get {return ConstBufferPointer<T>.fromUInt(toUInt() + __naturalStrideOf<T>() * index).get(); }
}

__glsl_version(450)
__glsl_extension(GL_EXT_shader_explicit_arithmetic_types_int64)
__glsl_extension(GL_EXT_buffer_reference)
static ConstBufferPointer<T> fromUInt(uint64_t val)
{
__target_switch
{
case glsl:
__intrinsic_asm "$TR($0)";
case spirv:
return spirv_asm {
result:$$ConstBufferPointer<T> = OpConvertUToPtr $val;
};
}
}

__glsl_version(450)
__glsl_extension(GL_EXT_shader_explicit_arithmetic_types_int64)
__glsl_extension(GL_EXT_buffer_reference)
uint64_t toUInt()
{
__target_switch
{
case glsl:
__intrinsic_asm "uint64_t($0)";
case spirv:
return spirv_asm {
result:$$uint64_t = OpConvertPtrToU $this;
};
}
}
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-ast-dump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,9 @@ struct ASTDumpContext
case SPIRVAsmOperand::SlangType:
m_writer->emit("$$");
break;
case SPIRVAsmOperand::SlangImmediateValue:
m_writer->emit("!");
break;
default:
SLANG_UNREACHABLE("Unhandled case in ast dump for SPIRVAsmOperand");
}
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-ast-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ class SPIRVAsmOperand
NamedValue, // Any other identifier
SlangValue,
SlangValueAddr,
SlangImmediateValue,
SlangType,
SampledType, // __sampledType(T), this becomes a 4 vector of the component type of T
ImageType, // __imageType(texture), returns the equivalaent OpTypeImage of a given texture typed value.
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-check-expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4262,6 +4262,7 @@ namespace Slang
operand.expr = typeExpr.exp;
}
else if(operand.flavor == SPIRVAsmOperand::SlangValue
|| operand.flavor == SPIRVAsmOperand::SlangImmediateValue
|| operand.flavor == SPIRVAsmOperand::SlangValueAddr
|| operand.flavor == SPIRVAsmOperand::ImageType
|| operand.flavor == SPIRVAsmOperand::SampledImageType)
Expand Down
20 changes: 10 additions & 10 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ Result linkAndOptimizeIR(
// Lower all the LValue implict casts (used for out/inout/ref scenarios)
lowerLValueCast(targetRequest, irModule);

simplifyIR(irModule, IRSimplificationOptions::getDefault(), sink);
simplifyIR(targetRequest, irModule, IRSimplificationOptions::getDefault(), sink);

// Fill in default matrix layout into matrix types that left layout unspecified.
specializeMatrixLayout(codeGenContext->getTargetReq(), irModule);
Expand Down Expand Up @@ -407,7 +407,7 @@ Result linkAndOptimizeIR(
//auto b1 = dumpIRToString(irModule->getModuleInst());
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE");
if (!codeGenContext->isSpecializationDisabled())
changed |= specializeModule(irModule, codeGenContext->getSink());
changed |= specializeModule(codeGenContext->getTargetReq(), irModule, codeGenContext->getSink());
if (codeGenContext->getSink()->getErrorCount() != 0)
return SLANG_FAIL;
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-SPECIALIZE");
Expand All @@ -425,7 +425,7 @@ Result linkAndOptimizeIR(
// Unroll loops.
if (codeGenContext->getSink()->getErrorCount() == 0)
{
if (!unrollLoopsInModule(irModule, codeGenContext->getSink()))
if (!unrollLoopsInModule(targetRequest, irModule, codeGenContext->getSink()))
return SLANG_FAIL;
}

Expand All @@ -438,15 +438,15 @@ Result linkAndOptimizeIR(

dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
enableIRValidationAtInsert();
changed |= processAutodiffCalls(irModule, sink);
changed |= processAutodiffCalls(targetRequest, irModule, sink);
disableIRValidationAtInsert();
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF");

if (!changed)
break;
}

finalizeAutoDiffPass(irModule);
finalizeAutoDiffPass(targetRequest, irModule);

finalizeSpecialization(irModule);

Expand Down Expand Up @@ -481,7 +481,7 @@ Result linkAndOptimizeIR(

validateIRModuleIfEnabled(codeGenContext, irModule);

simplifyIR(irModule, IRSimplificationOptions::getFast(), sink);
simplifyIR(targetRequest, irModule, IRSimplificationOptions::getFast(), sink);

if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc))
{
Expand Down Expand Up @@ -510,7 +510,7 @@ Result linkAndOptimizeIR(
// up downstream passes like type legalization, so we
// will run a DCE pass to clean up after the specialization.
//
simplifyIR(irModule, IRSimplificationOptions::getDefault(), sink);
simplifyIR(targetRequest, irModule, IRSimplificationOptions::getDefault(), sink);

validateIRModuleIfEnabled(codeGenContext, irModule);

Expand Down Expand Up @@ -596,7 +596,7 @@ Result linkAndOptimizeIR(
// to see if we can clean up any temporaries created by legalization.
// (e.g., things that used to be aggregated might now be split up,
// so that we can work with the individual fields).
simplifyIR(irModule, IRSimplificationOptions::getFast(), sink);
simplifyIR(targetRequest, irModule, IRSimplificationOptions::getFast(), sink);

#if 0
dumpIRIfEnabled(codeGenContext, irModule, "AFTER SSA");
Expand Down Expand Up @@ -940,7 +940,7 @@ Result linkAndOptimizeIR(
{
IRSimplificationOptions simplificationOptions = IRSimplificationOptions::getFast();
simplificationOptions.cfgOptions.removeTrivialSingleIterationLoops = true;
simplifyIR(irModule, simplificationOptions, sink);
simplifyIR(targetRequest, irModule, simplificationOptions, sink);
}

// As a late step, we need to take the SSA-form IR and move things *out*
Expand Down Expand Up @@ -1018,7 +1018,7 @@ Result linkAndOptimizeIR(
}

// Run a final round of simplifications to clean up unused things after phi-elimination.
simplifyNonSSAIR(irModule, IRSimplificationOptions::getFast());
simplifyNonSSAIR(targetRequest, irModule, IRSimplificationOptions::getFast());

// We include one final step to (optionally) dump the IR and validate
// it after all of the optimization passes are complete. This should
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir-autodiff-fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1642,7 +1642,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
if (SLANG_SUCCEEDED(result))
{
disableIRValidationAtInsert();
simplifyFunc(func, IRSimplificationOptions::getDefault());
simplifyFunc(autoDiffSharedContext->targetRequest, func, IRSimplificationOptions::getDefault());
enableIRValidationAtInsert();
}
return result;
Expand Down
11 changes: 6 additions & 5 deletions source/slang/slang-ir-autodiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
return result;
}

AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst)
: moduleInst(inModuleInst)
AutoDiffSharedContext::AutoDiffSharedContext(TargetRequest* target, IRModuleInst* inModuleInst)
: moduleInst(inModuleInst), targetRequest(target)
{
differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
if (differentiableInterfaceType)
Expand Down Expand Up @@ -1979,6 +1979,7 @@ struct AutoDiffPass : public InstPassBase
};

bool processAutodiffCalls(
TargetRequest* target,
IRModule* module,
DiagnosticSink* sink,
IRAutodiffPassOptions const&)
Expand All @@ -1987,7 +1988,7 @@ bool processAutodiffCalls(
bool modified = false;

// Create shared context for all auto-diff related passes
AutoDiffSharedContext autodiffContext(module->getModuleInst());
AutoDiffSharedContext autodiffContext(target, module->getModuleInst());

AutoDiffPass pass(&autodiffContext, sink);

Expand Down Expand Up @@ -2077,12 +2078,12 @@ void releaseNullDifferentialType(AutoDiffSharedContext* context)
}
}

bool finalizeAutoDiffPass(IRModule* module)
bool finalizeAutoDiffPass(TargetRequest* target, IRModule* module)
{
bool modified = false;

// Create shared context for all auto-diff related passes
AutoDiffSharedContext autodiffContext(module->getModuleInst());
AutoDiffSharedContext autodiffContext(target, module->getModuleInst());

// Replaces IRDifferentialPairType with an auto-generated struct,
// IRDifferentialPairGetDifferential with 'differential' field access,
Expand Down
7 changes: 5 additions & 2 deletions source/slang/slang-ir-autodiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ struct DiffTranscriberSet

struct AutoDiffSharedContext
{
TargetRequest* targetRequest = nullptr;

IRModuleInst* moduleInst = nullptr;

// A reference to the builtin IDifferentiable interface type.
Expand Down Expand Up @@ -113,7 +115,7 @@ struct AutoDiffSharedContext

DiffTranscriberSet transcriberSet;

AutoDiffSharedContext(IRModuleInst* inModuleInst);
AutoDiffSharedContext(TargetRequest* target, IRModuleInst* inModuleInst);

private:

Expand Down Expand Up @@ -357,11 +359,12 @@ struct IRAutodiffPassOptions
};

bool processAutodiffCalls(
TargetRequest* target,
IRModule* module,
DiagnosticSink* sink,
IRAutodiffPassOptions const& options = IRAutodiffPassOptions());

bool finalizeAutoDiffPass(IRModule* module);
bool finalizeAutoDiffPass(TargetRequest* target, IRModule* module);

// Utility methods

Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir-check-differentiability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct CheckDifferentiabilityPassContext : public InstPassBase
Dictionary<IRInst*, DifferentiableLevel> differentiableFunctions;

CheckDifferentiabilityPassContext(IRModule* inModule, DiagnosticSink* inSink)
: InstPassBase(inModule), sink(inSink), sharedContext(inModule->getModuleInst())
: InstPassBase(inModule), sink(inSink), sharedContext(nullptr, inModule->getModuleInst())
{}

bool _isFuncMarkedForAutoDiff(IRInst* func)
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir-inline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ void performGLSLResourceReturnFunctionInlining(IRModule* module)
while (changed)
{
changed = pass.considerAllCallSites();
simplifyIR(module, IRSimplificationOptions::getFast());
simplifyIR(nullptr, module, IRSimplificationOptions::getFast());
}
}

Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-ir-inst-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,8 @@ INST(StructuredBufferGetDimensions, StructuredBufferGetDimensions, 1, 0)
INST(AtomicCounterIncrement, AtomicCounterIncrement, 1, 0)
INST(AtomicCounterDecrement, AtomicCounterDecrement, 1, 0)

INST(GetNaturalStride, getNaturalStride, 1, 0)

INST(MeshOutputRef, meshOutputRef, 2, 0)

// Construct a vector from a scalar
Expand Down
15 changes: 9 additions & 6 deletions source/slang/slang-ir-loop-unroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ static int _getLoopMaxIterationsToUnroll(IRLoop* loopInst)
}

static void _foldAndSimplifyLoopIteration(
TargetRequest* targetRequest,
IRBuilder& builder,
List<IRBlock*>& clonedBlocks,
IRBlock* firstIterationBreakBlock,
Expand All @@ -80,15 +81,15 @@ static void _foldAndSimplifyLoopIteration(
{
for (auto inst : b->getChildren())
{
tryReplaceInstUsesWithSimplifiedValue(builder.getModule(), inst);
tryReplaceInstUsesWithSimplifiedValue(targetRequest, builder.getModule(), inst);
}
}

// It is important to also evaluate `firstIterationBreakBlock` because we need to have
// the phi arguments for next iteration evaluated (args in the new loop inst).
for (auto inst : firstIterationBreakBlock->getChildren())
{
tryReplaceInstUsesWithSimplifiedValue(builder.getModule(), inst);
tryReplaceInstUsesWithSimplifiedValue(targetRequest, builder.getModule(), inst);
}

// Fold conditional branches into unconditional branches if the condition is known.
Expand Down Expand Up @@ -147,6 +148,7 @@ static void _foldAndSimplifyLoopIteration(
// Returns true if we can statically determine that the loop terminated within the iteration limit.
// This operation assumes the loop does not have `continue` jumps, i.e. continueBlock == targetBlock.
static bool _unrollLoop(
TargetRequest* targetRequest,
IRModule* module,
IRLoop* loopInst,
List<IRBlock*>& blocks)
Expand Down Expand Up @@ -339,7 +341,7 @@ static bool _unrollLoop(
// conditional jumps can be folded into unconditional jumps.

_foldAndSimplifyLoopIteration(
builder, clonedBlocks, firstIterationBreakBlock, unreachableBlock);
targetRequest, builder, clonedBlocks, firstIterationBreakBlock, unreachableBlock);

// Now we have peeled off one iteration from the loop, we check if there are any
// branches into next iteration, if not, the loop terminates and we are done.
Expand Down Expand Up @@ -433,6 +435,7 @@ List<IRLoop*> collectLoopsInFunc(IRGlobalValueWithCode* func, const TFunc& filte
}

bool unrollLoopsInFunc(
TargetRequest* targetRequest,
IRModule* module,
IRGlobalValueWithCode* func,
DiagnosticSink* sink)
Expand All @@ -450,7 +453,7 @@ bool unrollLoopsInFunc(

auto blocks = collectBlocksInRegion(func, loop);
auto loopLoc = loop->sourceLoc;
if (!_unrollLoop(module, loop, blocks))
if (!_unrollLoop(targetRequest, module, loop, blocks))
{
if (sink)
sink->diagnose(loopLoc, Diagnostics::cannotUnrollLoop);
Expand All @@ -465,7 +468,7 @@ bool unrollLoopsInFunc(
return true;
}

bool unrollLoopsInModule(IRModule* module, DiagnosticSink* sink)
bool unrollLoopsInModule(TargetRequest* targetRequest, IRModule* module, DiagnosticSink* sink)
{
SLANG_PROFILE;

Expand All @@ -476,7 +479,7 @@ bool unrollLoopsInModule(IRModule* module, DiagnosticSink* sink)

if (auto func = as<IRGlobalValueWithCode>(inst))
{
bool result = unrollLoopsInFunc(module, func, sink);
bool result = unrollLoopsInFunc(targetRequest, module, func, sink);
if (!result)
return false;
}
Expand Down
5 changes: 3 additions & 2 deletions source/slang/slang-ir-loop-unroll.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ namespace Slang
class DiagnosticSink;
struct IRModule;
struct IRBlock;
class TargetRequest;

// Return true if successfull, false if errors occurred.
bool unrollLoopsInFunc(IRModule* module, IRGlobalValueWithCode* func, DiagnosticSink* sink);
bool unrollLoopsInFunc(TargetRequest*target, IRModule* module, IRGlobalValueWithCode* func, DiagnosticSink* sink);

bool unrollLoopsInModule(IRModule* module, DiagnosticSink* sink);
bool unrollLoopsInModule(TargetRequest* target, IRModule* module, DiagnosticSink* sink);

// Turn a loop with continue block into a loop with only back jumps and breaks.
// Each iteration will be wrapped in a breakable region, where everything before `continue`
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-ir-lower-generics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ namespace Slang
// real RTTI objects and witness tables.
specializeRTTIObjects(&sharedContext, sink);

simplifyIR(module, IRSimplificationOptions::getFast());
simplifyIR(sharedContext.targetReq, module, IRSimplificationOptions::getFast());

lowerTuples(module, sink);
if (sink->getErrorCount() != 0)
Expand Down
Loading

0 comments on commit b507d88

Please sign in to comment.