ntv: add basic vulkan support

this enables (some) shaders generated by vtn to successfully pass through
ntv and generate valid spirv

the majority of the plumbing is to handle deref casts, which are currently
assumed to originate solely from loading descriptors

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39488>
This commit is contained in:
Mike Blumenkrantz 2026-01-23 10:40:55 -05:00 committed by Marge Bot
parent 92622f7f44
commit e1855dc947
2 changed files with 225 additions and 33 deletions

View file

@ -278,6 +278,37 @@ find_image_type(struct ntv_context *ctx, nir_variable *var)
return he ? (intptr_t)he->data : 0;
}
static nir_variable *
find_vulkan_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
{
nir_variable *var = nir_deref_instr_get_variable(deref);
if (var)
return var;
assert(ctx->sinfo->is_native_vulkan);
nir_deref_instr *parent = nir_deref_instr_parent(deref);
while (parent) {
if (parent->deref_type == nir_deref_type_cast && nir_def_is_intrinsic(parent->parent.ssa)) {
nir_intrinsic_instr *intr = nir_def_as_intrinsic(parent->parent.ssa);
while (intr->intrinsic != nir_intrinsic_vulkan_resource_index)
intr = nir_def_as_intrinsic(intr->src[0].ssa);
int desc_set = nir_intrinsic_desc_set(intr);
int binding = nir_intrinsic_binding(intr);
nir_foreach_variable_with_modes(i, ctx->nir, deref->modes) {
if (i->data.descriptor_set == desc_set && i->data.binding == binding) {
var = i;
break;
}
}
return var;
} else if (parent->deref_type == nir_deref_type_var) {
return parent->var;
} else {
parent = nir_deref_instr_parent(parent);
}
}
return NULL;
}
static SpvScope
get_scope(mesa_scope scope)
{
@ -1390,23 +1421,31 @@ get_bo_struct_type(struct ntv_context *ctx, struct nir_variable *var)
struct hash_entry *he = _mesa_hash_table_search(ctx->bo_struct_types, var);
if (he)
return (SpvId)(uintptr_t)he->data;
const struct glsl_type *bare_type = glsl_without_array(var->type);
unsigned bitsize = glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(bare_type, 0)));
SpvId array_type = get_bo_array_type(ctx, var);
_mesa_hash_table_insert(ctx->bo_array_types, var, (void *)(uintptr_t)array_type);
bool ssbo = var->data.mode == nir_var_mem_ssbo;
SpvId struct_type = 0;
if (ctx->sinfo->is_native_vulkan) {
struct_type = get_glsl_type(ctx, var->type, false);
} else {
const struct glsl_type *bare_type = glsl_without_array(var->type);
unsigned bitsize = glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(bare_type, 0)));
SpvId array_type = get_bo_array_type(ctx, var);
_mesa_hash_table_insert(ctx->bo_array_types, var, (void *)(uintptr_t)array_type);
bool ssbo = var->data.mode == nir_var_mem_ssbo;
// wrap UBO-array in a struct
SpvId runtime_array = 0;
if (ssbo && glsl_get_length(bare_type) > 1) {
const struct glsl_type *last_member = glsl_get_struct_field(bare_type, glsl_get_length(bare_type) - 1);
if (glsl_type_is_unsized_array(last_member)) {
runtime_array = spirv_builder_type_runtime_array(&ctx->builder, get_uvec_type(ctx, bitsize, 1));
spirv_builder_emit_array_stride(&ctx->builder, runtime_array, glsl_get_explicit_stride(last_member));
}
// wrap UBO-array in a struct
SpvId runtime_array = 0;
if (ssbo && glsl_get_length(bare_type) > 1) {
const struct glsl_type *last_member = glsl_get_struct_field(bare_type, glsl_get_length(bare_type) - 1);
if (glsl_type_is_unsized_array(last_member)) {
runtime_array = spirv_builder_type_runtime_array(&ctx->builder, get_uvec_type(ctx, bitsize, 1));
spirv_builder_emit_array_stride(&ctx->builder, runtime_array, glsl_get_explicit_stride(last_member));
}
}
SpvId types[] = {array_type, runtime_array};
struct_type = spirv_builder_type_struct(&ctx->builder, types, 1 + !!runtime_array);
spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
if (runtime_array)
spirv_builder_emit_member_offset(&ctx->builder, struct_type, 1, 0);
}
SpvId types[] = {array_type, runtime_array};
SpvId struct_type = spirv_builder_type_struct(&ctx->builder, types, 1 + !!runtime_array);
if (var->name) {
char struct_name[100];
snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
@ -1415,9 +1454,6 @@ get_bo_struct_type(struct ntv_context *ctx, struct nir_variable *var)
spirv_builder_emit_decoration(&ctx->builder, struct_type,
SpvDecorationBlock);
spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
if (runtime_array)
spirv_builder_emit_member_offset(&ctx->builder, struct_type, 1, 0);
return struct_type;
}
@ -1425,12 +1461,12 @@ get_bo_struct_type(struct ntv_context *ctx, struct nir_variable *var)
static void
emit_bo(struct ntv_context *ctx, struct nir_variable *var, bool aliased)
{
unsigned bitsize = glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(glsl_without_array(var->type), 0)));
bool ssbo = var->data.mode == nir_var_mem_ssbo;
SpvId struct_type = get_bo_struct_type(ctx, var);
_mesa_hash_table_insert(ctx->bo_struct_types, var, (void *)(uintptr_t)struct_type);
bool ssbo = var->data.mode == nir_var_mem_ssbo;
unsigned bitsize = ctx->sinfo->is_native_vulkan ? 32 : glsl_get_bit_size(glsl_get_array_element(glsl_get_struct_field(glsl_without_array(var->type), 0)));
SpvId array_length = emit_uint_const(ctx, 32, glsl_get_length(var->type));
SpvId array_type = spirv_builder_type_array(&ctx->builder, struct_type, array_length);
SpvId array_type = ctx->sinfo->is_native_vulkan ? struct_type : spirv_builder_type_array(&ctx->builder, struct_type, array_length);
SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
ssbo ? SpvStorageClassStorageBuffer : SpvStorageClassUniform,
array_type);
@ -2319,7 +2355,7 @@ emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
SpvId type;
if (glsl_type_is_image(deref->type)) {
nir_variable *var = nir_deref_instr_get_variable(deref);
nir_variable *var = find_vulkan_deref_var(ctx, deref);
const struct glsl_type *gtype = glsl_without_array(var->type);
type = get_image_type(ctx, var,
glsl_type_is_sampler(gtype) || glsl_type_is_texture(gtype),
@ -2347,6 +2383,11 @@ emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
const struct glsl_type *gtype = nir_src_as_deref(intr->src[0])->type;
nir_variable *var = nir_intrinsic_get_var(intr, 0);
if (!var) {
assert(ctx->sinfo->is_native_vulkan);
assert(nir_def_is_deref(intr->src[0].ssa));
var = find_vulkan_deref_var(ctx, nir_def_as_deref(intr->src[0].ssa));
}
SpvId type = get_glsl_type(ctx, gtype, var->data.mode & (nir_var_shader_temp | nir_var_function_temp));
unsigned wrmask = nir_intrinsic_write_mask(intr);
if (!glsl_type_is_scalar(gtype) &&
@ -2985,7 +3026,7 @@ emit_image_deref_store(struct ntv_context *ctx, nir_intrinsic_instr *intr)
nir_alu_type atype;
SpvId img_var = get_src(ctx, &intr->src[0], &atype);
nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
nir_variable *var = nir_deref_instr_get_variable(deref);
nir_variable *var = find_vulkan_deref_var(ctx, deref);
const struct glsl_type *type = glsl_without_array(var->type);
SpvId base_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
SpvId coord = get_image_coords(ctx, type, &intr->src[1]);
@ -3053,7 +3094,7 @@ emit_image_deref_load(struct ntv_context *ctx, nir_intrinsic_instr *intr)
nir_alu_type atype;
SpvId img_var = get_src(ctx, &intr->src[0], &atype);
nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
nir_variable *var = nir_deref_instr_get_variable(deref);
nir_variable *var = find_vulkan_deref_var(ctx, deref);
bool mediump = (var->data.precision == GLSL_PRECISION_MEDIUM || var->data.precision == GLSL_PRECISION_LOW);
const struct glsl_type *type = glsl_without_array(var->type);
SpvId base_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
@ -3095,7 +3136,7 @@ emit_image_deref_size(struct ntv_context *ctx, nir_intrinsic_instr *intr)
nir_alu_type atype;
SpvId img_var = get_src(ctx, &intr->src[0], &atype);
nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
nir_variable *var = nir_deref_instr_get_variable(deref);
nir_variable *var = find_vulkan_deref_var(ctx, deref);
SpvId img_type = find_image_type(ctx, var);
const struct glsl_type *type = glsl_without_array(var->type);
SpvId img = spirv_builder_emit_load(&ctx->builder, img_type, img_var, false);
@ -3115,7 +3156,7 @@ emit_image_deref_samples(struct ntv_context *ctx, nir_intrinsic_instr *intr)
nir_alu_type atype;
SpvId img_var = get_src(ctx, &intr->src[0], &atype);
nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
nir_variable *var = nir_deref_instr_get_variable(deref);
nir_variable *var = find_vulkan_deref_var(ctx, deref);
SpvId img_type = find_image_type(ctx, var);
SpvId img = spirv_builder_emit_load(&ctx->builder, img_type, img_var, false);
@ -3131,7 +3172,7 @@ emit_image_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
SpvId param = get_src(ctx, &intr->src[3], &ptype);
SpvId img_var = get_src(ctx, &intr->src[0], &atype);
nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
nir_variable *var = nir_deref_instr_get_variable(deref);
nir_variable *var = find_vulkan_deref_var(ctx, deref);
const struct glsl_type *type = glsl_without_array(var->type);
bool is_ms;
type_to_dim(glsl_get_sampler_dim(type), &is_ms);
@ -3591,6 +3632,68 @@ init_sparse_resident(struct ntv_context *ctx)
ctx->have_sparse = true;
}
static void
emit_load_vulkan_descriptor(struct ntv_context *ctx, nir_intrinsic_instr *intr)
{
nir_alu_type itype;
SpvId src = get_src(ctx, &intr->src[0], &itype);
store_def(ctx, intr->def.index, src, nir_type_uint);
}
static void
emit_vulkan_resource_index(struct ntv_context *ctx, nir_intrinsic_instr *intr)
{
nir_alu_type itype;
SpvId index = get_src(ctx, &intr->src[0], &itype);
if (itype == nir_type_float)
index = emit_bitcast(ctx, get_uvec_type(ctx, 32, 1), index);
int desc_set = nir_intrinsic_desc_set(intr);
int binding = nir_intrinsic_binding(intr);
nir_variable *var = NULL;
nir_foreach_variable_with_modes(i, ctx->nir, nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_image | nir_var_uniform) {
if (i->data.descriptor_set == desc_set && i->data.binding == binding) {
var = i;
break;
}
}
assert(var);
struct hash_entry *he = _mesa_hash_table_search(ctx->vars, var);
assert(he);
SpvId base = (SpvId)(intptr_t)he->data;
const struct glsl_type *gtype = glsl_without_array(var->type);
SpvId type = 0;
if (glsl_type_is_bare_sampler(gtype)) {
type = spirv_builder_type_sampler(&ctx->builder);
} else if (glsl_type_is_sampler(gtype) || glsl_type_is_image(gtype) || glsl_type_is_texture(gtype)) {
struct hash_entry *he = _mesa_hash_table_search(&ctx->image_types, var);
assert(he);
type = (SpvId)(intptr_t)he->data;
} else {
assert(var->data.mode & (nir_var_mem_ubo | nir_var_mem_ssbo));
type = get_bo_struct_type(ctx, var);
}
if (glsl_type_is_array(var->type)) {
SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
get_storage_class(var),
type);
SpvId indices[] = {
emit_uint_const(ctx, 32, 0),
index
};
SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
ptr_type,
base,
indices, 2);
store_def(ctx, intr->def.index, result, nir_type_uint);
} else {
store_def(ctx, intr->def.index, base, nir_type_uint);
}
}
static void
emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
{
@ -3607,6 +3710,14 @@ emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
emit_store_reg(ctx, intr);
break;
case nir_intrinsic_vulkan_resource_index:
emit_vulkan_resource_index(ctx, intr);
break;
case nir_intrinsic_load_vulkan_descriptor:
emit_load_vulkan_descriptor(ctx, intr);
break;
case nir_intrinsic_terminate:
emit_discard(ctx, intr);
break;
@ -3987,7 +4098,7 @@ get_tex_srcs(struct ntv_context *ctx, nir_tex_instr *tex,
nir_const_value *cv;
switch (tex->src[i].src_type) {
case nir_tex_src_texture_deref:
var = nir_deref_instr_get_variable(nir_def_as_deref(tex->src[i].src.ssa));
var = find_vulkan_deref_var(ctx, nir_def_as_deref(tex->src[i].src.ssa));
tex_src->src = get_src(ctx, &tex->src[i].src, &atype);
break;
case nir_tex_src_sampler_deref:
@ -4088,7 +4199,7 @@ get_tex_srcs(struct ntv_context *ctx, nir_tex_instr *tex,
case nir_tex_src_texture_handle:
tex_src->src = get_src(ctx, &tex->src[i].src, &atype);
var = *bindless_var = nir_deref_instr_get_variable(nir_src_as_deref(tex->src[i].src));
var = *bindless_var = find_vulkan_deref_var(ctx, nir_src_as_deref(tex->src[i].src));
break;
default:
@ -4402,7 +4513,7 @@ static void
emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
{
assert(deref->deref_type == nir_deref_type_array);
nir_variable *var = nir_deref_instr_get_variable(deref);
nir_variable *var = find_vulkan_deref_var(ctx, deref);
if (!nir_src_is_always_uniform(deref->arr.index)) {
if (deref->modes & nir_var_mem_ubo)
@ -4491,12 +4602,12 @@ static void
emit_deref_struct(struct ntv_context *ctx, nir_deref_instr *deref)
{
assert(deref->deref_type == nir_deref_type_struct);
nir_variable *var = nir_deref_instr_get_variable(deref);
nir_variable *var = find_vulkan_deref_var(ctx, deref);
SpvStorageClass storage_class = get_storage_class(var);
SpvId index = emit_uint_const(ctx, 32, deref->strct.index);
SpvId type = (var->data.mode & (nir_var_mem_ubo | nir_var_mem_ssbo)) ?
bool is_zink_bo = !ctx->sinfo->is_native_vulkan && (var->data.mode & (nir_var_mem_ubo | nir_var_mem_ssbo));
SpvId type = is_zink_bo ?
get_bo_array_type(ctx, var) :
get_glsl_type(ctx, deref->type, var->data.mode & (nir_var_shader_temp | nir_var_function_temp));
@ -4513,6 +4624,15 @@ emit_deref_struct(struct ntv_context *ctx, nir_deref_instr *deref)
store_def(ctx, deref->def.index, result, get_nir_alu_type(deref->type));
}
static void
emit_deref_cast(struct ntv_context *ctx, nir_deref_instr *deref)
{
assert(deref->deref_type == nir_deref_type_cast);
nir_alu_type atype;
SpvId src = get_src(ctx, &deref->parent, &atype);
store_def(ctx, deref->def.index, src, nir_type_uint);
}
static void
emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
{
@ -4529,6 +4649,10 @@ emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
emit_deref_struct(ctx, deref);
break;
case nir_deref_type_cast:
emit_deref_cast(ctx, deref);
break;
default:
UNREACHABLE("unexpected deref_type");
}
@ -5358,3 +5482,67 @@ spirv_shader_delete(struct spirv_shader *s)
{
ralloc_free(s);
}
static void
optimize_nir(struct nir_shader *nir)
{
bool progress;
do {
progress = false;
NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
NIR_PASS(progress, nir, nir_opt_copy_prop);
NIR_PASS(progress, nir, nir_opt_remove_phis);
NIR_PASS(progress, nir, nir_lower_all_phis_to_scalar);
NIR_PASS(progress, nir, nir_opt_dce);
NIR_PASS(progress, nir, nir_opt_dead_cf);
NIR_PASS(progress, nir, nir_opt_cse);
nir_opt_peephole_select_options peephole_select_options = {
.limit = 64,
.expensive_alu_ok = true,
};
NIR_PASS(progress, nir, nir_opt_peephole_select, &peephole_select_options);
NIR_PASS(progress, nir, nir_opt_phi_precision);
NIR_PASS(progress, nir, nir_opt_algebraic);
NIR_PASS(progress, nir, nir_opt_constant_folding);
NIR_PASS(progress, nir, nir_opt_deref);
NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
NIR_PASS(progress, nir, nir_opt_undef);
NIR_PASS(progress, nir, nir_opt_loop);
} while (progress);
NIR_PASS(progress, nir, nir_opt_shrink_vectors, true);
}
/* this is the bare minimum required to make vtn shaders work with ntv */
void
ntv_shader_prepare(nir_shader *nir)
{
struct nir_lower_compute_system_values_options cs_options = {0};
NIR_PASS(_, nir, nir_lower_system_values);
NIR_PASS(_, nir, nir_lower_compute_system_values, &cs_options);
NIR_PASS(_, nir, nir_split_per_member_structs);
NIR_PASS(_, nir, nir_lower_returns);
NIR_PASS(_, nir, nir_inline_functions);
optimize_nir(nir);
/* required until phi support is complete */
NIR_PASS(_, nir, nir_convert_from_ssa, true, false);
nir_foreach_variable_in_shader(var, nir) {
if (nir->info.stage == MESA_SHADER_VERTEX && var->data.mode & nir_var_shader_in) {
if (var->data.location >= VERT_ATTRIB_GENERIC0)
var->data.driver_location = var->data.location - VERT_ATTRIB_GENERIC0;
else
var->data.driver_location = var->data.location;
} else if (var->data.mode & (nir_var_shader_out | nir_var_shader_in)) {
if (var->data.location >= VARYING_SLOT_VAR0)
var->data.driver_location = var->data.location - VARYING_SLOT_VAR0;
else
var->data.driver_location = var->data.location;
} else {
var->data.driver_location = var->data.binding;
}
}
}

View file

@ -50,6 +50,7 @@ struct ntv_info {
bool have_workgroup_memory_explicit_layout;
bool broken_arbitary_type_const;
bool has_demote_to_helper;
bool is_native_vulkan; //ignore zink-isms
struct {
uint8_t flush_denorms:3; // 16, 32, 64
uint8_t preserve_denorms:3; // 16, 32, 64
@ -69,6 +70,9 @@ nir_to_spirv(struct nir_shader *s, const struct ntv_info *sinfo);
void
spirv_shader_delete(struct spirv_shader *s);
void
ntv_shader_prepare(struct nir_shader *nir);
static inline bool
type_is_counter(const struct glsl_type *type)
{