zink: do not convert bools to/from uint

Since bools are the only 1-bit type, we always know if an SSA-def is a
bool or not. So we don't need to marshal it to uint.

So let's simplify the code a bit here.

Tested-by: Marge Bot <https://gitlab.freedesktop.org/mesa/mesa/merge_requests/3763>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/merge_requests/3763>
This commit is contained in:
Erik Faye-Lund 2020-02-10 15:45:22 +01:00 committed by Marge Bot
parent 4d016de250
commit 9903f10636

View file

@ -140,9 +140,9 @@ get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
static SpvId
get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
{
assert(bit_size == 1 || bit_size == 32); // only 32-bit ints supported so far
assert(bit_size == 32); // only 32-bit ints supported so far
SpvId int_type = spirv_builder_type_int(&ctx->builder, MAX2(bit_size, 32));
SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
if (num_components > 1)
return spirv_builder_type_vector(&ctx->builder, int_type,
num_components);
@ -154,9 +154,9 @@ get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
static SpvId
get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
{
assert(bit_size == 1 || bit_size == 32); // only 32-bit uints supported so far
assert(bit_size == 32); // only 32-bit uints supported so far
SpvId uint_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
if (num_components > 1)
return spirv_builder_type_vector(&ctx->builder, uint_type,
num_components);
@ -168,8 +168,8 @@ get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_component
static SpvId
get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
{
return get_uvec_type(ctx, nir_dest_bit_size(*dest),
nir_dest_num_components(*dest));
unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
}
static SpvId
@ -601,7 +601,9 @@ get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
int bit_size = nir_src_bit_size(alu->src[src].src);
assert(bit_size == 1 || bit_size == 32);
SpvId raw_type = spirv_builder_type_uint(&ctx->builder, MAX2(bit_size, 32));
SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
spirv_builder_type_uint(&ctx->builder, bit_size);
if (used_channels == 1) {
uint32_t indices[] = { alu->src[src].swizzle[0] };
return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
@ -655,15 +657,6 @@ emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
}
static SpvId
bvec_to_uvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
{
SpvId otype = get_uvec_type(ctx, 32, num_components);
SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
SpvId one = get_uvec_constant(ctx, 32, num_components, UINT32_MAX);
return emit_select(ctx, otype, value, one, zero);
}
static SpvId
uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
{
@ -725,22 +718,22 @@ store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type t
unsigned num_components = nir_dest_num_components(*dest);
unsigned bit_size = nir_dest_bit_size(*dest);
switch (nir_alu_type_get_base_type(type)) {
case nir_type_bool:
assert(bit_size == 1);
result = bvec_to_uvec(ctx, result, num_components);
break;
if (bit_size != 1) {
switch (nir_alu_type_get_base_type(type)) {
case nir_type_bool:
assert("bool should have bit-size 1");
case nir_type_uint:
break; /* nothing to do! */
case nir_type_uint:
break; /* nothing to do! */
case nir_type_int:
case nir_type_float:
result = bitcast_to_uvec(ctx, result, bit_size, num_components);
break;
case nir_type_int:
case nir_type_float:
result = bitcast_to_uvec(ctx, result, bit_size, num_components);
break;
default:
unreachable("unsupported nir_alu_type");
default:
unreachable("unsupported nir_alu_type");
}
}
store_dest_raw(ctx, dest, result);
@ -874,22 +867,25 @@ get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
unsigned bit_size = nir_src_bit_size(alu->src[src].src);
nir_alu_type type = nir_op_infos[alu->op].input_types[src];
switch (nir_alu_type_get_base_type(type)) {
case nir_type_bool:
assert(bit_size == 1);
return uvec_to_bvec(ctx, raw_value, num_components);
case nir_type_int:
return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
case nir_type_uint:
if (bit_size == 1)
return raw_value;
else {
switch (nir_alu_type_get_base_type(type)) {
case nir_type_bool:
unreachable("bool should have bit-size 1");
case nir_type_float:
return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
case nir_type_int:
return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
default:
unreachable("unknown nir_alu_type");
case nir_type_uint:
return raw_value;
case nir_type_float:
return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
default:
unreachable("unknown nir_alu_type");
}
}
}
@ -907,9 +903,12 @@ get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
unsigned num_components = nir_dest_num_components(*dest);
unsigned bit_size = nir_dest_bit_size(*dest);
if (bit_size == 1)
return get_bvec_type(ctx, num_components);
switch (nir_alu_type_get_base_type(type)) {
case nir_type_bool:
return get_bvec_type(ctx, num_components);
unreachable("bool should have bit-size 1");
case nir_type_int:
return get_ivec_type(ctx, bit_size, num_components);
@ -1231,9 +1230,6 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
}
if (bit_size == 1)
constant = bvec_to_uvec(ctx, constant, num_components);
store_ssa_def(ctx, &load_const->def, constant);
}
@ -1283,6 +1279,9 @@ emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
num_components);
}
if (nir_dest_bit_size(intr->dest) == 1)
result = uvec_to_bvec(ctx, result, num_components);
store_dest(ctx, &intr->dest, result, nir_type_uint);
} else
unreachable("uniform-addressing not yet supported");
@ -1767,10 +1766,8 @@ emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
static SpvId
get_src_bool(struct ntv_context *ctx, nir_src *src)
{
SpvId def = get_src(ctx, src);
assert(nir_src_bit_size(*src) == 1);
unsigned num_components = nir_src_num_components(*src);
return uvec_to_bvec(ctx, def, num_components);
return get_src(ctx, src);
}
static void