zink: infer types from load_const instrs to avoid more bitcasts

this walks to uses list for the ssa def to infer a type from one of the
uses to reduce the need to bitcast

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22934>
This commit is contained in:
Mike Blumenkrantz 2023-05-10 08:51:10 -04:00 committed by Marge Bot
parent 9f6be8effb
commit 208c31b25f

View file

@ -164,6 +164,129 @@ get_nir_alu_type(const struct glsl_type *type)
return nir_alu_type_get_base_type(nir_get_nir_type_for_glsl_base_type(glsl_get_base_type(glsl_without_array_or_matrix(type))));
}
static nir_alu_type
infer_nir_alu_type_from_uses_ssa(nir_ssa_def *ssa, unsigned depth);
static nir_alu_type
infer_nir_alu_type_from_uses_reg(nir_register *reg, unsigned depth);
static nir_alu_type
infer_nir_alu_type_from_use(nir_src *src, unsigned depth)
{
nir_instr *instr = src->parent_instr;
nir_alu_type atype = nir_type_invalid;
switch (instr->type) {
case nir_instr_type_alu: {
nir_alu_instr *alu = nir_instr_as_alu(instr);
if (alu->op == nir_op_bcsel) {
if (nir_srcs_equal(alu->src[0].src, *src)) {
/* special case: the first src in bcsel is always bool */
return nir_type_bool;
}
}
/* ignore typeless ops */
if (alu_op_is_typeless(alu->op)) {
if (alu->dest.dest.is_ssa) {
atype = infer_nir_alu_type_from_uses_ssa(&alu->dest.dest.ssa, depth);
} else {
/* avoid infinite recursion */
if (depth > 10)
break;
if (!src->is_ssa && src->reg.reg == alu->dest.dest.reg.reg)
break;
atype = infer_nir_alu_type_from_uses_reg(alu->dest.dest.reg.reg, ++depth);
}
break;
}
for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
if (!nir_srcs_equal(alu->src[i].src, *src))
continue;
atype = nir_op_infos[alu->op].input_types[i];
break;
}
break;
}
case nir_instr_type_tex: {
nir_tex_instr *tex = nir_instr_as_tex(instr);
for (unsigned i = 0; i < tex->num_srcs; i++) {
if (!nir_srcs_equal(tex->src[i].src, *src))
continue;
switch (tex->src[i].src_type) {
case nir_tex_src_coord:
case nir_tex_src_lod:
if (tex->op == nir_texop_txf ||
tex->op == nir_texop_txf_ms ||
tex->op == nir_texop_txs)
atype = nir_type_int;
else
atype = nir_type_float;
break;
case nir_tex_src_projector:
case nir_tex_src_bias:
case nir_tex_src_min_lod:
case nir_tex_src_comparator:
case nir_tex_src_ddx:
case nir_tex_src_ddy:
atype = nir_type_float;
break;
case nir_tex_src_offset:
case nir_tex_src_ms_index:
case nir_tex_src_texture_offset:
case nir_tex_src_sampler_offset:
case nir_tex_src_sampler_handle:
case nir_tex_src_texture_handle:
atype = nir_type_int;
break;
default:
break;
}
break;
}
break;
}
case nir_instr_type_intrinsic: {
if (nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_load_deref) {
atype = get_nir_alu_type(nir_instr_as_deref(instr)->type);
} else if (nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_store_deref) {
atype = get_nir_alu_type(nir_src_as_deref(nir_instr_as_intrinsic(instr)->src[0])->type);
}
break;
}
default:
break;
}
return nir_alu_type_get_base_type(atype);
}
static nir_alu_type
infer_nir_alu_type_from_uses_ssa(nir_ssa_def *ssa, unsigned depth)
{
nir_alu_type atype = nir_type_invalid;
/* try to infer a type: if it's wrong then whatever, but at least we tried */
nir_foreach_use_including_if(src, ssa) {
if (src->is_if)
return nir_type_bool;
atype = infer_nir_alu_type_from_use(src, depth);
if (atype)
break;
}
return atype ? atype : nir_type_uint;
}
static nir_alu_type
infer_nir_alu_type_from_uses_reg(nir_register *reg, unsigned depth)
{
nir_alu_type atype = nir_type_invalid;
/* try to infer a type: if it's wrong then whatever, but at least we tried */
nir_foreach_use_including_if(src, reg) {
if (src->is_if)
return nir_type_bool;
atype = infer_nir_alu_type_from_use(src, depth);
if (atype)
break;
}
return atype ? atype : nir_type_uint;
}
static SpvId
get_bvec_type(struct ntv_context *ctx, int num_components)
{
@ -2488,22 +2611,37 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
SpvId components[NIR_MAX_VEC_COMPONENTS];
nir_alu_type atype;
if (bit_size == 1) {
atype = nir_type_bool;
for (int i = 0; i < num_components; i++)
components[i] = spirv_builder_const_bool(&ctx->builder,
load_const->value[i].b);
atype = nir_type_bool;
} else {
atype = infer_nir_alu_type_from_uses_ssa(&load_const->def, 0);
for (int i = 0; i < num_components; i++) {
uint64_t tmp = nir_const_value_as_uint(load_const->value[i],
bit_size);
components[i] = emit_uint_const(ctx, bit_size, tmp);
switch (atype) {
case nir_type_uint: {
uint64_t tmp = nir_const_value_as_uint(load_const->value[i], bit_size);
components[i] = emit_uint_const(ctx, bit_size, tmp);
break;
}
case nir_type_int: {
int64_t tmp = nir_const_value_as_int(load_const->value[i], bit_size);
components[i] = emit_int_const(ctx, bit_size, tmp);
break;
}
case nir_type_float: {
double tmp = nir_const_value_as_float(load_const->value[i], bit_size);
components[i] = emit_float_const(ctx, bit_size, tmp);
break;
}
default:
unreachable("this shouldn't happen!");
}
}
atype = nir_type_uint;
}
if (num_components > 1) {
SpvId type = get_vec_from_bit_size(ctx, bit_size,
num_components);
SpvId type = get_alu_type(ctx, atype, num_components, bit_size);
SpvId value = spirv_builder_const_composite(&ctx->builder,
type, components,
num_components);