microsoft/compiler: Do basic I/O analysis for dependency tables

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22949>
This commit is contained in:
Jesse Natalie 2023-05-09 14:57:08 -07:00 committed by Marge Bot
parent 8ff95b766d
commit e169a402a8
3 changed files with 244 additions and 0 deletions

View file

@ -22,9 +22,11 @@
*/
#include "dxil_nir.h"
#include "dxil_module.h"
#include "nir_builder.h"
#include "nir_deref.h"
#include "nir_worklist.h"
#include "nir_to_dxil.h"
#include "util/u_math.h"
#include "vulkan/vulkan_core.h"
@ -2411,3 +2413,214 @@ dxil_nir_forward_front_face(nir_shader *nir)
nir_metadata_block_index | nir_metadata_dominance,
var);
}
static void
clear_pass_flags(nir_function_impl *impl)
{
nir_foreach_block(block, impl) {
nir_foreach_instr(instr, block) {
instr->pass_flags = 0;
}
}
}
static bool
add_dest_to_worklist(nir_dest *dest, void *state)
{
assert(dest->is_ssa);
nir_foreach_use_including_if(src, &dest->ssa) {
assert(src->is_ssa);
if (src->is_if) {
nir_if *nif = src->parent_if;
nir_foreach_block_in_cf_node(block, &nif->cf_node) {
nir_foreach_instr(instr, block)
nir_instr_worklist_push_tail(state, instr);
}
} else
nir_instr_worklist_push_tail(state, src->parent_instr);
}
return true;
}
static bool
set_input_bits(struct dxil_module *mod, nir_intrinsic_instr *intr, BITSET_WORD *input_bits, uint32_t ***tables, const uint32_t **table_sizes)
{
if (intr->intrinsic == nir_intrinsic_load_view_index) {
BITSET_SET(input_bits, 0);
return true;
}
bool any_bits_set = false;
nir_src *row_src = intr->intrinsic == nir_intrinsic_load_per_vertex_input ? &intr->src[1] : &intr->src[0];
bool is_patch_constant = mod->shader_kind == DXIL_DOMAIN_SHADER && intr->intrinsic == nir_intrinsic_load_input;
const struct dxil_signature_record *sig_rec = is_patch_constant ?
&mod->patch_consts[nir_intrinsic_base(intr)] :
&mod->inputs[mod->input_mappings[nir_intrinsic_base(intr)]];
if (is_patch_constant) {
/* Redirect to the second I/O table */
*tables = *tables + 1;
*table_sizes = *table_sizes + 1;
}
for (uint32_t component = 0; component < intr->num_components; ++component) {
uint32_t base_element = 0;
uint32_t num_elements = sig_rec->num_elements;
if (nir_src_is_const(*row_src)) {
base_element = (uint32_t)nir_src_as_uint(*row_src);
num_elements = 1;
}
for (uint32_t element = 0; element < num_elements; ++element) {
uint32_t row = sig_rec->elements[element + base_element].reg;
if (row == 0xffffffff)
continue;
BITSET_SET(input_bits, row * 4 + component + nir_intrinsic_component(intr));
any_bits_set = true;
}
}
return any_bits_set;
}
static bool
set_output_bits(struct dxil_module *mod, nir_intrinsic_instr *intr, BITSET_WORD *input_bits, uint32_t **tables, const uint32_t *table_sizes)
{
bool any_bits_set = false;
nir_src *row_src = intr->intrinsic == nir_intrinsic_store_per_vertex_output ? &intr->src[2] : &intr->src[1];
bool is_patch_constant = mod->shader_kind == DXIL_HULL_SHADER && intr->intrinsic == nir_intrinsic_store_output;
const struct dxil_signature_record *sig_rec = is_patch_constant ?
&mod->patch_consts[nir_intrinsic_base(intr)] :
&mod->outputs[nir_intrinsic_base(intr)];
for (uint32_t component = 0; component < intr->num_components; ++component) {
uint32_t base_element = 0;
uint32_t num_elements = sig_rec->num_elements;
if (nir_src_is_const(*row_src)) {
base_element = (uint32_t)nir_src_as_uint(*row_src);
num_elements = 1;
}
for (uint32_t element = 0; element < num_elements; ++element) {
uint32_t row = sig_rec->elements[element + base_element].reg;
if (row == 0xffffffff)
continue;
uint32_t stream = sig_rec->elements[element + base_element].stream;
uint32_t table_idx = is_patch_constant ? 1 : stream;
uint32_t *table = tables[table_idx];
uint32_t output_component = component + nir_intrinsic_component(intr);
uint32_t input_component;
BITSET_FOREACH_SET(input_component, input_bits, 32 * 4) {
uint32_t *table_for_input_component = table + table_sizes[table_idx] * input_component;
BITSET_SET(table_for_input_component, row * 4 + output_component);
any_bits_set = true;
}
}
}
return any_bits_set;
}
static bool
propagate_input_to_output_dependencies(struct dxil_module *mod, nir_intrinsic_instr *load_intr, uint32_t **tables, const uint32_t *table_sizes)
{
/* Which input components are being loaded by this instruction */
BITSET_DECLARE(input_bits, 32 * 4) = { 0 };
if (!set_input_bits(mod, load_intr, input_bits, &tables, &table_sizes))
return false;
nir_instr_worklist *worklist = nir_instr_worklist_create();
nir_instr_worklist_push_tail(worklist, &load_intr->instr);
bool any_bits_set = false;
nir_foreach_instr_in_worklist(instr, worklist) {
if (instr->pass_flags)
continue;
instr->pass_flags = 1;
nir_foreach_dest(instr, add_dest_to_worklist, worklist);
switch (instr->type) {
case nir_instr_type_jump: {
nir_jump_instr *jump = nir_instr_as_jump(instr);
switch (jump->type) {
case nir_jump_break:
case nir_jump_continue: {
nir_cf_node *parent = &instr->block->cf_node;
while (parent->type != nir_cf_node_loop)
parent = parent->parent;
nir_foreach_block_in_cf_node(block, parent)
nir_foreach_instr(i, block)
nir_instr_worklist_push_tail(worklist, i);
}
break;
default:
unreachable("Don't expect any other jumps");
}
break;
}
case nir_instr_type_intrinsic: {
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
switch (intr->intrinsic) {
case nir_intrinsic_store_output:
case nir_intrinsic_store_per_vertex_output:
any_bits_set |= set_output_bits(mod, intr, input_bits, tables, table_sizes);
break;
/* TODO: Memory writes */
default:
break;
}
break;
}
default:
break;
}
}
nir_instr_worklist_destroy(worklist);
return any_bits_set;
}
/* For every input load, compute the set of output stores that it can contribute to.
* If it contributes to a store to memory, If it's used for control flow, then any
* instruction in the CFG that it impacts is considered to contribute.
* Ideally, we should also handle stores to outputs/memory and then loads from that
* output/memory, but this is non-trivial and unclear how much impact that would have. */
bool
dxil_nir_analyze_io_dependencies(struct dxil_module *mod, nir_shader *s)
{
bool any_outputs = false;
for (uint32_t i = 0; i < 4; ++i)
any_outputs |= mod->num_psv_outputs[i] > 0;
if (mod->shader_kind == DXIL_HULL_SHADER)
any_outputs |= mod->num_psv_patch_consts > 0;
if (!any_outputs)
return false;
bool any_bits_set = false;
nir_foreach_function(func, s) {
assert(func->impl);
/* Hull shaders have a patch constant function */
assert(func->is_entrypoint || s->info.stage == MESA_SHADER_TESS_CTRL);
/* Pass 1: input/view ID -> output dependencies */
nir_foreach_block(block, func->impl) {
nir_foreach_instr(instr, block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
uint32_t **tables = mod->io_dependency_table;
const uint32_t *table_sizes = mod->dependency_table_dwords_per_input;
switch (intr->intrinsic) {
case nir_intrinsic_load_view_index:
tables = mod->viewid_dependency_table;
FALLTHROUGH;
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_vertex_input:
case nir_intrinsic_load_interpolated_input:
break;
default:
continue;
}
clear_pass_flags(func->impl);
any_bits_set |= propagate_input_to_output_dependencies(mod, intr, tables, table_sizes);
}
}
/* Pass 2: output -> output dependencies */
/* TODO */
}
return any_bits_set;
}

View file

@ -85,6 +85,9 @@ bool dxil_nir_split_unaligned_loads_stores(nir_shader *shader, nir_variable_mode
bool dxil_nir_lower_unsupported_subgroup_scan(nir_shader *s);
bool dxil_nir_forward_front_face(nir_shader *s);
struct dxil_module;
bool dxil_nir_analyze_io_dependencies(struct dxil_module *mod, nir_shader *s);
#ifdef __cplusplus
}
#endif

View file

@ -1960,6 +1960,34 @@ emit_metadata(struct ntd_context *ctx)
&dx_resources, 1);
}
if (ctx->mod.minor_version >= 2 &&
dxil_nir_analyze_io_dependencies(&ctx->mod, ctx->shader)) {
const struct dxil_type *i32_type = dxil_module_get_int_type(&ctx->mod, 32);
if (!i32_type)
return false;
const struct dxil_type *array_type = dxil_module_get_array_type(&ctx->mod, i32_type, ctx->mod.serialized_dependency_table_size);
if (!array_type)
return false;
const struct dxil_value **array_entries = malloc(sizeof(const struct value *) * ctx->mod.serialized_dependency_table_size);
if (!array_entries)
return false;
for (uint32_t i = 0; i < ctx->mod.serialized_dependency_table_size; ++i)
array_entries[i] = dxil_module_get_int32_const(&ctx->mod, ctx->mod.serialized_dependency_table[i]);
const struct dxil_value *array_val = dxil_module_get_array_const(&ctx->mod, array_type, array_entries);
free((void *)array_entries);
const struct dxil_mdnode *view_id_state_val = dxil_get_metadata_value(&ctx->mod, array_type, array_val);
if (!view_id_state_val)
return false;
const struct dxil_mdnode *view_id_state_node = dxil_get_metadata_node(&ctx->mod, &view_id_state_val, 1);
dxil_add_metadata_named_node(&ctx->mod, "dx.viewIdState", &view_id_state_node, 1);
}
const struct dxil_mdnode *dx_type_annotations[] = { main_type_annotation };
return dxil_add_metadata_named_node(&ctx->mod, "dx.typeAnnotations",
dx_type_annotations,