diff --git a/src/microsoft/spirv_to_dxil/spirv_to_dxil.c b/src/microsoft/spirv_to_dxil/spirv_to_dxil.c index 3c9e0580502..15710c91cff 100644 --- a/src/microsoft/spirv_to_dxil/spirv_to_dxil.c +++ b/src/microsoft/spirv_to_dxil/spirv_to_dxil.c @@ -152,6 +152,108 @@ dxil_spirv_nir_lower_shader_system_values(nir_shader *shader, &data); } +static nir_variable * +add_push_constant_var(nir_shader *nir, unsigned size, unsigned desc_set, unsigned binding) +{ + /* Size must be a multiple of 16 as buffer load is loading 16 bytes at a time */ + size = ALIGN_POT(size, 16) / 16; + + const struct glsl_type *array_type = glsl_array_type(glsl_uint_type(), size, 4); + const struct glsl_struct_field field = {array_type, "arr"}; + nir_variable *var = nir_variable_create( + nir, nir_var_mem_ubo, + glsl_struct_type(&field, 1, "block", false), "push_constants"); + var->data.descriptor_set = desc_set; + var->data.binding = binding; + var->data.how_declared = nir_var_hidden; + return var; +} + +struct lower_load_push_constant_data { + nir_address_format ubo_format; + unsigned desc_set; + unsigned binding; + uint32_t min; + uint32_t max; +}; + +static bool +lower_load_push_constant(struct nir_builder *builder, nir_instr *instr, + void *cb_data) +{ + struct lower_load_push_constant_data *data = + (struct lower_load_push_constant_data *)cb_data; + + if (instr->type != nir_instr_type_intrinsic) + return false; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + + /* All the intrinsics we care about are loads */ + if (intrin->intrinsic != nir_intrinsic_load_push_constant) + return false; + + uint32_t base = nir_intrinsic_base(intrin); + uint32_t range = nir_intrinsic_range(intrin); + data->min = MIN2(data->min, base); + data->max = MAX2(data->max, base + range); + + builder->cursor = nir_after_instr(instr); + nir_address_format ubo_format = data->ubo_format; + + nir_ssa_def *index = nir_vulkan_resource_index( + builder, nir_address_format_num_components(ubo_format), + nir_address_format_bit_size(ubo_format), + nir_imm_int(builder, 0), + .desc_set = data->desc_set, .binding = data->binding, + .desc_type = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER); + + nir_ssa_def *load_desc = nir_load_vulkan_descriptor( + builder, nir_address_format_num_components(ubo_format), + nir_address_format_bit_size(ubo_format), + index, .desc_type = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER); + + nir_ssa_def *offset = nir_ssa_for_src(builder, intrin->src[0], 1); + nir_ssa_def *load_data = build_load_ubo_dxil( + builder, nir_channel(builder, load_desc, 0), + nir_iadd_imm(builder, offset, base), + nir_dest_num_components(intrin->dest), nir_dest_bit_size(intrin->dest)); + + nir_ssa_def_rewrite_uses(&intrin->dest.ssa, load_data); + nir_instr_remove(instr); + return true; +} + +static bool +dxil_spirv_nir_lower_load_push_constant(nir_shader *shader, + nir_address_format ubo_format, + unsigned desc_set, unsigned binding, + uint32_t *size) +{ + bool ret; + struct lower_load_push_constant_data data = { + .ubo_format = ubo_format, + .desc_set = desc_set, + .binding = binding, + .min = UINT32_MAX, + .max = 0, + }; + ret = nir_shader_instructions_pass(shader, lower_load_push_constant, + nir_metadata_block_index | + nir_metadata_dominance | + nir_metadata_loop_analysis, + &data); + + if (data.min >= data.max) + *size = 0; + else + *size = (data.max - data.min); + + assert(ret == (*size > 0)); + + return ret; +} + struct lower_yz_flip_data { bool *reads_sysval_ubo; const struct dxil_spirv_runtime_conf *rt_conf; @@ -361,6 +463,15 @@ spirv_to_dxil(const uint32_t *words, size_t word_count, nir_var_system_value | nir_var_mem_shared, NULL); + uint32_t push_constant_size = 0; + NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_push_const, + nir_address_format_32bit_offset); + NIR_PASS_V(nir, dxil_spirv_nir_lower_load_push_constant, + spirv_opts.ubo_addr_format, + conf->push_constant_cbv.register_space, + conf->push_constant_cbv.base_shader_register, + &push_constant_size); + NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_ubo | nir_var_mem_ssbo, nir_address_format_32bit_index_offset); @@ -411,6 +522,12 @@ spirv_to_dxil(const uint32_t *words, size_t word_count, conf->runtime_data_cbv.base_shader_register); } + if (push_constant_size > 0) { + add_push_constant_var(nir, push_constant_size, + conf->push_constant_cbv.register_space, + conf->push_constant_cbv.base_shader_register); + } + NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL); NIR_PASS_V(nir, nir_opt_dce); NIR_PASS_V(nir, dxil_nir_lower_double_math); diff --git a/src/microsoft/spirv_to_dxil/spirv_to_dxil.h b/src/microsoft/spirv_to_dxil/spirv_to_dxil.h index 23bcaeb8787..afc799de03a 100644 --- a/src/microsoft/spirv_to_dxil/spirv_to_dxil.h +++ b/src/microsoft/spirv_to_dxil/spirv_to_dxil.h @@ -132,6 +132,11 @@ struct dxil_spirv_runtime_conf { uint32_t base_shader_register; } runtime_data_cbv; + struct { + uint32_t register_space; + uint32_t base_shader_register; + } push_constant_cbv; + // Set true if vertex and instance ids have already been converted to // zero-based. Otherwise, runtime_data will be required to lower them. bool zero_based_vertex_instance_id;