diff --git a/src/compiler/nir/meson.build b/src/compiler/nir/meson.build index 0ca71ad77a9..42ff91799df 100644 --- a/src/compiler/nir/meson.build +++ b/src/compiler/nir/meson.build @@ -258,6 +258,7 @@ files_libnir = files( 'nir_opt_memcpy.c', 'nir_opt_move.c', 'nir_opt_move_discards_to_top.c', + 'nir_opt_mqsad.c', 'nir_opt_non_uniform_access.c', 'nir_opt_offsets.c', 'nir_opt_peephole_select.c', diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 9b8d39d00bb..f6871466fce 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -3908,6 +3908,9 @@ typedef struct nir_shader_compiler_options { bool has_rotate16; bool has_rotate32; + /** Backend supports shfr */ + bool has_shfr32; + /** Backend supports ternary addition */ bool has_iadd3; @@ -6425,6 +6428,8 @@ bool nir_opt_gcm(nir_shader *shader, bool value_number); bool nir_opt_idiv_const(nir_shader *shader, unsigned min_bit_size); +bool nir_opt_mqsad(nir_shader *shader); + typedef enum { nir_opt_if_optimize_phi_true_false = (1 << 0), nir_opt_if_avoid_64bit_phis = (1 << 1), diff --git a/src/compiler/nir/nir_lower_alu_width.c b/src/compiler/nir/nir_lower_alu_width.c index 0d3f4e9feb3..b8ee78a2005 100644 --- a/src/compiler/nir/nir_lower_alu_width.c +++ b/src/compiler/nir/nir_lower_alu_width.c @@ -232,6 +232,7 @@ lower_alu_instr_width(nir_builder *b, nir_instr *instr, void *_data) case nir_op_unpack_snorm_4x8: case nir_op_unpack_unorm_2x16: case nir_op_unpack_snorm_2x16: + case nir_op_mqsad_4x8: /* There is no scalar version of these ops, unless we were to break it * down to bitshifts and math (which is definitely not intended). */ diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py index ef8c8fe28fe..7bd7a776b11 100644 --- a/src/compiler/nir/nir_opcodes.py +++ b/src/compiler/nir/nir_opcodes.py @@ -896,6 +896,12 @@ opcode("uror", 0, tuint, [0, 0], [tuint, tuint32], False, "", """ (src0 << (-src1 & rotate_mask)); """) +opcode("shfr", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32], False, "", """ + uint32_t rotate_mask = sizeof(src0) * 8 - 1; + dst = (src1 >> (src2 & rotate_mask)) | + (src0 << (-src2 & rotate_mask)); +""") + bitwise_description = """ Bitwise {0}, also used as a boolean {0} for hardware supporting integers. """ @@ -1141,6 +1147,14 @@ then add them together. There is also a third source which is a 32-bit unsigned integer and added to the result. """) +opcode("mqsad_4x8", 4, tuint32, [1, 2, 4], [tuint32, tuint32, tuint32], False, "", """ +uint64_t src = src1.x | ((uint64_t)src1.y << 32); +dst.x = msad(src0.x, src, src2.x); +dst.y = msad(src0.x, src >> 8, src2.y); +dst.z = msad(src0.x, src >> 16, src2.z); +dst.w = msad(src0.x, src >> 24, src2.w); +""") + # Combines the first component of each input to make a 3-component vector. triop_horiz("vec3", 3, 1, 1, 1, """ diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index 9bb8b1c7dcf..19a647a47fc 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -1418,6 +1418,14 @@ optimizations.extend([ (('uror@32', a, b), ('ior', ('ushr', a, b), ('ishl', a, ('isub', 32, b))), '!options->has_rotate32'), (('uror@64', a, b), ('ior', ('ushr', a, b), ('ishl', a, ('isub', 64, b)))), + (('bitfield_select', 0xff000000, ('ishl', 'b@32', 24), ('ushr', a, 8)), ('shfr', b, a, 8), 'options->has_shfr32'), + (('bitfield_select', 0xffff0000, ('ishl', 'b@32', 16), ('extract_u16', a, 1)), ('shfr', b, a, 16), 'options->has_shfr32'), + (('bitfield_select', 0xffffff00, ('ishl', 'b@32', 8), ('extract_u8', a, 3)), ('shfr', b, a, 24), 'options->has_shfr32'), + (('ior', ('ishl', 'b@32', 24), ('ushr', a, 8)), ('shfr', b, a, 8), 'options->has_shfr32'), + (('ior', ('ishl', 'b@32', 16), ('extract_u16', a, 1)), ('shfr', b, a, 16), 'options->has_shfr32'), + (('ior', ('ishl', 'b@32', 8), ('extract_u8', a, 3)), ('shfr', b, a, 24), 'options->has_shfr32'), + (('ior', ('ishl', 'b@32', ('iadd', 32, ('ineg', c))), ('ushr@32', a, c)), ('shfr', b, a, c), 'options->has_shfr32'), + # bfi(X, a, b) = (b & ~X) | (a & X) # If X = ~0: (b & 0) | (a & 0xffffffff) = a # If X = 0: (b & 0xffffffff) | (a & 0) = b diff --git a/src/compiler/nir/nir_opt_mqsad.c b/src/compiler/nir/nir_opt_mqsad.c new file mode 100644 index 00000000000..140eb3bb9d8 --- /dev/null +++ b/src/compiler/nir/nir_opt_mqsad.c @@ -0,0 +1,146 @@ +/* + * Copyright 2023 Valve Corporation + * SPDX-License-Identifier: MIT + */ +#include "nir.h" +#include "nir_builder.h" +#include "nir_worklist.h" + +/* + * This pass recognizes certain patterns of nir_op_shfr and nir_op_msad_4x8 and replaces it + * with a single nir_op_mqsad_4x8 instruction. + */ + +struct mqsad { + nir_scalar ref; + nir_scalar src[2]; + + nir_scalar accum[4]; + nir_alu_instr *msad[4]; + unsigned first_msad_index; + uint8_t mask; +}; + +static bool +is_mqsad_compatible(struct mqsad *mqsad, nir_scalar ref, nir_scalar src0, nir_scalar src1, + unsigned idx, nir_alu_instr *msad) +{ + if (!nir_scalar_equal(ref, mqsad->ref) || !nir_scalar_equal(src0, mqsad->src[0])) + return false; + if ((mqsad->mask & 0b1110) && idx && !nir_scalar_equal(src1, mqsad->src[1])) + return false; + + /* Ensure that this MSAD doesn't depend on any previous MSAD. */ + nir_instr_worklist *wl = nir_instr_worklist_create(); + nir_instr_worklist_add_ssa_srcs(wl, &msad->instr); + nir_foreach_instr_in_worklist(instr, wl) { + if (instr->block != msad->instr.block || instr->index < mqsad->first_msad_index) + continue; + + u_foreach_bit(i, mqsad->mask) { + if (instr == &mqsad->msad[i]->instr) { + nir_instr_worklist_destroy(wl); + return false; + } + } + + nir_instr_worklist_add_ssa_srcs(wl, instr); + } + nir_instr_worklist_destroy(wl); + + return true; +} + +static void +parse_msad(nir_alu_instr *msad, struct mqsad *mqsad) +{ + if (msad->def.num_components != 1) + return; + + nir_scalar msad_s = nir_get_scalar(&msad->def, 0); + nir_scalar ref = nir_scalar_chase_alu_src(msad_s, 0); + nir_scalar accum = nir_scalar_chase_alu_src(msad_s, 2); + + unsigned idx = 0; + nir_scalar src0 = nir_scalar_chase_alu_src(msad_s, 1); + nir_scalar src1; + if (nir_scalar_is_alu(src0) && nir_scalar_alu_op(src0) == nir_op_shfr) { + nir_scalar amount_s = nir_scalar_chase_alu_src(src0, 2); + uint32_t amount = nir_scalar_is_const(amount_s) ? nir_scalar_as_uint(amount_s) : 0; + if (amount == 8 || amount == 16 || amount == 24) { + idx = amount / 8; + src1 = nir_scalar_chase_alu_src(src0, 0); + src0 = nir_scalar_chase_alu_src(src0, 1); + } + } + + if (mqsad->mask && !is_mqsad_compatible(mqsad, ref, src0, src1, idx, msad)) + memset(mqsad, 0, sizeof(*mqsad)); + + /* Add this instruction to the in-progress MQSAD. */ + mqsad->ref = ref; + mqsad->src[0] = src0; + if (idx) + mqsad->src[1] = src1; + + mqsad->accum[idx] = accum; + mqsad->msad[idx] = msad; + if (!mqsad->mask) + mqsad->first_msad_index = msad->instr.index; + mqsad->mask |= 1 << idx; +} + +static void +create_msad(nir_builder *b, struct mqsad *mqsad) +{ + nir_def *mqsad_def = nir_mqsad_4x8(b, nir_channel(b, mqsad->ref.def, mqsad->ref.comp), + nir_vec_scalars(b, mqsad->src, 2), + nir_vec_scalars(b, mqsad->accum, 4)); + + for (unsigned i = 0; i < 4; i++) + nir_def_rewrite_uses(&mqsad->msad[i]->def, nir_channel(b, mqsad_def, i)); + + memset(mqsad, 0, sizeof(*mqsad)); +} + +bool +nir_opt_mqsad(nir_shader *shader) +{ + bool progress = false; + nir_foreach_function_impl(impl, shader) { + bool progress_impl = false; + + nir_metadata_require(impl, nir_metadata_instr_index); + + nir_foreach_block(block, impl) { + struct mqsad mqsad; + memset(&mqsad, 0, sizeof(mqsad)); + + nir_foreach_instr(instr, block) { + if (instr->type != nir_instr_type_alu) + continue; + + nir_alu_instr *alu = nir_instr_as_alu(instr); + if (alu->op != nir_op_msad_4x8) + continue; + + parse_msad(alu, &mqsad); + + if (mqsad.mask == 0xf) { + nir_builder b = nir_builder_at(nir_before_instr(instr)); + create_msad(&b, &mqsad); + progress_impl = true; + } + } + } + + if (progress_impl) { + nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance); + progress = true; + } else { + nir_metadata_preserve(impl, nir_metadata_block_index); + } + } + + return progress; +} diff --git a/src/compiler/nir/tests/algebraic_tests.cpp b/src/compiler/nir/tests/algebraic_tests.cpp index a6fcc660efa..45fc0b2df34 100644 --- a/src/compiler/nir/tests/algebraic_tests.cpp +++ b/src/compiler/nir/tests/algebraic_tests.cpp @@ -37,6 +37,8 @@ protected: void test_2src_op(nir_op op, int64_t src0, int64_t src1); + void require_one_alu(nir_op op); + nir_variable *res_var; }; @@ -85,6 +87,18 @@ void algebraic_test_base::test_2src_op(nir_op op, int64_t src0, int64_t src1) test_op(op, nir_imm_int(b, src0), nir_imm_int(b, src1), NULL, NULL, desc); } +void algebraic_test_base::require_one_alu(nir_op op) +{ + unsigned count = 0; + nir_foreach_instr(instr, nir_start_block(b->impl)) { + if (instr->type == nir_instr_type_alu) { + ASSERT_TRUE(nir_instr_as_alu(instr)->op == op); + ASSERT_EQ(count, 0); + count++; + } + } +} + class nir_opt_algebraic_test : public algebraic_test_base { protected: virtual void run_pass() { @@ -99,6 +113,13 @@ protected: } }; +class nir_opt_mqsad_test : public algebraic_test_base { +protected: + virtual void run_pass() { + nir_opt_mqsad(b->shader); + } +}; + TEST_F(nir_opt_algebraic_test, umod_pow2_src2) { for (int i = 0; i <= 9; i++) @@ -162,14 +183,58 @@ TEST_F(nir_opt_algebraic_test, msad) nir_opt_dce(b->shader); } - unsigned count = 0; - nir_foreach_instr(instr, nir_start_block(b->impl)) { - if (instr->type == nir_instr_type_alu) { - ASSERT_TRUE(nir_instr_as_alu(instr)->op == nir_op_msad_4x8); - ASSERT_EQ(count, 0); - count++; + require_one_alu(nir_op_msad_4x8); +} + +TEST_F(nir_opt_mqsad_test, mqsad) +{ + options.lower_bitfield_extract = true; + options.has_bfe = true; + options.has_msad = true; + options.has_shfr32 = true; + + nir_def *ref = nir_load_var(b, nir_local_variable_create(b->impl, glsl_int_type(), "ref")); + nir_def *src = nir_load_var(b, nir_local_variable_create(b->impl, glsl_ivec_type(2), "src")); + nir_def *accum = nir_load_var(b, nir_local_variable_create(b->impl, glsl_ivec_type(4), "accum")); + + nir_def *srcx = nir_channel(b, src, 0); + nir_def *srcy = nir_channel(b, src, 1); + + nir_def *res[4]; + for (unsigned i = 0; i < 4; i++) { + nir_def *src1 = srcx; + switch (i) { + case 0: + break; + case 1: + src1 = nir_bitfield_select(b, nir_imm_int(b, 0xff000000), nir_ishl_imm(b, srcy, 24), + nir_ushr_imm(b, srcx, 8)); + break; + case 2: + src1 = nir_bitfield_select(b, nir_imm_int(b, 0xffff0000), nir_ishl_imm(b, srcy, 16), + nir_extract_u16(b, srcx, nir_imm_int(b, 1))); + break; + case 3: + src1 = nir_bitfield_select(b, nir_imm_int(b, 0xffffff00), nir_ishl_imm(b, srcy, 8), + nir_extract_u8_imm(b, srcx, 3)); + break; } + + res[i] = nir_msad_4x8(b, ref, src1, nir_channel(b, accum, i)); } + + nir_store_var(b, nir_local_variable_create(b->impl, glsl_ivec_type(4), "res"), nir_vec(b, res, 4), 0xf); + + while (nir_opt_algebraic(b->shader)) { + nir_opt_constant_folding(b->shader); + nir_opt_dce(b->shader); + } + + ASSERT_TRUE(nir_opt_mqsad(b->shader)); + nir_copy_prop(b->shader); + nir_opt_dce(b->shader); + + require_one_alu(nir_op_mqsad_4x8); } TEST_F(nir_opt_idiv_const_test, umod)