From 4edfb67fd4911fea127a3ee070dfd1eb8898a76b Mon Sep 17 00:00:00 2001 From: Jesse Natalie Date: Thu, 18 May 2023 10:31:50 -0700 Subject: [PATCH] nir: Add is_null_constant to nir_constant Indicates that the values contained within are 0s, regardless of type. Enables some optimizations. Reviewed-by: Alyssa Rosenzweig Part-of: --- src/compiler/nir/nir.h | 3 ++ src/compiler/nir/nir_clone.c | 1 + src/compiler/nir/nir_lower_io.c | 5 +++ src/compiler/nir/nir_opt_constant_folding.c | 6 +++ src/compiler/nir/nir_print.c | 10 +++-- src/compiler/nir/nir_serialize.c | 6 ++- src/compiler/nir/nir_validate.c | 49 ++++++++++++--------- 7 files changed, 54 insertions(+), 26 deletions(-) diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 0677930f661..cf7242f1a17 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -395,6 +395,9 @@ typedef struct nir_constant { */ nir_const_value values[NIR_MAX_VEC_COMPONENTS]; + /* Indicates all the values are 0s which can enable some optimizations */ + bool is_null_constant; + /* we could get this from the var->type but makes clone *much* easier to * not have to care about the type. */ diff --git a/src/compiler/nir/nir_clone.c b/src/compiler/nir/nir_clone.c index 84e1e40acec..5774cde21f1 100644 --- a/src/compiler/nir/nir_clone.c +++ b/src/compiler/nir/nir_clone.c @@ -136,6 +136,7 @@ nir_constant_clone(const nir_constant *c, nir_variable *nvar) nir_constant *nc = ralloc(nvar, nir_constant); memcpy(nc->values, c->values, sizeof(nc->values)); + nc->is_null_constant = c->is_null_constant; nc->num_elements = c->num_elements; nc->elements = ralloc_array(nvar, nir_constant *, c->num_elements); for (unsigned i = 0; i < c->num_elements; i++) { diff --git a/src/compiler/nir/nir_lower_io.c b/src/compiler/nir/nir_lower_io.c index ad83e9e0138..9e6e021e3c4 100644 --- a/src/compiler/nir/nir_lower_io.c +++ b/src/compiler/nir/nir_lower_io.c @@ -2530,6 +2530,11 @@ static void write_constant(void *dst, size_t dst_size, const nir_constant *c, const struct glsl_type *type) { + if (c->is_null_constant) { + memset(dst, 0, dst_size); + return; + } + if (glsl_type_is_vector_or_scalar(type)) { const unsigned num_components = glsl_get_vector_elements(type); const unsigned bit_size = glsl_get_bit_size(type); diff --git a/src/compiler/nir/nir_opt_constant_folding.c b/src/compiler/nir/nir_opt_constant_folding.c index 4b27bc36081..945055f5451 100644 --- a/src/compiler/nir/nir_opt_constant_folding.c +++ b/src/compiler/nir/nir_opt_constant_folding.c @@ -122,6 +122,12 @@ const_value_for_deref(nir_deref_instr *deref) if (var->constant_initializer == NULL) goto fail; + if (var->constant_initializer->is_null_constant) { + /* Doesn't matter what casts are in the way, it's all zeros */ + nir_deref_path_finish(&path); + return var->constant_initializer->values; + } + nir_constant *c = var->constant_initializer; nir_const_value *v = NULL; /* Vector value for array-deref-of-vec */ diff --git a/src/compiler/nir/nir_print.c b/src/compiler/nir/nir_print.c index bb9588a3800..3a7864fb148 100644 --- a/src/compiler/nir/nir_print.c +++ b/src/compiler/nir/nir_print.c @@ -714,9 +714,13 @@ print_var_decl(nir_variable *var, print_state *state) } if (var->constant_initializer) { - fprintf(fp, " = { "); - print_constant(var->constant_initializer, var->type, state); - fprintf(fp, " }"); + if (var->constant_initializer->is_null_constant) { + fprintf(fp, " = null"); + } else { + fprintf(fp, " = { "); + print_constant(var->constant_initializer, var->type, state); + fprintf(fp, " }"); + } } if (glsl_type_is_sampler(var->type) && var->data.sampler.is_inline_sampler) { fprintf(fp, " = { %s, %s, %s }", diff --git a/src/compiler/nir/nir_serialize.c b/src/compiler/nir/nir_serialize.c index 862f2894d81..3004f1abb48 100644 --- a/src/compiler/nir/nir_serialize.c +++ b/src/compiler/nir/nir_serialize.c @@ -187,11 +187,15 @@ read_constant(read_ctx *ctx, nir_variable *nvar) { nir_constant *c = ralloc(nvar, nir_constant); + static const nir_const_value zero_vals[ARRAY_SIZE(c->values)] = { 0 }; blob_copy_bytes(ctx->blob, (uint8_t *)c->values, sizeof(c->values)); + c->is_null_constant = memcmp(c->values, zero_vals, sizeof(c->values)) == 0; c->num_elements = blob_read_uint32(ctx->blob); c->elements = ralloc_array(nvar, nir_constant *, c->num_elements); - for (unsigned i = 0; i < c->num_elements; i++) + for (unsigned i = 0; i < c->num_elements; i++) { c->elements[i] = read_constant(ctx, nvar); + c->is_null_constant &= c->elements[i]->is_null_constant; + } return c; } diff --git a/src/compiler/nir/nir_validate.c b/src/compiler/nir/nir_validate.c index b30491650e4..81f63046066 100644 --- a/src/compiler/nir/nir_validate.c +++ b/src/compiler/nir/nir_validate.c @@ -963,7 +963,7 @@ validate_call_instr(nir_call_instr *instr, validate_state *state) static void validate_const_value(nir_const_value *val, unsigned bit_size, - validate_state *state) + bool is_null_constant, validate_state *state) { /* In order for block copies to work properly for things like instruction * comparisons and [de]serialization, we require the unused bits of the @@ -971,24 +971,26 @@ validate_const_value(nir_const_value *val, unsigned bit_size, */ nir_const_value cmp_val; memset(&cmp_val, 0, sizeof(cmp_val)); - switch (bit_size) { - case 1: - cmp_val.b = val->b; - break; - case 8: - cmp_val.u8 = val->u8; - break; - case 16: - cmp_val.u16 = val->u16; - break; - case 32: - cmp_val.u32 = val->u32; - break; - case 64: - cmp_val.u64 = val->u64; - break; - default: - validate_assert(state, !"Invalid load_const bit size"); + if (!is_null_constant) { + switch (bit_size) { + case 1: + cmp_val.b = val->b; + break; + case 8: + cmp_val.u8 = val->u8; + break; + case 16: + cmp_val.u16 = val->u16; + break; + case 32: + cmp_val.u32 = val->u32; + break; + case 64: + cmp_val.u64 = val->u64; + break; + default: + validate_assert(state, !"Invalid load_const bit size"); + } } validate_assert(state, memcmp(val, &cmp_val, sizeof(cmp_val)) == 0); } @@ -999,7 +1001,7 @@ validate_load_const_instr(nir_load_const_instr *instr, validate_state *state) validate_ssa_def(&instr->def, state); for (unsigned i = 0; i < instr->def.num_components; i++) - validate_const_value(&instr->value[i], instr->def.bit_size, state); + validate_const_value(&instr->value[i], instr->def.bit_size, false, state); } static void @@ -1483,7 +1485,7 @@ validate_constant(nir_constant *c, const struct glsl_type *type, unsigned num_components = glsl_get_vector_elements(type); unsigned bit_size = glsl_get_bit_size(type); for (unsigned i = 0; i < num_components; i++) - validate_const_value(&c->values[i], bit_size, state); + validate_const_value(&c->values[i], bit_size, c->is_null_constant, state); for (unsigned i = num_components; i < NIR_MAX_VEC_COMPONENTS; i++) validate_assert(state, c->values[i].u64 == 0); } else { @@ -1492,11 +1494,14 @@ validate_constant(nir_constant *c, const struct glsl_type *type, for (unsigned i = 0; i < c->num_elements; i++) { const struct glsl_type *elem_type = glsl_get_struct_field(type, i); validate_constant(c->elements[i], elem_type, state); + validate_assert(state, !c->is_null_constant || c->elements[i]->is_null_constant); } } else if (glsl_type_is_array_or_matrix(type)) { const struct glsl_type *elem_type = glsl_get_array_element(type); - for (unsigned i = 0; i < c->num_elements; i++) + for (unsigned i = 0; i < c->num_elements; i++) { validate_constant(c->elements[i], elem_type, state); + validate_assert(state, !c->is_null_constant || c->elements[i]->is_null_constant); + } } else { validate_assert(state, !"Invalid type for nir_constant"); }