Skip to content

Commit

Permalink
refactor shader program creation
Browse files Browse the repository at this point in the history
  • Loading branch information
skallweitNV committed Sep 5, 2024
1 parent a423c1d commit d88d9f3
Show file tree
Hide file tree
Showing 25 changed files with 159 additions and 280 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
- remove IDevice::createProgram2
- rename IDevice::createProgram -> IDevice::createShaderProgram
- rename IShaderProgram::Desc -> ShaderProgramDesc
- remove SLANG_RHI_FORMAT(x) macro, extended FormatInfo with name
- refactor NativeHandle to hold the actual type of the handle instead of just a category
- rename InteropHandle -> NativeHandle
Expand Down
97 changes: 41 additions & 56 deletions include/slang-rhi.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,63 +89,41 @@ const GfxCount kMaxRenderTargetCount = 8;
class ITransientResourceHeap;
class IPersistentShaderCache;

enum class ShaderModuleSourceType
/// Defines how linking should be performed for a shader program.
enum class LinkingStyle
{
SlangSource, // a slang source string in memory.
SlangModuleBinary, // a slang module binary code in memory.
SlangSourceFile, // a slang source from file.
SlangModuleBinaryFile, // a slang module binary code from file.
// Compose all entry-points in a single program, then compile all entry-points together with the same
// set of root shader arguments.
SingleProgram,

// Link and compile each entry-point individually, potentially with different specializations.
SeparateEntryPointCompilation
};

class IShaderProgram : public ISlangUnknown
struct ShaderProgramDesc
{
SLANG_COM_INTERFACE(0x19cabd0d, 0xf3e3, 0x4b3d, {0x93, 0x43, 0xea, 0xcc, 0x00, 0x1e, 0xc5, 0xf2});

public:
// Defines how linking should be performed for a shader program.
enum class LinkingStyle
{
// Compose all entry-points in a single program, then compile all entry-points together with the same
// set of root shader arguments.
SingleProgram,

// Link and compile each entry-point individually, potentially with different specializations.
SeparateEntryPointCompilation
};

struct Desc
{
// TODO: Tess doesn't like this but doesn't know what to do about it
// The linking style of this program.
LinkingStyle linkingStyle = LinkingStyle::SingleProgram;
// TODO: Tess doesn't like this but doesn't know what to do about it
// The linking style of this program.
LinkingStyle linkingStyle = LinkingStyle::SingleProgram;

// The global scope or a Slang composite component that represents the entire program.
slang::IComponentType* slangGlobalScope;
// The global scope or a Slang composite component that represents the entire program.
slang::IComponentType* slangGlobalScope;

// Number of separate entry point components in the `slangEntryPoints` array to link in.
// If set to 0, then `slangGlobalScope` must contain Slang EntryPoint components.
// If not 0, then `slangGlobalScope` must not contain any EntryPoint components.
GfxCount entryPointCount = 0;
// An array of Slang entry points. The size of the array must be `slangEntryPointCount`.
// Each element must define only 1 Slang EntryPoint.
slang::IComponentType** slangEntryPoints = nullptr;

// An array of Slang entry points. The size of the array must be `entryPointCount`.
// Each element must define only 1 Slang EntryPoint.
slang::IComponentType** slangEntryPoints = nullptr;
};

struct CreateDesc2
{
ShaderModuleSourceType sourceType;
void* sourceData;
Size sourceDataSize;
// Number of separate entry point components in the `slangEntryPoints` array to link in.
// If set to 0, then `slangGlobalScope` must contain Slang EntryPoint components.
// If not 0, then `slangGlobalScope` must not contain any EntryPoint components.
GfxCount slangEntryPointCount = 0;
};

// Number of entry points to include in the shader program. 0 means include all entry points
// defined in the module.
GfxCount entryPointCount = 0;
// Names of entry points to include in the shader program. The size of the array must be
// `entryPointCount`.
const char** entryPointNames = nullptr;
};
class IShaderProgram : public ISlangUnknown
{
SLANG_COM_INTERFACE(0x19cabd0d, 0xf3e3, 0x4b3d, {0x93, 0x43, 0xea, 0xcc, 0x00, 0x1e, 0xc5, 0xf2});

public:
virtual SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL findTypeByName(const char* name) = 0;
};

Expand Down Expand Up @@ -2348,24 +2326,31 @@ class IDevice : public ISlangUnknown
virtual SLANG_NO_THROW Result SLANG_MCALL
createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outTable) = 0;

virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(
const IShaderProgram::Desc& desc,
virtual SLANG_NO_THROW Result SLANG_MCALL createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnosticBlob = nullptr
) = 0;

inline ComPtr<IShaderProgram> createProgram(const IShaderProgram::Desc& desc)
inline ComPtr<IShaderProgram> createShaderProgram(
const ShaderProgramDesc& desc,
ISlangBlob** outDiagnosticBlob = nullptr
)
{
ComPtr<IShaderProgram> program;
SLANG_RETURN_NULL_ON_FAIL(createProgram(desc, program.writeRef()));
SLANG_RETURN_NULL_ON_FAIL(createShaderProgram(desc, program.writeRef(), outDiagnosticBlob));
return program;
}

virtual SLANG_NO_THROW Result SLANG_MCALL createProgram2(
const IShaderProgram::CreateDesc2& createDesc,
IShaderProgram** outProgram,
inline ComPtr<IShaderProgram> createShaderProgram(
slang::IComponentType* linkedProgram,
ISlangBlob** outDiagnosticBlob = nullptr
) = 0;
)
{
ShaderProgramDesc desc = {};
desc.slangGlobalScope = linkedProgram;
return createShaderProgram(desc, outDiagnosticBlob);
}

virtual SLANG_NO_THROW Result SLANG_MCALL
createRenderPipeline(const RenderPipelineDesc& desc, IPipeline** outPipeline) = 0;
Expand Down
7 changes: 5 additions & 2 deletions src/cpu/cpu-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,11 @@ Result DeviceImpl::createRootShaderObject(IShaderProgram* program, ShaderObjectB
return SLANG_OK;
}

SLANG_NO_THROW Result SLANG_MCALL
DeviceImpl::createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram, ISlangBlob** outDiagnosticBlob)
SLANG_NO_THROW Result SLANG_MCALL DeviceImpl::createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnosticBlob
)
{
RefPtr<ShaderProgramImpl> cpuProgram = new ShaderProgramImpl();
cpuProgram->init(desc);
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/cpu-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class DeviceImpl : public ImmediateComputeDeviceBase

virtual Result createRootShaderObject(IShaderProgram* program, ShaderObjectBase** outObject) override;

virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(
const IShaderProgram::Desc& desc,
virtual SLANG_NO_THROW Result SLANG_MCALL createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnosticBlob
) override;
Expand Down
7 changes: 5 additions & 2 deletions src/cuda/cuda-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,11 @@ Result DeviceImpl::createRootShaderObject(IShaderProgram* program, ShaderObjectB
return SLANG_OK;
}

SLANG_NO_THROW Result SLANG_MCALL
DeviceImpl::createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram, ISlangBlob** outDiagnosticBlob)
SLANG_NO_THROW Result SLANG_MCALL DeviceImpl::createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnosticBlob
)
{
// If this is a specializable program, we just keep a reference to the slang program and
// don't actually create any kernels. This program will be specialized later when we know
Expand Down
4 changes: 2 additions & 2 deletions src/cuda/cuda-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class DeviceImpl : public RendererBase

Result createRootShaderObject(IShaderProgram* program, ShaderObjectBase** outObject);

virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(
const IShaderProgram::Desc& desc,
virtual SLANG_NO_THROW Result SLANG_MCALL createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnosticBlob
) override;
Expand Down
4 changes: 2 additions & 2 deletions src/d3d11/d3d11-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1330,8 +1330,8 @@ void DeviceImpl::drawIndexedInstanced(
);
}

Result DeviceImpl::createProgram(
const IShaderProgram::Desc& desc,
Result DeviceImpl::createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnosticBlob
)
Expand Down
4 changes: 2 additions & 2 deletions src/d3d11/d3d11-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class DeviceImpl : public ImmediateRendererBase
virtual Result createRootShaderObject(IShaderProgram* program, ShaderObjectBase** outObject) override;
virtual void bindRootShaderObject(IShaderObject* shaderObject) override;

virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(
const IShaderProgram::Desc& desc,
virtual SLANG_NO_THROW Result SLANG_MCALL createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnosticBlob
) override;
Expand Down
4 changes: 2 additions & 2 deletions src/d3d12/d3d12-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1861,8 +1861,8 @@ Result DeviceImpl::readBuffer(IBuffer* bufferIn, Offset offset, Size size, ISlan
return SLANG_OK;
}

Result DeviceImpl::createProgram(
const IShaderProgram::Desc& desc,
Result DeviceImpl::createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnosticBlob
)
Expand Down
7 changes: 5 additions & 2 deletions src/d3d12/d3d12-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,11 @@ class DeviceImpl : public RendererBase

virtual SLANG_NO_THROW Result SLANG_MCALL
createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable) override;
virtual SLANG_NO_THROW Result SLANG_MCALL
createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram, ISlangBlob** outDiagnostics) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnostics
) override;
virtual SLANG_NO_THROW Result SLANG_MCALL
createRenderPipeline(const RenderPipelineDesc& desc, IPipeline** outPipeline) override;
virtual SLANG_NO_THROW Result SLANG_MCALL
Expand Down
24 changes: 3 additions & 21 deletions src/debug-layer/debug-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,41 +458,23 @@ Result DebugDevice::createMutableShaderObjectFromTypeLayout(
return result;
}

Result DebugDevice::createProgram(
const IShaderProgram::Desc& desc,
Result DebugDevice::createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnostics
)
{
SLANG_RHI_API_FUNC;

RefPtr<DebugShaderProgram> outObject = new DebugShaderProgram();
auto result = baseObject->createProgram(desc, outObject->baseObject.writeRef(), outDiagnostics);
auto result = baseObject->createShaderProgram(desc, outObject->baseObject.writeRef(), outDiagnostics);
if (SLANG_FAILED(result))
return result;
outObject->m_slangProgram = desc.slangGlobalScope;
returnComPtr(outProgram, outObject);
return result;
}

Result DebugDevice::createProgram2(
const IShaderProgram::CreateDesc2& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnostics
)
{
SLANG_RHI_API_FUNC;
IShaderProgram::Desc desc1 = {};
RefPtr<DebugShaderProgram> outObject = new DebugShaderProgram();
auto result = baseObject->createProgram2(desc, outObject->baseObject.writeRef(), outDiagnostics);
if (SLANG_FAILED(result))
return result;
auto base = static_cast<ShaderProgramBase*>(outObject->baseObject.get());
outObject->m_slangProgram = base->desc.slangGlobalScope;
returnComPtr(outProgram, outObject);
return result;
}

Result DebugDevice::createRenderPipeline(const RenderPipelineDesc& desc, IPipeline** outPipeline)
{
SLANG_RHI_API_FUNC;
Expand Down
6 changes: 2 additions & 4 deletions src/debug-layer/debug-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,8 @@ class DebugDevice : public DebugObject<IDevice>
) override;
virtual SLANG_NO_THROW Result SLANG_MCALL
createMutableRootShaderObject(IShaderProgram* program, IShaderObject** outObject) override;
virtual SLANG_NO_THROW Result SLANG_MCALL
createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram, ISlangBlob** outDiagnostics) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createProgram2(
const IShaderProgram::CreateDesc2& desc,
virtual SLANG_NO_THROW Result SLANG_MCALL createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnostics
) override;
Expand Down
4 changes: 2 additions & 2 deletions src/metal/metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,8 @@ Result DeviceImpl::createInputLayout(IInputLayout::Desc const& desc, IInputLayou
return SLANG_OK;
}

Result DeviceImpl::createProgram(
const IShaderProgram::Desc& desc,
Result DeviceImpl::createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnosticBlob
)
Expand Down
4 changes: 2 additions & 2 deletions src/metal/metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class DeviceImpl : public RendererBase

virtual SLANG_NO_THROW Result SLANG_MCALL
createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(
const IShaderProgram::Desc& desc,
virtual SLANG_NO_THROW Result SLANG_MCALL createShaderProgram(
const ShaderProgramDesc& desc,
IShaderProgram** outProgram,
ISlangBlob** outDiagnosticBlob
) override;
Expand Down
Loading

0 comments on commit d88d9f3

Please sign in to comment.