Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shader object refactor #87

Open
skallweitNV opened this issue Oct 22, 2024 · 0 comments
Open

Shader object refactor #87

skallweitNV opened this issue Oct 22, 2024 · 0 comments
Labels

Comments

@skallweitNV
Copy link
Collaborator

Introduction

This document describes a new API for shader objects in slang-rhi.

The main goals are:

  • Shader objects must be immutable
    • Allow to use them in multi-threaded environments (e.g. parallel command encoding)
    • Allow to remove a lot of state-tracking in the current implementation
  • Shader specialization through shader objects should be handled in a backend agnostic way
    • Backends should not need to duplicate the same logic over and over
  • Creating mutated shader object (copies) should be cheap
  • Only changing uniform data should not require rebuilding descriptor sets/tables or bind groups
    • D3D12 should use root descriptors for binding the uniform constant buffers
    • Vulkan should use push constants or dynamic uniform buffers
    • WebGPU should use bind groups with dynamic offsets
    • Memory for uniform data should be sub-allocated from a large buffer
  • Shader objects are used to track the lifetime of all bound resources
    • This is used ensure resources stay alive until command buffers are finished on the GPU

IShaderObject

Similar to the implementation in gfx, we have a IShaderObject interface that represents shader objects:

class IShaderObject
{
public:
    /// Return the associated element type layout.
    virtual slang::TypeLayoutReflection* getElementTypeLayout() = 0;
    /// Return the container type.
    virtual ShaderObjectContainerType getContainerType() = 0;
    /// Return the number of entry points (if this is a root shader object).
    virtual Count getEntryPointCount() = 0;
    /// Return an entry point by index.
    virtual Result getEntryPoint(GfxIndex index, IShaderObject** entryPoint) = 0;

    /// Set uniform data.
    virtual Result setData(ShaderOffset offset, const void *data, Size size) = 0;
    /// Get uniform data.
    virtual Result getData(ShaderOffset offset, void *data, Size size) = 0;

    /// Set a binding.
    virtual Result setBinding(ShaderOffset offset, Binding binding) = 0;
    /// Get a binding.
    virtual Result getBinding(ShaderOffset offset, Binding* binding) = 0;

    /// Set a sub-object.
    virtual Result setObject(ShaderOffset offset, IShaderObject* object) = 0;
    /// Get a sub-object.
    virtual Result getObject(ShaderOffset offset, IShaderObject** object) = 0;

    /// Freeze the shader object, making it immutable.
    /// Any calls to modify the shader object after this will result in an error.
    virtual Result freeze() = 0;
};

The main difference is that IShaderObject objects become immutable after calling the freeze() method. Shader objects can only be bound to a command encoder after they have been frozen. Shader objects can only be assigned as sub-objects in other shader objects if they are frozen.

Frozen shader objects cannot under any circumstances be unfrozen.

To create shader objects, there are a few factory methods on IDevice:

class IDevice
{
public:
    /// Create a new shader object from a given slang type.
    virtual Result createShaderObject(
        slang::ISession* slangSession,
        slang::TypeReflection* type,
        ShaderObjectContainerType container,
        IShaderObject** outObject
    ) = 0;

    /// Create a new shader object by copying an existing shader object.
    /// The new shader object will be mutable.
    virtual Result createShaderObject(IShaderObject* object, IShaderObject** outObject) = 0;

    /// Create a new root shader object for the given shader program.
    virtual Result createRootShaderObject(IShaderProgram* program, IShaderObject** outObject) = 0;
};

Basic example

Slang code:

interface ICamera { Ray getRay(float2 uv);};

struct PinholeCamera : ICamera
{
    float3 position;
    float3 direction;
    float3 up;
    float2 fov;
    float2 resolution;

    Ray getRay(float2 uv) { /* ... */ }
}

struct Scene
{
    StructuredBuffer<Vertex> vertexBuffer;
    StructuredBuffer<Index> indexBuffer;
    ICamera camera;
};

[[shader(compute)]]
void render(uniform ParameterBlock<Scene> scene, uniform uint iteration)
{}

Host code:

// Load the shader program.
ComPtr<IShaderProgram> program = device->loadProgram("test.slang", "main");

// Create a shader object for the camera.
// This object will be used to specialize the program.
ComPtr<IShaderObject> cameraObject = device->createShaderObject(program->getReflection()->getType("PinholeCamera"));
{
    ShaderCursor cursor(cameraObject);
    cursor["position"] = float3(0, 0, 0);
    ...
}
// Freeze the camera shader object so we can use it as a sub-object.
cameraObject->freeze();

// Create a shader object for the scene.
ComPtr<IShaderObject> sceneObject = device->createShaderObject(program->getReflection()->getType("Scene"));
{
    ShaderCursor cursor(sceneObject);
    cursor["vertexBuffer"] = vertexBuffer;
    cursor["indexBuffer"] = indexBuffer;
    cursor["camera"] = cameraObject; // NOTE: Error if cameraObject was not frozen!
}
sceneObject->freeze();

// Create the root shader object.
ComPtr<IShaderObject> rootObject = device->createRootShaderObject(program);
{
    ShaderCursor cursor(rootObject)
    cursor["scene"] = sceneObject;
}
rootObject->freeze();

// With the root object done, we can now specialize our program.
ComPtr<IShaderProgram> specializedProgram = program->specialize(program, rootObject);

// Create a compute pipeline.
// NOTE: Creating a pipeline for an unspecialized program would be an error!
ComPtr<IComputePipeline> pipeline = device->createComputePipeline(specializedProgram);

// Submit a single compute dispatch.
ComPtr<ICommandEncoder> encoder = device->getQueue()->createCommandEncoder();
ComputeState state;
state.pipeline = pipeline;
state.rootObject = rootObject; // NOTE: Error if rootObject is not frozen.
encoder->setComputeState(state);
encoder->dispatchCompute(1, 1, 1);
device->getQueue()->submit(encoder->finish());

// Submit multiple dispatches on the same command list, with modified root objects.
ComPtr<ICommandEncoder> encoder = device->getQueue()->createCommandEncoder();
for (int i = 0; i < 100; ++i)
{
    ComPtr<IShaderObject> modifiedRootObject = device->createShaderObject(rootObject);
    {
        ShaderCursor cursor(modifiedRootObject);
        cursor["iteration"] = i;
    }
    modifiedRootObject->freeze();
    ComputeState state;
    state.pipeline = pipeline;
    state.rootObject = modifiedRootObject;
    encoder->setComputeState(state);
    encoder->dispatchCompute(1, 1, 1);
}
device->getQueue()->submit(encoder->finish());

Implementation details

ShaderObject

Shader objects are implemented in a backend agnostic way. The main purpose of shader objects is to hold all the resources, sub-objects and uniform data assigned to them. Binding ranges are used to map shader offsets to a linear array of binding slots. Each binding slot contains a reference to a resource and additional data (e.g. buffer range, format, etc.).

class ShaderObject : public IShaderObject, public ComObject
{
public:
    void init()
    {
        // 1. Enumerate all binding ranges and populate m_bindingTypeToStartIndex and m_bindings.
        // 2. Enumerate all sub-objects and create a shader object for each sub-object (recursively).
        // 3. Allocate memory for uniform data.
    }

    Result setData(ShaderOffset offset, const void *data, Size size) override
    {
        // 1. Return error if the shader object is frozen.
        // 2. Copy the data into the uniform data buffer.
    }

    Result getData(ShaderOffset offset, void *data, Size size) override
    {
        // Copy the data from the uniform data buffer.
    }

    Result setBinding(ShaderOffset offset, Binding binding) override
    {
        // 1. Return error if the shader object is frozen.
        // 2. Find the binding range for the given offset.
        // 3. Copy the binding into the bindings array.
    }

    Result getBinding(ShaderOffset offset, Binding* binding) override
    {
        // Find the binding for the given offset and return it.
    }

    Result setObject(ShaderOffset offset, IShaderObject* object) override
    {
        // 1. Return error if the shader object is frozen.
        // 2. Find the index of the sub-object for the given offset.
        // 3. Set the object into the sub-objects array.
    }

    Result getObject(ShaderOffset offset, IShaderObject** object) override
    {
        // Find the sub-object for the given offset and return it.
    }

    Result freeze() override
    {
        // 1. Return error if the shader object is already frozen.
        // 2. Freeze all sub-objects (recursively).
        // 3. Freeze the shader object.
    }

private:
    struct BindingSlot
    {
        /// The bound resource.
        RefPtr<Resource> resource;
        /// Additional data.
        union
        {
            struct
            {
                BufferRange range;
                Format format;
            } buffer;
            // ...
        };
    }

    /// True if the shader object is frozen.
    bool m_frozen = false; 
    /// Map from binding type to start index in the bindings array.
    std::array<uint32_t, slang::BindingType::Count> m_bindingTypeToStartIndex;
    /// List of bindings.
    std::vector<BindingSlot> m_bindings;
    /// List of sub-objects.
    std::vector<RefPtr<ShaderObject>> m_objects;
    /// Uniform data.
    std::vector<uint8_t> m_data;
};

Shallow copy

When copying shader objects, we start by only copying the root object. Sub-objects initially reference the same memory as the original shader object. When a sub-object is modified (through getObject), we create a copy of the sub-object. This way, we can avoid copying the entire shader object tree when only a small part of it is modified. We also track the original shader objects we copied from to allow backends to reuse computation of previous shader objects.

class ShaderObject : public IShaderObject, public ComObject
{
public:
    Result initFromOther(IShaderObject* other) override
    {
        // 1. Copy all data from `other`.
        // 2. Set `m_is_copy` to true.
        // 3. Assign the original shader object to `m_original`.
    }

    Result getObject(ShaderOffset offset, IShaderObject** object) override
    {
        // 1. Find the sub-object for the given offset.
        // 2. If the sub-object is frozen (i.e. still referencing the original), replace the sub-object with a copy.
        // 3. Return the new sub-object.
    }

private:
    /// True if the shader object is a copy.
    bool m_is_copy = false;
    /// The shader object we copied from.
    RefPtr<ShaderObject> m_original;
};

Backend data

Each backend has a different way for binding shader objects. Once a shader object is frozen, we can create a per-backend data structure that contains all the information needed to bind the shader object to a command encoder. These objects can be light-weight, as they only need to store the information needed to bind the shader object.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant