nir: Add is_null_constant to nir_constant

Indicates that the values contained within are 0s, regardless of
type. Enables some optimizations.

Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23173>
This commit is contained in:
Jesse Natalie 2023-05-18 10:31:50 -07:00 committed by Marge Bot
parent 009d2de88f
commit 4edfb67fd4
7 changed files with 54 additions and 26 deletions

View file

@ -395,6 +395,9 @@ typedef struct nir_constant {
*/
nir_const_value values[NIR_MAX_VEC_COMPONENTS];
/* Indicates all the values are 0s which can enable some optimizations */
bool is_null_constant;
/* we could get this from the var->type but makes clone *much* easier to
* not have to care about the type.
*/

View file

@ -136,6 +136,7 @@ nir_constant_clone(const nir_constant *c, nir_variable *nvar)
nir_constant *nc = ralloc(nvar, nir_constant);
memcpy(nc->values, c->values, sizeof(nc->values));
nc->is_null_constant = c->is_null_constant;
nc->num_elements = c->num_elements;
nc->elements = ralloc_array(nvar, nir_constant *, c->num_elements);
for (unsigned i = 0; i < c->num_elements; i++) {

View file

@ -2530,6 +2530,11 @@ static void
write_constant(void *dst, size_t dst_size,
const nir_constant *c, const struct glsl_type *type)
{
if (c->is_null_constant) {
memset(dst, 0, dst_size);
return;
}
if (glsl_type_is_vector_or_scalar(type)) {
const unsigned num_components = glsl_get_vector_elements(type);
const unsigned bit_size = glsl_get_bit_size(type);

View file

@ -122,6 +122,12 @@ const_value_for_deref(nir_deref_instr *deref)
if (var->constant_initializer == NULL)
goto fail;
if (var->constant_initializer->is_null_constant) {
/* Doesn't matter what casts are in the way, it's all zeros */
nir_deref_path_finish(&path);
return var->constant_initializer->values;
}
nir_constant *c = var->constant_initializer;
nir_const_value *v = NULL; /* Vector value for array-deref-of-vec */

View file

@ -714,9 +714,13 @@ print_var_decl(nir_variable *var, print_state *state)
}
if (var->constant_initializer) {
fprintf(fp, " = { ");
print_constant(var->constant_initializer, var->type, state);
fprintf(fp, " }");
if (var->constant_initializer->is_null_constant) {
fprintf(fp, " = null");
} else {
fprintf(fp, " = { ");
print_constant(var->constant_initializer, var->type, state);
fprintf(fp, " }");
}
}
if (glsl_type_is_sampler(var->type) && var->data.sampler.is_inline_sampler) {
fprintf(fp, " = { %s, %s, %s }",

View file

@ -187,11 +187,15 @@ read_constant(read_ctx *ctx, nir_variable *nvar)
{
nir_constant *c = ralloc(nvar, nir_constant);
static const nir_const_value zero_vals[ARRAY_SIZE(c->values)] = { 0 };
blob_copy_bytes(ctx->blob, (uint8_t *)c->values, sizeof(c->values));
c->is_null_constant = memcmp(c->values, zero_vals, sizeof(c->values)) == 0;
c->num_elements = blob_read_uint32(ctx->blob);
c->elements = ralloc_array(nvar, nir_constant *, c->num_elements);
for (unsigned i = 0; i < c->num_elements; i++)
for (unsigned i = 0; i < c->num_elements; i++) {
c->elements[i] = read_constant(ctx, nvar);
c->is_null_constant &= c->elements[i]->is_null_constant;
}
return c;
}

View file

@ -963,7 +963,7 @@ validate_call_instr(nir_call_instr *instr, validate_state *state)
static void
validate_const_value(nir_const_value *val, unsigned bit_size,
validate_state *state)
bool is_null_constant, validate_state *state)
{
/* In order for block copies to work properly for things like instruction
* comparisons and [de]serialization, we require the unused bits of the
@ -971,24 +971,26 @@ validate_const_value(nir_const_value *val, unsigned bit_size,
*/
nir_const_value cmp_val;
memset(&cmp_val, 0, sizeof(cmp_val));
switch (bit_size) {
case 1:
cmp_val.b = val->b;
break;
case 8:
cmp_val.u8 = val->u8;
break;
case 16:
cmp_val.u16 = val->u16;
break;
case 32:
cmp_val.u32 = val->u32;
break;
case 64:
cmp_val.u64 = val->u64;
break;
default:
validate_assert(state, !"Invalid load_const bit size");
if (!is_null_constant) {
switch (bit_size) {
case 1:
cmp_val.b = val->b;
break;
case 8:
cmp_val.u8 = val->u8;
break;
case 16:
cmp_val.u16 = val->u16;
break;
case 32:
cmp_val.u32 = val->u32;
break;
case 64:
cmp_val.u64 = val->u64;
break;
default:
validate_assert(state, !"Invalid load_const bit size");
}
}
validate_assert(state, memcmp(val, &cmp_val, sizeof(cmp_val)) == 0);
}
@ -999,7 +1001,7 @@ validate_load_const_instr(nir_load_const_instr *instr, validate_state *state)
validate_ssa_def(&instr->def, state);
for (unsigned i = 0; i < instr->def.num_components; i++)
validate_const_value(&instr->value[i], instr->def.bit_size, state);
validate_const_value(&instr->value[i], instr->def.bit_size, false, state);
}
static void
@ -1483,7 +1485,7 @@ validate_constant(nir_constant *c, const struct glsl_type *type,
unsigned num_components = glsl_get_vector_elements(type);
unsigned bit_size = glsl_get_bit_size(type);
for (unsigned i = 0; i < num_components; i++)
validate_const_value(&c->values[i], bit_size, state);
validate_const_value(&c->values[i], bit_size, c->is_null_constant, state);
for (unsigned i = num_components; i < NIR_MAX_VEC_COMPONENTS; i++)
validate_assert(state, c->values[i].u64 == 0);
} else {
@ -1492,11 +1494,14 @@ validate_constant(nir_constant *c, const struct glsl_type *type,
for (unsigned i = 0; i < c->num_elements; i++) {
const struct glsl_type *elem_type = glsl_get_struct_field(type, i);
validate_constant(c->elements[i], elem_type, state);
validate_assert(state, !c->is_null_constant || c->elements[i]->is_null_constant);
}
} else if (glsl_type_is_array_or_matrix(type)) {
const struct glsl_type *elem_type = glsl_get_array_element(type);
for (unsigned i = 0; i < c->num_elements; i++)
for (unsigned i = 0; i < c->num_elements; i++) {
validate_constant(c->elements[i], elem_type, state);
validate_assert(state, !c->is_null_constant || c->elements[i]->is_null_constant);
}
} else {
validate_assert(state, !"Invalid type for nir_constant");
}