diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c index 655dc884382..d07550a6b03 100644 --- a/src/compiler/nir/nir.c +++ b/src/compiler/nir/nir.c @@ -70,6 +70,7 @@ reg_create(void *mem_ctx, struct exec_list *list) list_inithead(®->if_uses); reg->num_components = 0; + reg->bit_size = 32; reg->num_array_elems = 0; reg->is_packed = false; reg->name = NULL; @@ -1325,6 +1326,7 @@ nir_ssa_def_init(nir_instr *instr, nir_ssa_def *def, list_inithead(&def->uses); list_inithead(&def->if_uses); def->num_components = num_components; + def->bit_size = 32; /* FIXME: Add an input paremeter or guess? */ if (instr->block) { nir_function_impl *impl = diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 418682f2caf..8f411793d9d 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -345,6 +345,9 @@ typedef struct nir_register { unsigned num_components; /** < number of vector components */ unsigned num_array_elems; /** < size of array (0 for no array) */ + /* The bit-size of each channel; must be one of 8, 16, 32, or 64 */ + uint8_t bit_size; + /** generic register index. */ unsigned index; @@ -452,6 +455,9 @@ typedef struct nir_ssa_def { struct list_head if_uses; uint8_t num_components; + + /* The bit-size of each channel; must be one of 8, 16, 32, or 64 */ + uint8_t bit_size; } nir_ssa_def; struct nir_src; diff --git a/src/compiler/nir/nir_validate.c b/src/compiler/nir/nir_validate.c index 0c32d5fe07a..9f18d1c33e4 100644 --- a/src/compiler/nir/nir_validate.c +++ b/src/compiler/nir/nir_validate.c @@ -179,9 +179,12 @@ validate_alu_src(nir_alu_instr *instr, unsigned index, validate_state *state) nir_alu_src *src = &instr->src[index]; unsigned num_components; - if (src->src.is_ssa) + unsigned src_bit_size; + if (src->src.is_ssa) { + src_bit_size = src->src.ssa->bit_size; num_components = src->src.ssa->num_components; - else { + } else { + src_bit_size = src->src.reg.reg->bit_size; if (src->src.reg.reg->is_packed) num_components = 4; /* can't check anything */ else @@ -194,6 +197,24 @@ validate_alu_src(nir_alu_instr *instr, unsigned index, validate_state *state) assert(src->swizzle[i] < num_components); } + nir_alu_type src_type = nir_op_infos[instr->op].input_types[index]; + + /* 8-bit float isn't a thing */ + if (nir_alu_type_get_base_type(src_type) == nir_type_float) + assert(src_bit_size == 16 || src_bit_size == 32 || src_bit_size == 64); + + if (nir_alu_type_get_type_size(src_type)) { + /* This source has an explicit bit size */ + assert(nir_alu_type_get_type_size(src_type) == src_bit_size); + } else { + if (!nir_alu_type_get_type_size(nir_op_infos[instr->op].output_type)) { + unsigned dest_bit_size = + instr->dest.dest.is_ssa ? instr->dest.dest.ssa.bit_size + : instr->dest.dest.reg.reg->bit_size; + assert(dest_bit_size == src_bit_size); + } + } + validate_src(&src->src, state); } @@ -263,8 +284,10 @@ validate_dest(nir_dest *dest, validate_state *state) } static void -validate_alu_dest(nir_alu_dest *dest, validate_state *state) +validate_alu_dest(nir_alu_instr *instr, validate_state *state) { + nir_alu_dest *dest = &instr->dest; + unsigned dest_size = dest->dest.is_ssa ? dest->dest.ssa.num_components : dest->dest.reg.reg->num_components; @@ -282,6 +305,17 @@ validate_alu_dest(nir_alu_dest *dest, validate_state *state) assert(nir_op_infos[alu->op].output_type == nir_type_float || !dest->saturate); + unsigned bit_size = dest->dest.is_ssa ? dest->dest.ssa.bit_size + : dest->dest.reg.reg->bit_size; + nir_alu_type type = nir_op_infos[instr->op].output_type; + + /* 8-bit float isn't a thing */ + if (nir_alu_type_get_base_type(type) == nir_type_float) + assert(bit_size == 16 || bit_size == 32 || bit_size == 64); + + assert(nir_alu_type_get_type_size(type) == 0 || + nir_alu_type_get_type_size(type) == bit_size); + validate_dest(&dest->dest, state); } @@ -294,7 +328,7 @@ validate_alu_instr(nir_alu_instr *instr, validate_state *state) validate_alu_src(instr, i, state); } - validate_alu_dest(&instr->dest, state); + validate_alu_dest(instr, state); } static void