microsoft/clc: Use NIR_PASS instead of NIR_PASS_V

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36299>
This commit is contained in:
Jesse Natalie 2025-07-22 13:35:57 -07:00
parent 2ca10a854f
commit 4528ad5281
2 changed files with 136 additions and 190 deletions

View file

@ -273,17 +273,15 @@ clc_lower_input_image_deref(nir_builder *b, struct clc_image_lower_context *cont
nir_instr_remove(&context->deref->instr);
}
static void
static bool
clc_lower_images(nir_shader *nir, struct clc_image_lower_context *context)
{
nir_foreach_function(func, nir) {
if (!func->is_entrypoint)
continue;
assert(func->impl);
bool progress = false;
nir_foreach_function_impl(impl, nir) {
nir_builder b = nir_builder_create(impl);
nir_builder b = nir_builder_create(func->impl);
nir_foreach_block(block, func->impl) {
bool func_progress = false;
nir_foreach_block(block, impl) {
nir_foreach_instr_safe(instr, block) {
if (instr->type == nir_instr_type_deref) {
context->deref = nir_instr_as_deref(instr);
@ -291,24 +289,31 @@ clc_lower_images(nir_shader *nir, struct clc_image_lower_context *context)
if (glsl_type_is_image(context->deref->type)) {
assert(context->deref->deref_type == nir_deref_type_var);
clc_lower_input_image_deref(&b, context);
func_progress = true;
}
}
}
}
progress |= nir_progress(func_progress, impl, nir_metadata_control_flow | nir_metadata_loop_analysis);
}
nir_foreach_variable_with_modes_safe(var, nir, nir_var_image) {
if (glsl_type_is_image(var->type) && glsl_get_sampler_result_type(var->type) == GLSL_TYPE_VOID)
if (glsl_type_is_image(var->type) && glsl_get_sampler_result_type(var->type) == GLSL_TYPE_VOID) {
exec_node_remove(&var->node);
progress = true;
}
}
return progress;
}
static void
static bool
clc_lower_64bit_semantics(nir_shader *nir)
{
bool progress = false;
nir_foreach_function_impl(impl, nir) {
nir_builder b = nir_builder_create(impl);
bool func_progress = false;
nir_foreach_block(block, impl) {
nir_foreach_instr_safe(instr, block) {
if (instr->type == nir_instr_type_intrinsic) {
@ -328,6 +333,7 @@ clc_lower_64bit_semantics(nir_shader *nir)
if (nir_instr_def(instr)->bit_size != 64)
continue;
func_progress = true;
intrinsic->def.bit_size = 32;
b.cursor = nir_after_instr(instr);
@ -339,21 +345,21 @@ clc_lower_64bit_semantics(nir_shader *nir)
}
}
}
progress |= nir_progress(func_progress, impl, nir_metadata_control_flow | nir_metadata_loop_analysis);
}
return progress;
}
static void
static bool
clc_lower_nonnormalized_samplers(nir_shader *nir,
const dxil_wrap_sampler_state *states)
{
nir_foreach_function(func, nir) {
if (!func->is_entrypoint)
continue;
assert(func->impl);
bool progress = false;
nir_foreach_function_impl(impl, nir) {
nir_builder b = nir_builder_create(impl);
nir_builder b = nir_builder_create(func->impl);
nir_foreach_block(block, func->impl) {
bool func_progress = false;
nir_foreach_block(block, impl) {
nir_foreach_instr_safe(instr, block) {
if (instr->type != nir_instr_type_tex)
continue;
@ -376,6 +382,7 @@ clc_lower_nonnormalized_samplers(nir_shader *nir,
if (!states[sampler->data.binding].is_nonnormalized_coords)
continue;
func_progress = true;
b.cursor = nir_before_instr(&tex->instr);
int coords_idx = nir_tex_instr_src_index(tex, nir_tex_src_coord);
@ -408,7 +415,9 @@ clc_lower_nonnormalized_samplers(nir_shader *nir,
nir_src_rewrite(&tex->src[coords_idx].src, normalized_coords);
}
}
progress = nir_progress(func_progress, impl, nir_metadata_control_flow | nir_metadata_loop_analysis);
}
return progress;
}
static nir_variable *
@ -456,26 +465,24 @@ add_work_properties_var(struct clc_dxil_object *dxil,
return var;
}
static void
static bool
clc_lower_constant_to_ssbo(nir_shader *nir,
const struct clc_kernel_info *kerninfo, unsigned *uav_id)
{
/* Update UBO vars and assign them a binding. */
bool progress = false;
nir_foreach_variable_with_modes(var, nir, nir_var_mem_constant) {
var->data.mode = nir_var_mem_ssbo;
var->data.binding = (*uav_id)++;
progress = true;
}
/* And finally patch all the derefs referincing the constant
* variables/pointers.
*/
nir_foreach_function(func, nir) {
if (!func->is_entrypoint)
continue;
assert(func->impl);
nir_foreach_block(block, func->impl) {
nir_foreach_function_impl(impl, nir) {
bool func_progress = false;
nir_foreach_block(block, impl) {
nir_foreach_instr(instr, block) {
if (instr->type != nir_instr_type_deref)
continue;
@ -486,24 +493,26 @@ clc_lower_constant_to_ssbo(nir_shader *nir,
continue;
deref->modes = nir_var_mem_ssbo;
func_progress = true;
}
}
progress |= nir_progress(func_progress, impl, nir_metadata_all);
}
return progress;
}
static void
static bool
clc_change_variable_mode(nir_shader *nir, nir_variable_mode from, nir_variable_mode to)
{
nir_foreach_variable_with_modes(var, nir, from)
bool progress = false;
nir_foreach_variable_with_modes(var, nir, from) {
var->data.mode = to;
progress = true;
}
nir_foreach_function(func, nir) {
if (!func->is_entrypoint)
continue;
assert(func->impl);
nir_foreach_block(block, func->impl) {
nir_foreach_function_impl(impl, nir) {
bool func_progress = false;
nir_foreach_block(block, impl) {
nir_foreach_instr(instr, block) {
if (instr->type != nir_instr_type_deref)
continue;
@ -514,9 +523,12 @@ clc_change_variable_mode(nir_shader *nir, nir_variable_mode from, nir_variable_m
continue;
deref->modes = to;
func_progress = true;
}
}
progress |= nir_progress(func_progress, impl, nir_metadata_all);
}
return progress;
}
static void
@ -795,8 +807,8 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
}
nir->info.workgroup_size_variable = true;
NIR_PASS_V(nir, nir_lower_goto_ifs);
NIR_PASS_V(nir, nir_opt_dead_cf);
NIR_PASS(_, nir, nir_lower_goto_ifs);
NIR_PASS(_, nir, nir_opt_dead_cf);
struct clc_dxil_metadata *metadata = &out_dxil->metadata;
@ -828,10 +840,10 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
// Inline all functions first.
// according to the comment on nir_inline_functions
NIR_PASS_V(nir, nir_lower_variable_initializers, nir_var_function_temp);
NIR_PASS_V(nir, nir_lower_returns);
NIR_PASS_V(nir, nir_link_shader_functions, clc_libclc_get_clc_shader(lib));
NIR_PASS_V(nir, nir_inline_functions);
NIR_PASS(_, nir, nir_lower_variable_initializers, nir_var_function_temp);
NIR_PASS(_, nir, nir_lower_returns);
NIR_PASS(_, nir, nir_link_shader_functions, clc_libclc_get_clc_shader(lib));
NIR_PASS(_, nir, nir_inline_functions);
// Pick off the single entrypoint that we want.
nir_remove_non_entrypoints(nir);
@ -867,23 +879,23 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
} while (progress);
}
NIR_PASS_V(nir, nir_scale_fdiv);
NIR_PASS(_, nir, nir_scale_fdiv);
/* 128 is the minimum value for CL_DEVICE_MAX_READ_IMAGE_ARGS and used by CLOn12 */
dxil_wrap_sampler_state int_sampler_states[128] = { {{0}} };
unsigned sampler_id = 0;
NIR_PASS_V(nir, nir_lower_variable_initializers, ~(nir_var_function_temp | nir_var_shader_temp));
NIR_PASS(_, nir, nir_lower_variable_initializers, ~(nir_var_function_temp | nir_var_shader_temp));
// Ensure the printf struct has explicit types, but we'll throw away the scratch size, because we haven't
// necessarily removed all temp variables (e.g. the printf struct itself) at this point, so we'll rerun this later
assert(nir->scratch_size == 0);
NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_function_temp, glsl_get_cl_type_size_align);
NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp, glsl_get_cl_type_size_align);
nir_lower_printf_options printf_options = {
.max_buffer_size = 1024 * 1024
};
NIR_PASS_V(nir, nir_lower_printf, &printf_options);
NIR_PASS(_, nir, nir_lower_printf, &printf_options);
metadata->printf.info_count = nir->printf_info_count;
metadata->printf.infos = calloc(nir->printf_info_count, sizeof(struct clc_printf_info));
@ -896,7 +908,7 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
}
// For uniforms (kernel inputs, minus images), run this before adjusting variable list via image/sampler lowering
NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_uniform, glsl_get_cl_type_size_align);
NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_uniform, glsl_get_cl_type_size_align);
// Calculate input offsets/metadata.
unsigned uav_id = 0;
@ -930,12 +942,12 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
unsigned num_global_inputs = uav_id;
// Before removing dead uniforms, dedupe inline samplers to make more dead uniforms
NIR_PASS_V(nir, nir_dedup_inline_samplers);
NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_uniform | nir_var_mem_ubo |
NIR_PASS(_, nir, nir_dedup_inline_samplers);
NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_uniform | nir_var_mem_ubo |
nir_var_mem_constant | nir_var_function_temp | nir_var_image, NULL);
nir->scratch_size = 0;
NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
nir_var_mem_shared | nir_var_function_temp | nir_var_mem_global | nir_var_mem_constant,
glsl_get_cl_type_size_align);
@ -956,32 +968,32 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
NIR_PASS(progress, nir, nir_opt_cse);
} while (progress);
}
NIR_PASS_V(nir, nir_lower_memcpy);
NIR_PASS(_, nir, nir_lower_memcpy);
NIR_PASS_V(nir, clc_nir_lower_global_pointers_to_constants);
NIR_PASS(_, nir, clc_nir_lower_global_pointers_to_constants);
// Attempt to preserve derefs to constants by moving them to shader_temp
NIR_PASS_V(nir, dxil_nir_lower_constant_to_temp);
NIR_PASS(_, nir, dxil_nir_lower_constant_to_temp);
// While inserting new var derefs for our "logical" addressing mode, temporarily
// switch the pointer size to 32-bit.
nir->info.cs.ptr_size = 32;
NIR_PASS_V(nir, nir_split_struct_vars, nir_var_shader_temp);
NIR_PASS_V(nir, dxil_nir_flatten_var_arrays, nir_var_shader_temp);
NIR_PASS_V(nir, dxil_nir_lower_var_bit_size, nir_var_shader_temp,
NIR_PASS(_, nir, nir_split_struct_vars, nir_var_shader_temp);
NIR_PASS(_, nir, dxil_nir_flatten_var_arrays, nir_var_shader_temp);
NIR_PASS(_, nir, dxil_nir_lower_var_bit_size, nir_var_shader_temp,
(supported_int_sizes & 16) ? 16 : 32, (supported_int_sizes & 64) ? 64 : 32);
nir->info.cs.ptr_size = 64;
NIR_PASS_V(nir, clc_lower_constant_to_ssbo, out_dxil->kernel, &uav_id);
NIR_PASS_V(nir, clc_change_variable_mode, nir_var_shader_temp, nir_var_mem_constant);
NIR_PASS_V(nir, clc_change_variable_mode, nir_var_mem_global, nir_var_mem_ssbo);
NIR_PASS(_, nir, clc_lower_constant_to_ssbo, out_dxil->kernel, &uav_id);
NIR_PASS(_, nir, clc_change_variable_mode, nir_var_shader_temp, nir_var_mem_constant);
NIR_PASS(_, nir, clc_change_variable_mode, nir_var_mem_global, nir_var_mem_ssbo);
bool has_printf = false;
NIR_PASS(has_printf, nir, clc_lower_printf_base, uav_id);
metadata->printf.uav_id = has_printf ? uav_id++ : -1;
NIR_PASS_V(nir, dxil_nir_lower_deref_ssbo);
NIR_PASS(_, nir, dxil_nir_lower_deref_ssbo);
NIR_PASS_V(nir, dxil_nir_split_unaligned_loads_stores, nir_var_mem_shared | nir_var_function_temp);
NIR_PASS(_, nir, dxil_nir_split_unaligned_loads_stores, nir_var_mem_shared | nir_var_function_temp);
// Second pass over inputs to calculate image bindings
unsigned srv_id = 0;
@ -1032,33 +1044,33 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
}
// Needs to come before lower_explicit_io
NIR_PASS_V(nir, nir_lower_readonly_images_to_tex, false);
NIR_PASS(_, nir, nir_lower_readonly_images_to_tex, false);
struct clc_image_lower_context image_lower_context = { metadata, &srv_id, &uav_id };
NIR_PASS_V(nir, clc_lower_images, &image_lower_context);
NIR_PASS_V(nir, clc_lower_nonnormalized_samplers, int_sampler_states);
NIR_PASS_V(nir, nir_lower_samplers);
NIR_PASS_V(nir, dxil_lower_sample_to_txf_for_integer_tex,
NIR_PASS(_, nir, clc_lower_images, &image_lower_context);
NIR_PASS(_, nir, clc_lower_nonnormalized_samplers, int_sampler_states);
NIR_PASS(_, nir, nir_lower_samplers);
NIR_PASS(_, nir, dxil_lower_sample_to_txf_for_integer_tex,
sampler_id, int_sampler_states, NULL, 14.0f);
assert(nir->info.cs.ptr_size == 64);
NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_ssbo,
NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_mem_ssbo,
nir_address_format_32bit_index_offset_pack64);
NIR_PASS_V(nir, nir_lower_explicit_io,
NIR_PASS(_, nir, nir_lower_explicit_io,
nir_var_mem_shared | nir_var_function_temp | nir_var_uniform,
nir_address_format_32bit_offset_as_64bit);
NIR_PASS_V(nir, nir_lower_system_values);
NIR_PASS(_, nir, nir_lower_system_values);
nir_lower_compute_system_values_options compute_options = {
.has_base_global_invocation_id = (conf && conf->support_global_work_id_offsets),
.has_base_workgroup_id = (conf && conf->support_workgroup_id_offsets),
};
NIR_PASS_V(nir, nir_lower_compute_system_values, &compute_options);
NIR_PASS(_, nir, nir_lower_compute_system_values, &compute_options);
NIR_PASS_V(nir, clc_lower_64bit_semantics);
NIR_PASS(_, nir, clc_lower_64bit_semantics);
NIR_PASS_V(nir, nir_opt_deref);
NIR_PASS_V(nir, nir_lower_vars_to_ssa);
NIR_PASS(_, nir, nir_opt_deref);
NIR_PASS(_, nir, nir_lower_vars_to_ssa);
unsigned cbv_id = 0;
@ -1097,10 +1109,10 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
}
}
NIR_PASS_V(nir, clc_nir_lower_kernel_input_loads, inputs_var);
NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_ubo,
NIR_PASS(_, nir, clc_nir_lower_kernel_input_loads, inputs_var);
NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_mem_ubo,
nir_address_format_32bit_index_offset);
NIR_PASS_V(nir, clc_nir_lower_system_values, work_properties_var);
NIR_PASS(_, nir, clc_nir_lower_system_values, work_properties_var);
const struct dxil_nir_lower_loads_stores_options loads_stores_options = {
.use_16bit_ssbo = false,
};
@ -1132,17 +1144,17 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
nir->info.shared_size += size;
}
NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil, &loads_stores_options);
NIR_PASS_V(nir, dxil_nir_opt_alu_deref_srcs);
NIR_PASS_V(nir, nir_lower_fp16_casts, nir_lower_fp16_all);
NIR_PASS_V(nir, nir_lower_convert_alu_types, NULL);
NIR_PASS(_, nir, dxil_nir_lower_loads_stores_to_dxil, &loads_stores_options);
NIR_PASS(_, nir, dxil_nir_opt_alu_deref_srcs);
NIR_PASS(_, nir, nir_lower_fp16_casts, nir_lower_fp16_all);
NIR_PASS(_, nir, nir_lower_convert_alu_types, NULL);
// Convert pack to pack_split
NIR_PASS_V(nir, nir_lower_pack);
NIR_PASS(_, nir, nir_lower_pack);
// Lower pack_split to bit math
NIR_PASS_V(nir, nir_opt_algebraic);
NIR_PASS(_, nir, nir_opt_algebraic);
NIR_PASS_V(nir, nir_opt_dce);
NIR_PASS(_, nir, nir_opt_dce);
nir_validate_shader(nir, "Validate before feeding NIR to the DXIL compiler");
struct nir_to_dxil_options opts = {

View file

@ -46,101 +46,60 @@ load_ubo(nir_builder *b, nir_intrinsic_instr *intr, nir_variable *var, unsigned
}
static bool
lower_load_base_global_invocation_id(nir_builder *b, nir_intrinsic_instr *intr,
nir_variable *var)
is_clc_system_value(const nir_instr *instr, const void *state)
{
b->cursor = nir_after_instr(&intr->instr);
nir_def *offset = load_ubo(b, intr, var, offsetof(struct clc_work_properties_data,
global_offset_x));
nir_def_replace(&intr->def, offset);
return true;
if (instr->type != nir_instr_type_intrinsic)
return false;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
switch (intr->intrinsic) {
case nir_intrinsic_load_base_global_invocation_id:
case nir_intrinsic_load_work_dim:
case nir_intrinsic_load_num_workgroups:
case nir_intrinsic_load_base_workgroup_id:
return true;
default:
return false;
}
}
static bool
lower_load_work_dim(nir_builder *b, nir_intrinsic_instr *intr,
nir_variable *var)
static nir_def *
lower_clc_system_value(nir_builder *b, nir_instr *instr, void *state)
{
b->cursor = nir_after_instr(&intr->instr);
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
nir_variable *var = (nir_variable *)state;
nir_def *dim = load_ubo(b, intr, var, offsetof(struct clc_work_properties_data,
work_dim));
nir_def_replace(&intr->def, dim);
return true;
}
static bool
lower_load_num_workgroups(nir_builder *b, nir_intrinsic_instr *intr,
nir_variable *var)
{
b->cursor = nir_after_instr(&intr->instr);
nir_def *count =
load_ubo(b, intr, var, offsetof(struct clc_work_properties_data,
group_count_total_x));
nir_def_replace(&intr->def, count);
return true;
}
static bool
lower_load_base_workgroup_id(nir_builder *b, nir_intrinsic_instr *intr,
nir_variable *var)
{
b->cursor = nir_after_instr(&intr->instr);
nir_def *offset =
load_ubo(b, intr, var, offsetof(struct clc_work_properties_data,
group_id_offset_x));
nir_def_replace(&intr->def, offset);
return true;
switch (intr->intrinsic) {
case nir_intrinsic_load_base_global_invocation_id:
return load_ubo(b, intr, var, offsetof(struct clc_work_properties_data, global_offset_x));
case nir_intrinsic_load_work_dim:
return load_ubo(b, intr, var, offsetof(struct clc_work_properties_data, work_dim));
case nir_intrinsic_load_num_workgroups:
return load_ubo(b, intr, var, offsetof(struct clc_work_properties_data, group_count_total_x));
case nir_intrinsic_load_base_workgroup_id:
return load_ubo(b, intr, var, offsetof(struct clc_work_properties_data, group_id_offset_x));
default:
return NULL;
}
}
bool
clc_nir_lower_system_values(nir_shader *nir, nir_variable *var)
{
bool progress = false;
foreach_list_typed(nir_function, func, node, &nir->functions) {
if (!func->is_entrypoint)
continue;
assert(func->impl);
nir_builder b = nir_builder_create(func->impl);
nir_foreach_block(block, func->impl) {
nir_foreach_instr_safe(instr, block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
switch (intr->intrinsic) {
case nir_intrinsic_load_base_global_invocation_id:
progress |= lower_load_base_global_invocation_id(&b, intr, var);
break;
case nir_intrinsic_load_work_dim:
progress |= lower_load_work_dim(&b, intr, var);
break;
case nir_intrinsic_load_num_workgroups:
progress |= lower_load_num_workgroups(&b, intr, var);
break;
case nir_intrinsic_load_base_workgroup_id:
progress |= lower_load_base_workgroup_id(&b, intr, var);
break;
default: break;
}
}
}
}
return progress;
return nir_shader_lower_instructions(nir, is_clc_system_value, lower_clc_system_value, var);
}
static bool
lower_load_kernel_input(nir_builder *b, nir_intrinsic_instr *intr,
nir_variable *var)
is_load_kernel_input(const nir_instr *instr, const void *state)
{
b->cursor = nir_before_instr(&intr->instr);
return instr->type == nir_instr_type_intrinsic &&
nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_load_kernel_input;
}
static nir_def *
lower_load_kernel_input(nir_builder *b, nir_instr *instr, void *state)
{
nir_variable *var = (nir_variable *)state;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
unsigned bit_size = intr->def.bit_size;
enum glsl_base_type base_type;
@ -171,38 +130,13 @@ lower_load_kernel_input(nir_builder *b, nir_intrinsic_instr *intr,
deref->cast.align_mul = nir_intrinsic_align_mul(intr);
deref->cast.align_offset = nir_intrinsic_align_offset(intr);
nir_def *result =
nir_load_deref(b, deref);
nir_def_replace(&intr->def, result);
return true;
return nir_load_deref(b, deref);
}
bool
clc_nir_lower_kernel_input_loads(nir_shader *nir, nir_variable *var)
{
bool progress = false;
foreach_list_typed(nir_function, func, node, &nir->functions) {
if (!func->is_entrypoint)
continue;
assert(func->impl);
nir_builder b = nir_builder_create(func->impl);
nir_foreach_block(block, func->impl) {
nir_foreach_instr_safe(instr, block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
if (intr->intrinsic == nir_intrinsic_load_kernel_input)
progress |= lower_load_kernel_input(&b, intr, var);
}
}
}
return progress;
return nir_shader_lower_instructions(nir, is_load_kernel_input, lower_load_kernel_input, var);
}