From e1855dc9477bc64acd35baea138608cae6da5743 Mon Sep 17 00:00:00 2001 From: Mike Blumenkrantz Date: Fri, 23 Jan 2026 10:40:55 -0500 Subject: [PATCH] ntv: add basic vulkan support this enables (some) shaders generated by vtn to successfully pass through ntv and generate valid spirv the majority of the plumbing is to handle deref casts, which are currently assumed to originate solely from loading descriptors Part-of: --- .../drivers/zink/nir_to_spirv/nir_to_spirv.c | 254 +++++++++++++++--- .../drivers/zink/nir_to_spirv/nir_to_spirv.h | 4 + 2 files changed, 225 insertions(+), 33 deletions(-) diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 00f01217721..5db99f450d6 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -278,6 +278,37 @@ find_image_type(struct ntv_context *ctx, nir_variable *var) return he ? (intptr_t)he->data : 0; } +static nir_variable * +find_vulkan_deref_var(struct ntv_context *ctx, nir_deref_instr *deref) +{ + nir_variable *var = nir_deref_instr_get_variable(deref); + if (var) + return var; + assert(ctx->sinfo->is_native_vulkan); + nir_deref_instr *parent = nir_deref_instr_parent(deref); + while (parent) { + if (parent->deref_type == nir_deref_type_cast && nir_def_is_intrinsic(parent->parent.ssa)) { + nir_intrinsic_instr *intr = nir_def_as_intrinsic(parent->parent.ssa); + while (intr->intrinsic != nir_intrinsic_vulkan_resource_index) + intr = nir_def_as_intrinsic(intr->src[0].ssa); + int desc_set = nir_intrinsic_desc_set(intr); + int binding = nir_intrinsic_binding(intr); + nir_foreach_variable_with_modes(i, ctx->nir, deref->modes) { + if (i->data.descriptor_set == desc_set && i->data.binding == binding) { + var = i; + break; + } + } + return var; + } else if (parent->deref_type == nir_deref_type_var) { + return parent->var; + } else { + parent = nir_deref_instr_parent(parent); + } + } + return NULL; +} + static SpvScope get_scope(mesa_scope scope) { @@ -1390,23 +1421,31 @@ get_bo_struct_type(struct ntv_context *ctx, struct nir_variable *var) struct hash_entry *he = _mesa_hash_table_search(ctx->bo_struct_types, var); if (he) return (SpvId)(uintptr_t)he->data; - const struct glsl_type *bare_type = glsl_without_array(var->type); - unsigned bitsize = glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(bare_type, 0))); - SpvId array_type = get_bo_array_type(ctx, var); - _mesa_hash_table_insert(ctx->bo_array_types, var, (void *)(uintptr_t)array_type); - bool ssbo = var->data.mode == nir_var_mem_ssbo; + SpvId struct_type = 0; + if (ctx->sinfo->is_native_vulkan) { + struct_type = get_glsl_type(ctx, var->type, false); + } else { + const struct glsl_type *bare_type = glsl_without_array(var->type); + unsigned bitsize = glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(bare_type, 0))); + SpvId array_type = get_bo_array_type(ctx, var); + _mesa_hash_table_insert(ctx->bo_array_types, var, (void *)(uintptr_t)array_type); + bool ssbo = var->data.mode == nir_var_mem_ssbo; - // wrap UBO-array in a struct - SpvId runtime_array = 0; - if (ssbo && glsl_get_length(bare_type) > 1) { - const struct glsl_type *last_member = glsl_get_struct_field(bare_type, glsl_get_length(bare_type) - 1); - if (glsl_type_is_unsized_array(last_member)) { - runtime_array = spirv_builder_type_runtime_array(&ctx->builder, get_uvec_type(ctx, bitsize, 1)); - spirv_builder_emit_array_stride(&ctx->builder, runtime_array, glsl_get_explicit_stride(last_member)); - } + // wrap UBO-array in a struct + SpvId runtime_array = 0; + if (ssbo && glsl_get_length(bare_type) > 1) { + const struct glsl_type *last_member = glsl_get_struct_field(bare_type, glsl_get_length(bare_type) - 1); + if (glsl_type_is_unsized_array(last_member)) { + runtime_array = spirv_builder_type_runtime_array(&ctx->builder, get_uvec_type(ctx, bitsize, 1)); + spirv_builder_emit_array_stride(&ctx->builder, runtime_array, glsl_get_explicit_stride(last_member)); + } + } + SpvId types[] = {array_type, runtime_array}; + struct_type = spirv_builder_type_struct(&ctx->builder, types, 1 + !!runtime_array); + spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0); + if (runtime_array) + spirv_builder_emit_member_offset(&ctx->builder, struct_type, 1, 0); } - SpvId types[] = {array_type, runtime_array}; - SpvId struct_type = spirv_builder_type_struct(&ctx->builder, types, 1 + !!runtime_array); if (var->name) { char struct_name[100]; snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name); @@ -1415,9 +1454,6 @@ get_bo_struct_type(struct ntv_context *ctx, struct nir_variable *var) spirv_builder_emit_decoration(&ctx->builder, struct_type, SpvDecorationBlock); - spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0); - if (runtime_array) - spirv_builder_emit_member_offset(&ctx->builder, struct_type, 1, 0); return struct_type; } @@ -1425,12 +1461,12 @@ get_bo_struct_type(struct ntv_context *ctx, struct nir_variable *var) static void emit_bo(struct ntv_context *ctx, struct nir_variable *var, bool aliased) { - unsigned bitsize = glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(glsl_without_array(var->type), 0))); - bool ssbo = var->data.mode == nir_var_mem_ssbo; SpvId struct_type = get_bo_struct_type(ctx, var); _mesa_hash_table_insert(ctx->bo_struct_types, var, (void *)(uintptr_t)struct_type); + bool ssbo = var->data.mode == nir_var_mem_ssbo; + unsigned bitsize = ctx->sinfo->is_native_vulkan ? 32 : glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(glsl_without_array(var->type), 0))); SpvId array_length = emit_uint_const(ctx, 32, glsl_get_length(var->type)); - SpvId array_type = spirv_builder_type_array(&ctx->builder, struct_type, array_length); + SpvId array_type = ctx->sinfo->is_native_vulkan ? struct_type : spirv_builder_type_array(&ctx->builder, struct_type, array_length); SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder, ssbo ? SpvStorageClassStorageBuffer : SpvStorageClassUniform, array_type); @@ -2319,7 +2355,7 @@ emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); SpvId type; if (glsl_type_is_image(deref->type)) { - nir_variable *var = nir_deref_instr_get_variable(deref); + nir_variable *var = find_vulkan_deref_var(ctx, deref); const struct glsl_type *gtype = glsl_without_array(var->type); type = get_image_type(ctx, var, glsl_type_is_sampler(gtype) || glsl_type_is_texture(gtype), @@ -2347,6 +2383,11 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr) const struct glsl_type *gtype = nir_src_as_deref(intr->src[0])->type; nir_variable *var = nir_intrinsic_get_var(intr, 0); + if (!var) { + assert(ctx->sinfo->is_native_vulkan); + assert(nir_def_is_deref(intr->src[0].ssa)); + var = find_vulkan_deref_var(ctx, nir_def_as_deref(intr->src[0].ssa)); + } SpvId type = get_glsl_type(ctx, gtype, var->data.mode & (nir_var_shader_temp | nir_var_function_temp)); unsigned wrmask = nir_intrinsic_write_mask(intr); if (!glsl_type_is_scalar(gtype) && @@ -2985,7 +3026,7 @@ emit_image_deref_store(struct ntv_context *ctx, nir_intrinsic_instr *intr) nir_alu_type atype; SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); - nir_variable *var = nir_deref_instr_get_variable(deref); + nir_variable *var = find_vulkan_deref_var(ctx, deref); const struct glsl_type *type = glsl_without_array(var->type); SpvId base_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type)); SpvId coord = get_image_coords(ctx, type, &intr->src[1]); @@ -3053,7 +3094,7 @@ emit_image_deref_load(struct ntv_context *ctx, nir_intrinsic_instr *intr) nir_alu_type atype; SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); - nir_variable *var = nir_deref_instr_get_variable(deref); + nir_variable *var = find_vulkan_deref_var(ctx, deref); bool mediump = (var->data.precision == GLSL_PRECISION_MEDIUM || var->data.precision == GLSL_PRECISION_LOW); const struct glsl_type *type = glsl_without_array(var->type); SpvId base_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type)); @@ -3095,7 +3136,7 @@ emit_image_deref_size(struct ntv_context *ctx, nir_intrinsic_instr *intr) nir_alu_type atype; SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); - nir_variable *var = nir_deref_instr_get_variable(deref); + nir_variable *var = find_vulkan_deref_var(ctx, deref); SpvId img_type = find_image_type(ctx, var); const struct glsl_type *type = glsl_without_array(var->type); SpvId img = spirv_builder_emit_load(&ctx->builder, img_type, img_var, false); @@ -3115,7 +3156,7 @@ emit_image_deref_samples(struct ntv_context *ctx, nir_intrinsic_instr *intr) nir_alu_type atype; SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); - nir_variable *var = nir_deref_instr_get_variable(deref); + nir_variable *var = find_vulkan_deref_var(ctx, deref); SpvId img_type = find_image_type(ctx, var); SpvId img = spirv_builder_emit_load(&ctx->builder, img_type, img_var, false); @@ -3131,7 +3172,7 @@ emit_image_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) SpvId param = get_src(ctx, &intr->src[3], &ptype); SpvId img_var = get_src(ctx, &intr->src[0], &atype); nir_deref_instr *deref = nir_src_as_deref(intr->src[0]); - nir_variable *var = nir_deref_instr_get_variable(deref); + nir_variable *var = find_vulkan_deref_var(ctx, deref); const struct glsl_type *type = glsl_without_array(var->type); bool is_ms; type_to_dim(glsl_get_sampler_dim(type), &is_ms); @@ -3591,6 +3632,68 @@ init_sparse_resident(struct ntv_context *ctx) ctx->have_sparse = true; } +static void +emit_load_vulkan_descriptor(struct ntv_context *ctx, nir_intrinsic_instr *intr) +{ + nir_alu_type itype; + SpvId src = get_src(ctx, &intr->src[0], &itype); + store_def(ctx, intr->def.index, src, nir_type_uint); +} + +static void +emit_vulkan_resource_index(struct ntv_context *ctx, nir_intrinsic_instr *intr) +{ + nir_alu_type itype; + SpvId index = get_src(ctx, &intr->src[0], &itype); + if (itype == nir_type_float) + index = emit_bitcast(ctx, get_uvec_type(ctx, 32, 1), index); + + int desc_set = nir_intrinsic_desc_set(intr); + int binding = nir_intrinsic_binding(intr); + nir_variable *var = NULL; + nir_foreach_variable_with_modes(i, ctx->nir, nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_image | nir_var_uniform) { + if (i->data.descriptor_set == desc_set && i->data.binding == binding) { + var = i; + break; + } + } + assert(var); + struct hash_entry *he = _mesa_hash_table_search(ctx->vars, var); + assert(he); + SpvId base = (SpvId)(intptr_t)he->data; + + const struct glsl_type *gtype = glsl_without_array(var->type); + SpvId type = 0; + if (glsl_type_is_bare_sampler(gtype)) { + type = spirv_builder_type_sampler(&ctx->builder); + } else if (glsl_type_is_sampler(gtype) || glsl_type_is_image(gtype) || glsl_type_is_texture(gtype)) { + struct hash_entry *he = _mesa_hash_table_search(&ctx->image_types, var); + assert(he); + type = (SpvId)(intptr_t)he->data; + } else { + assert(var->data.mode & (nir_var_mem_ubo | nir_var_mem_ssbo)); + type = get_bo_struct_type(ctx, var); + } + + if (glsl_type_is_array(var->type)) { + SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder, + get_storage_class(var), + type); + + SpvId indices[] = { + emit_uint_const(ctx, 32, 0), + index + }; + SpvId result = spirv_builder_emit_access_chain(&ctx->builder, + ptr_type, + base, + indices, 2); + store_def(ctx, intr->def.index, result, nir_type_uint); + } else { + store_def(ctx, intr->def.index, base, nir_type_uint); + } +} + static void emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) { @@ -3607,6 +3710,14 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr) emit_store_reg(ctx, intr); break; + case nir_intrinsic_vulkan_resource_index: + emit_vulkan_resource_index(ctx, intr); + break; + + case nir_intrinsic_load_vulkan_descriptor: + emit_load_vulkan_descriptor(ctx, intr); + break; + case nir_intrinsic_terminate: emit_discard(ctx, intr); break; @@ -3987,7 +4098,7 @@ get_tex_srcs(struct ntv_context *ctx, nir_tex_instr *tex, nir_const_value *cv; switch (tex->src[i].src_type) { case nir_tex_src_texture_deref: - var = nir_deref_instr_get_variable(nir_def_as_deref(tex->src[i].src.ssa)); + var = find_vulkan_deref_var(ctx, nir_def_as_deref(tex->src[i].src.ssa)); tex_src->src = get_src(ctx, &tex->src[i].src, &atype); break; case nir_tex_src_sampler_deref: @@ -4088,7 +4199,7 @@ get_tex_srcs(struct ntv_context *ctx, nir_tex_instr *tex, case nir_tex_src_texture_handle: tex_src->src = get_src(ctx, &tex->src[i].src, &atype); - var = *bindless_var = nir_deref_instr_get_variable(nir_src_as_deref(tex->src[i].src)); + var = *bindless_var = find_vulkan_deref_var(ctx, nir_src_as_deref(tex->src[i].src)); break; default: @@ -4402,7 +4513,7 @@ static void emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref) { assert(deref->deref_type == nir_deref_type_array); - nir_variable *var = nir_deref_instr_get_variable(deref); + nir_variable *var = find_vulkan_deref_var(ctx, deref); if (!nir_src_is_always_uniform(deref->arr.index)) { if (deref->modes & nir_var_mem_ubo) @@ -4491,12 +4602,12 @@ static void emit_deref_struct(struct ntv_context *ctx, nir_deref_instr *deref) { assert(deref->deref_type == nir_deref_type_struct); - nir_variable *var = nir_deref_instr_get_variable(deref); - + nir_variable *var = find_vulkan_deref_var(ctx, deref); SpvStorageClass storage_class = get_storage_class(var); SpvId index = emit_uint_const(ctx, 32, deref->strct.index); - SpvId type = (var->data.mode & (nir_var_mem_ubo | nir_var_mem_ssbo)) ? + bool is_zink_bo = !ctx->sinfo->is_native_vulkan && (var->data.mode & (nir_var_mem_ubo | nir_var_mem_ssbo)); + SpvId type = is_zink_bo ? get_bo_array_type(ctx, var) : get_glsl_type(ctx, deref->type, var->data.mode & (nir_var_shader_temp | nir_var_function_temp)); @@ -4513,6 +4624,15 @@ emit_deref_struct(struct ntv_context *ctx, nir_deref_instr *deref) store_def(ctx, deref->def.index, result, get_nir_alu_type(deref->type)); } +static void +emit_deref_cast(struct ntv_context *ctx, nir_deref_instr *deref) +{ + assert(deref->deref_type == nir_deref_type_cast); + nir_alu_type atype; + SpvId src = get_src(ctx, &deref->parent, &atype); + store_def(ctx, deref->def.index, src, nir_type_uint); +} + static void emit_deref(struct ntv_context *ctx, nir_deref_instr *deref) { @@ -4529,6 +4649,10 @@ emit_deref(struct ntv_context *ctx, nir_deref_instr *deref) emit_deref_struct(ctx, deref); break; + case nir_deref_type_cast: + emit_deref_cast(ctx, deref); + break; + default: UNREACHABLE("unexpected deref_type"); } @@ -5358,3 +5482,67 @@ spirv_shader_delete(struct spirv_shader *s) { ralloc_free(s); } + +static void +optimize_nir(struct nir_shader *nir) +{ + bool progress; + do { + progress = false; + + NIR_PASS(progress, nir, nir_lower_vars_to_ssa); + + NIR_PASS(progress, nir, nir_opt_copy_prop); + NIR_PASS(progress, nir, nir_opt_remove_phis); + NIR_PASS(progress, nir, nir_lower_all_phis_to_scalar); + NIR_PASS(progress, nir, nir_opt_dce); + NIR_PASS(progress, nir, nir_opt_dead_cf); + NIR_PASS(progress, nir, nir_opt_cse); + + nir_opt_peephole_select_options peephole_select_options = { + .limit = 64, + .expensive_alu_ok = true, + }; + NIR_PASS(progress, nir, nir_opt_peephole_select, &peephole_select_options); + NIR_PASS(progress, nir, nir_opt_phi_precision); + NIR_PASS(progress, nir, nir_opt_algebraic); + NIR_PASS(progress, nir, nir_opt_constant_folding); + + NIR_PASS(progress, nir, nir_opt_deref); + NIR_PASS(progress, nir, nir_opt_copy_prop_vars); + NIR_PASS(progress, nir, nir_opt_undef); + NIR_PASS(progress, nir, nir_opt_loop); + } while (progress); + + NIR_PASS(progress, nir, nir_opt_shrink_vectors, true); +} + +/* this is the bare minimum required to make vtn shaders work with ntv */ +void +ntv_shader_prepare(nir_shader *nir) +{ + struct nir_lower_compute_system_values_options cs_options = {0}; + NIR_PASS(_, nir, nir_lower_system_values); + NIR_PASS(_, nir, nir_lower_compute_system_values, &cs_options); + NIR_PASS(_, nir, nir_split_per_member_structs); + NIR_PASS(_, nir, nir_lower_returns); + NIR_PASS(_, nir, nir_inline_functions); + optimize_nir(nir); + /* required until phi support is complete */ + NIR_PASS(_, nir, nir_convert_from_ssa, true, false); + nir_foreach_variable_in_shader(var, nir) { + if (nir->info.stage == MESA_SHADER_VERTEX && var->data.mode & nir_var_shader_in) { + if (var->data.location >= VERT_ATTRIB_GENERIC0) + var->data.driver_location = var->data.location - VERT_ATTRIB_GENERIC0; + else + var->data.driver_location = var->data.location; + } else if (var->data.mode & (nir_var_shader_out | nir_var_shader_in)) { + if (var->data.location >= VARYING_SLOT_VAR0) + var->data.driver_location = var->data.location - VARYING_SLOT_VAR0; + else + var->data.driver_location = var->data.location; + } else { + var->data.driver_location = var->data.binding; + } + } +} diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.h b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.h index 8d2c3aeea90..5cb490c7c3a 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.h +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.h @@ -50,6 +50,7 @@ struct ntv_info { bool have_workgroup_memory_explicit_layout; bool broken_arbitary_type_const; bool has_demote_to_helper; + bool is_native_vulkan; //ignore zink-isms struct { uint8_t flush_denorms:3; // 16, 32, 64 uint8_t preserve_denorms:3; // 16, 32, 64 @@ -69,6 +70,9 @@ nir_to_spirv(struct nir_shader *s, const struct ntv_info *sinfo); void spirv_shader_delete(struct spirv_shader *s); +void +ntv_shader_prepare(struct nir_shader *nir); + static inline bool type_is_counter(const struct glsl_type *type) {