nir/opt_large_constants: optimize constant arrays with just two different values
Some checks are pending
macOS-CI / macOS-CI (dri) (push) Waiting to run
macOS-CI / macOS-CI (xlib) (push) Waiting to run

Often, games just use arrays to select between 1.0 and 0.0 or -1.0.

In the case where all values are the same except one index, use a
compare instead of a shift. It's impossible to optimize the shift to
just a compare because of NIR's SM5 shift semantics, but when we know the
array length, it works just fine.

Foz-DB Navi21:
Totals from 3393 (2.96% of 114627) affected shaders:
MaxWaves: 87039 -> 87087 (+0.06%)
Instrs: 4991034 -> 4977962 (-0.26%); split: -0.28%, +0.02%
CodeSize: 27505196 -> 27509988 (+0.02%); split: -0.08%, +0.10%
VGPRs: 156216 -> 154720 (-0.96%)
SpillSGPRs: 812 -> 801 (-1.35%); split: -1.60%, +0.25%
Latency: 38221096 -> 38207053 (-0.04%); split: -0.10%, +0.06%
InvThroughput: 9518564 -> 9469903 (-0.51%); split: -0.52%, +0.01%
VClause: 121340 -> 121370 (+0.02%); split: -0.05%, +0.07%
SClause: 127822 -> 127996 (+0.14%); split: -0.01%, +0.14%
Copies: 437743 -> 437832 (+0.02%); split: -0.40%, +0.43%
Branches: 173910 -> 173893 (-0.01%); split: -0.17%, +0.16%
PreSGPRs: 147137 -> 147957 (+0.56%); split: -0.01%, +0.57%
PreVGPRs: 126313 -> 126296 (-0.01%); split: -0.09%, +0.08%
VALU: 3309713 -> 3288169 (-0.65%); split: -0.66%, +0.01%
SALU: 762369 -> 770904 (+1.12%); split: -0.03%, +1.15%
VMEM: 182394 -> 182392 (-0.00%)
SMEM: 201777 -> 201801 (+0.01%); split: -0.00%, +0.01%

Reviewed-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40539>
This commit is contained in:
Georg Lehmann 2026-03-20 15:28:26 +01:00 committed by Marge Bot
parent 98d486ea77
commit 0975e1513a
2 changed files with 152 additions and 38 deletions

View file

@ -110,13 +110,31 @@ write_const_values(void *dst, const nir_const_value *src,
}
}
typedef enum small_constant_encoding {
SMALL_CONST_INT,
SMALL_CONST_FLOAT,
SMALL_CONST_BCSEL,
} small_constant_encoding;
struct small_constant {
uint64_t data;
int64_t min;
uint32_t bit_size;
bool is_float;
uint32_t denom;
uint32_t bit_stride;
uint32_t bit_size;
small_constant_encoding encoding;
union {
/* int/float */
struct {
int64_t min;
uint32_t denom;
};
/* bcsel */
struct {
uint64_t sel_true;
uint64_t sel_false;
};
};
};
struct var_info {
@ -226,11 +244,49 @@ handle_constant_store(void *mem_ctx, struct var_info *info,
#define NIR_SMALL_CONSTANT_MAX_ABS_VALUE 255
static bool
get_small_constant_bcsel(struct small_constant *info, uint32_t array_len,
uint32_t bit_size, nir_const_value *values)
{
nir_const_value *other = NULL;
uint64_t data = 0;
for (unsigned i = 1; i < array_len; i++) {
uint64_t val = nir_const_value_as_uint(values[i], bit_size);
if (nir_const_value_as_uint(values[0], bit_size) == val)
continue;
if (other && nir_const_value_as_uint(*other, bit_size) != val)
return false;
other = &values[i];
data |= BITFIELD64_BIT(i);
}
info->sel_false = nir_const_value_as_uint(values[0], bit_size);
if (other)
info->sel_true = nir_const_value_as_uint(*other, bit_size);
else
info->sel_true = info->sel_false;
if (util_bitcount64(data) * 2 > array_len) {
data = ~data & BITFIELD64_MASK(array_len);
SWAP(info->sel_true, info->sel_false);
}
info->data = data;
info->bit_size = array_len > 32 ? 64 : 32;
info->encoding = SMALL_CONST_BCSEL;
info->bit_stride = 1;
return true;
}
static bool
get_small_constant_component(const nir_shader_compiler_options *options,
struct small_constant *info, uint32_t array_len,
uint32_t bit_size, nir_const_value *values)
{
if (get_small_constant_bcsel(info, array_len, bit_size, values))
return true;
int64_t min = INT64_MAX;
bool is_float = true;
@ -285,9 +341,7 @@ get_small_constant_component(const nir_shader_compiler_options *options,
}
}
if (bit_size == 1) {
min = 0;
} else if (!is_float) {
if (!is_float) {
min = INT64_MAX;
for (unsigned i = 0; i < array_len; i++) {
int64_t integer = nir_const_value_as_int(values[i], bit_size);
@ -301,8 +355,6 @@ get_small_constant_component(const nir_shader_compiler_options *options,
if (is_float)
i64_elem = nir_const_value_as_float(values[i], bit_size) * denom;
else if (bit_size == 1)
i64_elem = nir_const_value_as_uint(values[i], bit_size);
else
i64_elem = nir_const_value_as_int(values[i], bit_size);
@ -327,8 +379,6 @@ get_small_constant_component(const nir_shader_compiler_options *options,
if (is_float)
i64_elem = nir_const_value_as_float(values[i], bit_size) * denom;
else if (bit_size == 1)
i64_elem = nir_const_value_as_uint(values[i], bit_size);
else
i64_elem = nir_const_value_as_int(values[i], bit_size);
@ -342,7 +392,7 @@ get_small_constant_component(const nir_shader_compiler_options *options,
/* Limit bit_size >= 32 to avoid unnecessary conversions. */
info->bit_size = MAX2(util_next_power_of_two(used_bits * array_len), 32);
info->min = min;
info->is_float = is_float;
info->encoding = is_float ? SMALL_CONST_FLOAT : SMALL_CONST_INT;
info->denom = denom;
info->bit_stride = used_bits;
return true;
@ -405,6 +455,25 @@ build_small_constant_load(nir_builder *b, nir_deref_instr *deref,
for (unsigned c = 0; c < info->num_components; c++) {
const struct small_constant *constant = &info->small_constant[c];
if (constant->encoding == SMALL_CONST_BCSEL) {
assert(constant->bit_stride == 1);
if (util_is_power_of_two_nonzero64(constant->data)) {
ret[c] = nir_ieq_imm(b, index, ffsll(constant->data) - 1);
} else {
nir_def *imm = nir_imm_intN_t(b, constant->data, constant->bit_size);
ret[c] = nir_ushr(b, imm, index);
ret[c] = nir_test_mask(b, ret[c], 0x1);
}
nir_def *sel_true = nir_imm_intN_t(b, constant->sel_true, bit_size);
nir_def *sel_false = nir_imm_intN_t(b, constant->sel_false, bit_size);
ret[c] = nir_bcsel(b, ret[c], sel_true, sel_false);
continue;
}
nir_def *imm = nir_imm_intN_t(b, constant->data, constant->bit_size);
nir_def *shift = nir_imul_imm(b, index, constant->bit_stride);
@ -416,27 +485,21 @@ build_small_constant_load(nir_builder *b, nir_deref_instr *deref,
if (ret[c]->bit_size == 64)
ret[c] = nir_unpack_64_2x32_split_x(b, ret[c]);
if (bit_size == 64 && !constant->is_float)
if (bit_size == 64 && constant->encoding == SMALL_CONST_INT)
ret[c] = nir_u2u64(b, ret[c]);
ret[c] = nir_iadd_imm(b, ret[c], constant->min);
if (bit_size < 8) {
/* Booleans are special-cased to be 32-bit */
assert(glsl_type_is_boolean(deref->type));
ret[c] = nir_ine_imm(b, ret[c], 0);
} else {
if (constant->is_float) {
if (constant->min >= 0)
ret[c] = nir_u2fN(b, ret[c], bit_size);
else
ret[c] = nir_i2fN(b, ret[c], bit_size);
if (constant->encoding == SMALL_CONST_FLOAT) {
if (constant->min >= 0)
ret[c] = nir_u2fN(b, ret[c], bit_size);
else
ret[c] = nir_i2fN(b, ret[c], bit_size);
if (constant->denom != 1)
ret[c] = nir_fmul_imm(b, ret[c], 1.0f / (float)constant->denom);
} else {
ret[c] = nir_u2uN(b, ret[c], bit_size);
}
if (constant->denom != 1)
ret[c] = nir_fmul_imm(b, ret[c], 1.0f / (float)constant->denom);
} else {
ret[c] = nir_u2uN(b, ret[c], bit_size);
}
}

View file

@ -133,16 +133,19 @@ TEST_F(nir_large_constants_test, small_bool_array)
decl_function main () (entrypoint)
impl main {
block b0: // preds:
32 %0 = @load_workgroup_index
32 %1 = load_const (0x000000aa = 170)
32 %2 = ushr %1 (0xaa), %0
32 %3 = load_const (0x00000001)
32 %4 = iand %2, %3 (0x1)
32 %5 = load_const (0x00000000)
1 %6 = ine %4, %5 (0x0)
@use (%6)
// succs: b1
block b0: // preds:
32 %0 = @load_workgroup_index
32 %1 = load_const (0x000000aa = 170)
32 %2 = ushr %1 (0xaa), %0
32 %3 = load_const (0x00000001)
32 %4 = iand %2, %3 (0x1)
32 %5 = load_const (0x00000000)
1 %6 = ine %4, %5 (0x0)
1 %7 = load_const (true)
1 %8 = load_const (false)
1 %9 = bcsel %6, %7 (true), %8 (false)
@use (%9)
// succs: b1
block b1:
}
)"));
@ -334,3 +337,51 @@ TEST_F(nir_large_constants_test, small_fraction_array)
}
)"));
}
TEST_F(nir_large_constants_test, bcsel_vec)
{
uint32_t length = 4;
array = nir_local_variable_create(b->impl, glsl_array_type(glsl_vec4_type(), length, 0), "array");
for (uint32_t i = 0; i < length; i++)
nir_store_array_var_imm(b, array, i, nir_imm_vec4(b, i == 0, i == 1, i == 2, i == 3), 0xf);
run_test();
check_nir_string(NIR_REFERENCE_SHADER(R"(
shader: MESA_SHADER_COMPUTE
name: nir_large_constants_test
workgroup_size: 1, 1, 1
max_subgroup_size: 128
min_subgroup_size: 1
decl_function main () (entrypoint)
impl main {
block b0: // preds:
32 %0 = @load_workgroup_index
32 %1 = load_const (0x00000000)
1 %2 = ieq %0, %1 (0x0)
32 %3 = load_const (0x3f800000 = 1.000000 = 1065353216)
32 %4 = load_const (0x00000000 = 0.000000)
32 %5 = bcsel %2, %3 (0x3f800000), %4 (0x0)
32 %6 = load_const (0x00000001)
1 %7 = ieq %0, %6 (0x1)
32 %8 = load_const (0x3f800000 = 1.000000 = 1065353216)
32 %9 = load_const (0x00000000 = 0.000000)
32 %10 = bcsel %7, %8 (0x3f800000), %9 (0x0)
32 %11 = load_const (0x00000002)
1 %12 = ieq %0, %11 (0x2)
32 %13 = load_const (0x3f800000 = 1.000000 = 1065353216)
32 %14 = load_const (0x00000000 = 0.000000)
32 %15 = bcsel %12, %13 (0x3f800000), %14 (0x0)
32 %16 = load_const (0x00000003)
1 %17 = ieq %0, %16 (0x3)
32 %18 = load_const (0x3f800000 = 1.000000 = 1065353216)
32 %19 = load_const (0x00000000 = 0.000000)
32 %20 = bcsel %17, %18 (0x3f800000), %19 (0x0)
32x4 %21 = vec4 %5, %10, %15, %20
@use (%21)
// succs: b1
block b1:
}
)"));
}