Skip to content

Commit

Permalink
vkd3d: Support basic node overrides in workgraphs.
Browse files Browse the repository at this point in the history
Used by NV Donut demo.

Signed-off-by: Hans-Kristian Arntzen <post@arntzen-software.no>
  • Loading branch information
HansKristian-Work committed Oct 22, 2024
1 parent 3840edf commit d94d27f
Showing 1 changed file with 162 additions and 3 deletions.
165 changes: 162 additions & 3 deletions libs/vkd3d/workgraphs.c
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ struct d3d12_wg_state_object_program

const D3D12_NODE_ID *explicit_entry_points;
size_t explicit_entry_point_count;

const D3D12_NODE *explicit_defined_nodes;
size_t explicit_defined_nodes_count;
};

struct d3d12_wg_state_object_module
Expand Down Expand Up @@ -235,9 +238,6 @@ static HRESULT d3d12_wg_state_object_parse_subobject(
const D3D12_WORK_GRAPH_DESC *wg_desc = obj->pDesc;
struct d3d12_wg_state_object_program *program;

if (wg_desc->NumExplicitlyDefinedNodes != 0)
FIXME("Explicitly stated nodes is not supported.\n");

if (wg_desc->Flags != D3D12_WORK_GRAPH_FLAG_INCLUDE_ALL_AVAILABLE_NODES)
{
FIXME("Only INCLUDE_ALL_AVAILABLE_NODES mode is supported.\n");
Expand All @@ -255,6 +255,12 @@ static HRESULT d3d12_wg_state_object_parse_subobject(
program->explicit_entry_point_count = wg_desc->NumEntrypoints;
}

if (wg_desc->NumExplicitlyDefinedNodes != 0)
{
program->explicit_defined_nodes = wg_desc->pExplicitlyDefinedNodes;
program->explicit_defined_nodes_count = wg_desc->NumExplicitlyDefinedNodes;
}

program->name = vkd3d_wstrdup(wg_desc->ProgramName);
break;
}
Expand Down Expand Up @@ -691,6 +697,148 @@ static HRESULT d3d12_wg_state_object_rearrange_entry_points(struct d3d12_wg_stat
return S_OK;
}

static const struct vkd3d_shader_library_entry_point *
d3d12_wg_state_object_find_exported_entry_point(
struct d3d12_wg_state_object_data *data, LPCWSTR shader)
{
size_t i;
for (i = 0; i < data->entry_points_count; i++)
if (vkd3d_export_equal(shader, &data->entry_points[i]))
return &data->entry_points[i];
return NULL;
}

static HRESULT d3d12_wg_state_object_apply_node_overrides(
struct d3d12_wg_state_object_data *data,
struct d3d12_wg_state_object_program *program)
{
const struct vkd3d_shader_library_entry_point *entry;
size_t i;

for (i = 0; i < program->explicit_defined_nodes_count; i++)
{
const D3D12_SHADER_NODE *shader_node;
const D3D12_NODE *node;

node = &program->explicit_defined_nodes[i];
if (node->NodeType != D3D12_NODE_TYPE_SHADER)
{
/* Not supported in current workgraph spec. Would be app error to try this. */
WARN("Attempting to use non-supported D3D12_NODE_TYPE_PROGRAM.\n");
return E_INVALIDARG;
}

shader_node = &node->Shader;

TRACE("Adding overrides for export %s.\n", debugstr_w(shader_node->Shader));

entry = d3d12_wg_state_object_find_exported_entry_point(data, shader_node->Shader);
if (!entry)
{
FIXME("Could not find shader with export name %s.\n", debugstr_w(shader_node->Shader));
return E_INVALIDARG;
}

if (!entry->node_input)
{
FIXME("Override node does not have a node input structure associated with it.\n");
return E_INVALIDARG;
}

switch (shader_node->OverridesType)
{
case D3D12_NODE_OVERRIDES_TYPE_BROADCASTING_LAUNCH:
{
const D3D12_BROADCASTING_LAUNCH_OVERRIDES *override = shader_node->pBroadcastingLaunchOverrides;
if (override->NumOutputOverrides)
FIXME("Output overrides not supported yet.\n");
if (override->pShareInputOf)
FIXME("ShaderInputOf overrides not supported yet.\n");
if (override->pNewName)
FIXME("NameView overrides not supported yet.\n");

if (override->pDispatchGrid)
{
memcpy(entry->node_input->broadcast_grid, override->pDispatchGrid, sizeof(UINT) * 3);
entry->node_input->dispatch_grid_is_upper_bound = false;
TRACE("Overriding export %s dispatch grid [%u, %u, %u]\n",
debugstr_w(shader_node->Shader),
entry->node_input->broadcast_grid[0],
entry->node_input->broadcast_grid[1],
entry->node_input->broadcast_grid[2]);
}

if (override->pMaxDispatchGrid)
{
memcpy(entry->node_input->broadcast_grid, override->pMaxDispatchGrid, sizeof(UINT) * 3);
entry->node_input->dispatch_grid_is_upper_bound = true;
TRACE("Overriding export %s max dispatch grid [%u, %u, %u]\n",
debugstr_w(shader_node->Shader),
entry->node_input->broadcast_grid[0],
entry->node_input->broadcast_grid[1],
entry->node_input->broadcast_grid[2]);
}

if (override->pProgramEntry)
{
entry->node_input->is_program_entry = *override->pProgramEntry;
TRACE("Overriding export %s IsProgramEntry %u\n",
debugstr_w(shader_node->Shader), entry->node_input->is_program_entry);
}

if (override->pLocalRootArgumentsTableIndex)
{
entry->node_input->local_root_arguments_table_index = *override->pLocalRootArgumentsTableIndex;
TRACE("Overriding export %s LocalRootArgumentsTableIndex %u\n",
debugstr_w(shader_node->Shader), entry->node_input->local_root_arguments_table_index);
}

break;
}

/* These three are exactly the same.
* It's a union of pointers, so should be no issue. */
case D3D12_NODE_OVERRIDES_TYPE_COALESCING_LAUNCH:
case D3D12_NODE_OVERRIDES_TYPE_THREAD_LAUNCH:
case D3D12_NODE_OVERRIDES_TYPE_COMMON_COMPUTE:
{
const D3D12_COALESCING_LAUNCH_OVERRIDES *override = shader_node->pCoalescingLaunchOverrides;
if (override->NumOutputOverrides)
FIXME("Output overrides not supported yet.\n");
if (override->pShareInputOf)
FIXME("ShaderInputOf overrides not supported yet.\n");
if (override->pNewName)
FIXME("NameView overrides not supported yet.\n");

if (override->pProgramEntry)
{
entry->node_input->is_program_entry = *override->pProgramEntry;
TRACE("Overriding export %s IsProgramEntry %u\n",
debugstr_w(shader_node->Shader), entry->node_input->is_program_entry);
}

if (override->pLocalRootArgumentsTableIndex)
{
entry->node_input->local_root_arguments_table_index = *override->pLocalRootArgumentsTableIndex;
TRACE("Overriding export %s LocalRootArgumentsTableIndex %u\n",
debugstr_w(shader_node->Shader), entry->node_input->local_root_arguments_table_index);
}

break;
}

case D3D12_NODE_OVERRIDES_TYPE_NONE:
break;

default:
FIXME("Unrecognized node override type %u.\n", shader_node->OverridesType);
return E_INVALIDARG;
}
}

return S_OK;
}

static HRESULT d3d12_wg_state_object_resolve_entry_points_explicit(struct d3d12_wg_state_object *object,
struct d3d12_wg_state_object_data *data,
struct d3d12_wg_state_object_program *program)
Expand Down Expand Up @@ -2155,6 +2303,17 @@ static HRESULT d3d12_wg_state_object_compile_programs(
{
struct d3d12_wg_state_object_program *program = &object->programs[i];

/* TODO: Potential hazard is that each program can have different overrides per node,
* meaning we need a local copy of entry point meta. However, ignore this for now. */
if (object->programs_count > 1 && program->explicit_defined_nodes_count)
{
FIXME("More than one program is used and explicitly defined nodes is used. "
"This may not work if there are conflicts in overrides.\n");
}

if (FAILED(hr = d3d12_wg_state_object_apply_node_overrides(data, program)))
return hr;

if (program->explicit_entry_point_count)
{
if (FAILED(hr = d3d12_wg_state_object_resolve_entry_points_explicit(object, data, program)))
Expand Down

0 comments on commit d94d27f

Please sign in to comment.