d3d12: Support compute root signatures

Reviewed-by: Sil Vilerino <sivileri@microsoft.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14367>
This commit is contained in:
Jesse Natalie 2021-12-31 10:08:54 -08:00 committed by Marge Bot
parent 6d38a35afb
commit 5f23b1d7cd
3 changed files with 21 additions and 11 deletions

View file

@ -752,7 +752,7 @@ d3d12_draw_vbo(struct pipe_context *pctx,
}
if (!ctx->gfx_pipeline_state.root_signature || ctx->state_dirty & D3D12_DIRTY_SHADER) {
ID3D12RootSignature *root_signature = d3d12_get_root_signature(ctx);
ID3D12RootSignature *root_signature = d3d12_get_root_signature(ctx, false);
if (ctx->gfx_pipeline_state.root_signature != root_signature) {
ctx->gfx_pipeline_state.root_signature = root_signature;
ctx->state_dirty |= D3D12_DIRTY_ROOT_SIGNATURE;

View file

@ -51,6 +51,8 @@ get_shader_visibility(enum pipe_shader_type stage)
return D3D12_SHADER_VISIBILITY_HULL;
case PIPE_SHADER_TESS_EVAL:
return D3D12_SHADER_VISIBILITY_DOMAIN;
case PIPE_SHADER_COMPUTE:
return D3D12_SHADER_VISIBILITY_ALL;
default:
unreachable("unknown shader stage");
}
@ -104,8 +106,10 @@ create_root_signature(struct d3d12_context *ctx, struct d3d12_root_signature_key
unsigned num_params = 0;
unsigned num_ranges = 0;
for (unsigned i = 0; i < D3D12_GFX_SHADER_STAGES; ++i) {
D3D12_SHADER_VISIBILITY visibility = get_shader_visibility((enum pipe_shader_type)i);
unsigned count = key->compute ? 1 : D3D12_GFX_SHADER_STAGES;
for (unsigned i = 0; i < count; ++i) {
unsigned stage = key->compute ? PIPE_SHADER_COMPUTE : i;
D3D12_SHADER_VISIBILITY visibility = get_shader_visibility((enum pipe_shader_type)stage);
if (key->stages[i].num_cb_bindings > 0) {
init_range_root_param(&root_params[num_params++],
@ -174,7 +178,8 @@ create_root_signature(struct d3d12_context *ctx, struct d3d12_root_signature_key
root_sig_desc.Desc_1_1.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
/* TODO Only enable this flag when needed (optimization) */
root_sig_desc.Desc_1_1.Flags |= D3D12_ROOT_SIGNATURE_FLAG_ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT;
if (!key->compute)
root_sig_desc.Desc_1_1.Flags |= D3D12_ROOT_SIGNATURE_FLAG_ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT;
if (key->has_stream_output)
root_sig_desc.Desc_1_1.Flags |= D3D12_ROOT_SIGNATURE_FLAG_ALLOW_STREAM_OUTPUT;
@ -198,12 +203,16 @@ create_root_signature(struct d3d12_context *ctx, struct d3d12_root_signature_key
}
static void
fill_key(struct d3d12_context *ctx, struct d3d12_root_signature_key *key)
fill_key(struct d3d12_context *ctx, struct d3d12_root_signature_key *key, bool compute)
{
memset(key, 0, sizeof(struct d3d12_root_signature_key));
for (unsigned i = 0; i < D3D12_GFX_SHADER_STAGES; ++i) {
struct d3d12_shader *shader = ctx->gfx_pipeline_state.stages[i];
key->compute = compute;
unsigned count = compute ? 1 : D3D12_GFX_SHADER_STAGES;
for (unsigned i = 0; i < count; ++i) {
struct d3d12_shader *shader = compute ?
ctx->compute_pipeline_state.stage :
ctx->gfx_pipeline_state.stages[i];
if (shader) {
key->stages[i].num_cb_bindings = shader->num_cb_bindings;
@ -214,18 +223,18 @@ fill_key(struct d3d12_context *ctx, struct d3d12_root_signature_key *key)
key->stages[i].num_ssbos = shader->nir->info.num_ssbos;
key->stages[i].num_images = shader->nir->info.num_images;
if (ctx->gfx_stages[i]->so_info.num_outputs > 0)
if (!compute && ctx->gfx_stages[i]->so_info.num_outputs > 0)
key->has_stream_output = true;
}
}
}
ID3D12RootSignature *
d3d12_get_root_signature(struct d3d12_context *ctx)
d3d12_get_root_signature(struct d3d12_context *ctx, bool compute)
{
struct d3d12_root_signature_key key;
fill_key(ctx, &key);
fill_key(ctx, &key, compute);
struct hash_entry *entry = _mesa_hash_table_search(ctx->root_signature_cache, &key);
if (!entry) {
struct d3d12_root_signature *data =

View file

@ -27,6 +27,7 @@
#include "d3d12_context.h"
struct d3d12_root_signature_key {
bool compute;
bool has_stream_output;
struct {
unsigned num_cb_bindings;
@ -46,6 +47,6 @@ void
d3d12_root_signature_cache_destroy(struct d3d12_context *ctx);
ID3D12RootSignature *
d3d12_get_root_signature(struct d3d12_context *ctx);
d3d12_get_root_signature(struct d3d12_context *ctx, bool compute);
#endif