diff --git a/src/compiler/nir/nir_opt_shrink_vectors.c b/src/compiler/nir/nir_opt_shrink_vectors.c index 90f44e3f301..80481989184 100644 --- a/src/compiler/nir/nir_opt_shrink_vectors.c +++ b/src/compiler/nir/nir_opt_shrink_vectors.c @@ -349,6 +349,103 @@ opt_shrink_vectors_ssa_undef(nir_ssa_undef_instr *instr) return shrink_dest_to_read_mask(&instr->def); } +static bool +opt_shrink_vectors_phi(nir_builder *b, nir_phi_instr *instr) +{ + nir_ssa_def *def = &instr->dest.ssa; + + /* early out if there's nothing to do. */ + if (def->num_components == 1) + return false; + + /* Ignore large vectors for now. */ + if (def->num_components > 4) + return false; + + + /* Check the uses. */ + nir_component_mask_t mask = 0; + nir_foreach_use(src, def) { + if (src->parent_instr->type != nir_instr_type_alu) + return false; + + nir_alu_instr *alu = nir_instr_as_alu(src->parent_instr); + + nir_alu_src *alu_src = exec_node_data(nir_alu_src, src, src); + int src_idx = alu_src - &alu->src[0]; + nir_component_mask_t src_read_mask = nir_alu_instr_src_read_mask(alu, src_idx); + + nir_ssa_def *alu_def = &alu->dest.dest.ssa; + + /* We don't mark the channels used if the only reader is the original phi. + * This can happen in the case of loops. + */ + nir_foreach_use(alu_use_src, alu_def) { + if (alu_use_src->parent_instr != &instr->instr) { + mask |= src_read_mask; + } + } + + /* However, even if the instruction only points back at the phi, we still + * need to check that the swizzles are trivial. + */ + if (nir_op_is_vec(alu->op)) { + if (src_idx != alu->src[src_idx].swizzle[0]) { + mask |= src_read_mask; + } + } else if (!nir_alu_src_is_trivial_ssa(alu, src_idx)) { + mask |= src_read_mask; + } + } + + /* DCE will handle this. */ + if (mask == 0) + return false; + + /* Nothing to shrink? */ + if (BITFIELD_MASK(def->num_components) == mask) + return false; + + /* Set up the reswizzles. */ + unsigned num_components = 0; + uint8_t reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 }; + uint8_t src_reswizzle[NIR_MAX_VEC_COMPONENTS] = { 0 }; + for (unsigned i = 0; i < def->num_components; i++) { + if (!((mask >> i) & 0x1)) + continue; + src_reswizzle[num_components] = i; + reswizzle[i] = num_components++; + } + + /* Shrink the phi, this part is simple. */ + def->num_components = num_components; + + /* We can't swizzle phi sources directly so just insert extra mov + * with the correct swizzle and let the other parts of nir_shrink_vectors + * do its job on the original source instruction. If the original source was + * used only in the phi, the movs will disappear later after copy propagate. + */ + nir_foreach_phi_src(phi_src, instr) { + b->cursor = nir_after_instr_and_phis(phi_src->src.ssa->parent_instr); + + nir_alu_src alu_src = { + .src = nir_src_for_ssa(phi_src->src.ssa) + }; + + for (unsigned i = 0; i < num_components; i++) + alu_src.swizzle[i] = src_reswizzle[i]; + nir_ssa_def *mov = nir_mov_alu(b, alu_src, num_components); + + nir_instr_rewrite_src_ssa(&instr->instr, &phi_src->src, mov); + } + b->cursor = nir_before_instr(&instr->instr); + + /* Reswizzle readers. */ + reswizzle_alu_uses(def, reswizzle); + + return true; +} + static bool opt_shrink_vectors_instr(nir_builder *b, nir_instr *instr) { @@ -367,6 +464,9 @@ opt_shrink_vectors_instr(nir_builder *b, nir_instr *instr) case nir_instr_type_ssa_undef: return opt_shrink_vectors_ssa_undef(nir_instr_as_ssa_undef(instr)); + case nir_instr_type_phi: + return opt_shrink_vectors_phi(b, nir_instr_as_phi(instr)); + default: return false; } diff --git a/src/compiler/nir/tests/opt_shrink_vectors_tests.cpp b/src/compiler/nir/tests/opt_shrink_vectors_tests.cpp index 49329115a20..a9b86d76218 100644 --- a/src/compiler/nir/tests/opt_shrink_vectors_tests.cpp +++ b/src/compiler/nir/tests/opt_shrink_vectors_tests.cpp @@ -278,3 +278,339 @@ TEST_F(nir_opt_shrink_vectors_test, opt_shrink_vectors_vec8) nir_validate_shader(bld.shader, NULL); } + +TEST_F(nir_opt_shrink_vectors_test, opt_shrink_phis_loop_simple) +{ + /* Test that the phi is shrinked in the following case. + * + * v = vec4(0.0, 0.0, 0.0, 0.0); + * while (v.y < 3) { + * v.y += 1.0; + * } + * + * This mimics nir for loops that come out of nine+ttn. + */ + nir_ssa_def *v = nir_imm_vec4(&bld, 0.0, 0.0, 0.0, 0.0); + nir_ssa_def *increment = nir_imm_float(&bld, 1.0); + nir_ssa_def *loop_max = nir_imm_float(&bld, 3.0); + + nir_phi_instr *const phi = nir_phi_instr_create(bld.shader); + nir_ssa_def *phi_def = &phi->dest.ssa; + + nir_loop *loop = nir_push_loop(&bld); + + nir_ssa_dest_init(&phi->instr, &phi->dest, + v->num_components, v->bit_size, + NULL); + + nir_phi_instr_add_src(phi, v->parent_instr->block, + nir_src_for_ssa(v)); + + nir_ssa_def *fge = nir_fge(&bld, phi_def, loop_max); + nir_alu_instr *fge_alu_instr = nir_instr_as_alu(fge->parent_instr); + fge->num_components = 1; + fge_alu_instr->dest.write_mask = BITFIELD_MASK(1); + fge_alu_instr->src[0].swizzle[0] = 1; + + nir_if *nif = nir_push_if(&bld, fge); + { + nir_jump_instr *jump = nir_jump_instr_create(bld.shader, nir_jump_break); + nir_builder_instr_insert(&bld, &jump->instr); + } + nir_pop_if(&bld, nif); + + nir_ssa_def *fadd = nir_fadd(&bld, phi_def, increment); + nir_alu_instr *fadd_alu_instr = nir_instr_as_alu(fadd->parent_instr); + fadd->num_components = 1; + fadd_alu_instr->dest.write_mask = BITFIELD_MASK(1); + fadd_alu_instr->src[0].swizzle[0] = 1; + + nir_ssa_scalar srcs[4] = {{0}}; + for (unsigned i = 0; i < 4; i++) { + srcs[i] = nir_get_ssa_scalar(phi_def, i); + } + srcs[1] = nir_get_ssa_scalar(fadd, 0); + nir_ssa_def *vec = nir_vec_scalars(&bld, srcs, 4); + + nir_phi_instr_add_src(phi, vec->parent_instr->block, + nir_src_for_ssa(vec)); + + nir_pop_loop(&bld, loop); + + bld.cursor = nir_before_block(nir_loop_first_block(loop)); + nir_builder_instr_insert(&bld, &phi->instr); + + /* Generated nir: + * + * impl main { + * block block_0: + * * preds: * + * vec1 32 ssa_0 = deref_var &in (shader_in vec2) + * vec2 32 ssa_1 = intrinsic load_deref (ssa_0) (access=0) + * vec4 32 ssa_2 = load_const (0x00000000, 0x00000000, 0x00000000, 0x00000000) = (0.000000, 0.000000, 0.000000, 0.000000) + * vec1 32 ssa_3 = load_const (0x3f800000 = 1.000000) + * vec1 32 ssa_4 = load_const (0x40400000 = 3.000000) + * * succs: block_1 * + * loop { + * block block_1: + * * preds: block_0 block_4 * + * vec4 32 ssa_8 = phi block_0: ssa_2, block_4: ssa_7 + * vec1 1 ssa_5 = fge ssa_8.y, ssa_4 + * * succs: block_2 block_3 * + * if ssa_5 { + * block block_2: + * * preds: block_1 * + * break + * * succs: block_5 * + * } else { + * block block_3: + * * preds: block_1 * + * * succs: block_4 * + * } + * block block_4: + * * preds: block_3 * + * vec1 32 ssa_6 = fadd ssa_8.y, ssa_3 + * vec4 32 ssa_7 = vec4 ssa_8.x, ssa_6, ssa_8.z, ssa_8.w + * * succs: block_1 * + * } + * block block_5: + * * preds: block_2 * + * * succs: block_6 * + * block block_6: + * } + */ + + nir_validate_shader(bld.shader, NULL); + + ASSERT_TRUE(nir_opt_shrink_vectors(bld.shader)); + ASSERT_TRUE(phi_def->num_components == 1); + check_swizzle(&fge_alu_instr->src[0], "x"); + check_swizzle(&fadd_alu_instr->src[0], "x"); + + nir_validate_shader(bld.shader, NULL); +} + +TEST_F(nir_opt_shrink_vectors_test, opt_shrink_phis_loop_swizzle) +{ + /* Test that the phi is shrinked properly in the following case where + * some swizzling happens in the channels. + * + * v = vec4(0.0, 0.0, 0.0, 0.0); + * while (v.z < 3) { + * v = vec4(v.x, v.z + 1, v.y, v.w}; + * } + */ + nir_ssa_def *v = nir_imm_vec4(&bld, 0.0, 0.0, 0.0, 0.0); + nir_ssa_def *increment = nir_imm_float(&bld, 1.0); + nir_ssa_def *loop_max = nir_imm_float(&bld, 3.0); + + nir_phi_instr *const phi = nir_phi_instr_create(bld.shader); + nir_ssa_def *phi_def = &phi->dest.ssa; + + nir_loop *loop = nir_push_loop(&bld); + + nir_ssa_dest_init(&phi->instr, &phi->dest, + v->num_components, v->bit_size, + NULL); + + nir_phi_instr_add_src(phi, v->parent_instr->block, + nir_src_for_ssa(v)); + + nir_ssa_def *fge = nir_fge(&bld, phi_def, loop_max); + nir_alu_instr *fge_alu_instr = nir_instr_as_alu(fge->parent_instr); + fge->num_components = 1; + fge_alu_instr->dest.write_mask = BITFIELD_MASK(1); + fge_alu_instr->src[0].swizzle[0] = 2; + + nir_if *nif = nir_push_if(&bld, fge); + + nir_jump_instr *jump = nir_jump_instr_create(bld.shader, nir_jump_break); + nir_builder_instr_insert(&bld, &jump->instr); + + nir_pop_if(&bld, nif); + + nir_ssa_def *fadd = nir_fadd(&bld, phi_def, increment); + nir_alu_instr *fadd_alu_instr = nir_instr_as_alu(fadd->parent_instr); + fadd->num_components = 1; + fadd_alu_instr->dest.write_mask = BITFIELD_MASK(1); + fadd_alu_instr->src[0].swizzle[0] = 2; + + nir_ssa_scalar srcs[4] = {{0}}; + srcs[0] = nir_get_ssa_scalar(phi_def, 0); + srcs[1] = nir_get_ssa_scalar(fadd, 0); + srcs[2] = nir_get_ssa_scalar(phi_def, 1); + srcs[3] = nir_get_ssa_scalar(phi_def, 3); + nir_ssa_def *vec = nir_vec_scalars(&bld, srcs, 4); + + nir_phi_instr_add_src(phi, vec->parent_instr->block, + nir_src_for_ssa(vec)); + + nir_pop_loop(&bld, loop); + + bld.cursor = nir_before_block(nir_loop_first_block(loop)); + nir_builder_instr_insert(&bld, &phi->instr); + + /* Generated nir: + * + * impl main { + * block block_0: + * * preds: * + * vec1 32 ssa_0 = deref_var &in (shader_in vec2) + * vec2 32 ssa_1 = intrinsic load_deref (ssa_0) (access=0) + * vec4 32 ssa_2 = load_const (0x00000000, 0x00000000, 0x00000000, 0x00000000) = (0.000000, 0.000000, 0.000000, 0.000000) + * vec1 32 ssa_3 = load_const (0x3f800000 = 1.000000) + * vec1 32 ssa_4 = load_const (0x40400000 = 3.000000) + * * succs: block_1 * + * loop { + * block block_1: + * * preds: block_0 block_4 * + * vec4 32 ssa_8 = phi block_0: ssa_2, block_4: ssa_7 + * vec1 1 ssa_5 = fge ssa_8.z, ssa_4 + * * succs: block_2 block_3 * + * if ssa_5 { + * block block_2: + * * preds: block_1 * + * break + * * succs: block_5 * + * } else { + * block block_3: + * * preds: block_1 * + * * succs: block_4 * + * } + * block block_4: + * * preds: block_3 * + * vec1 32 ssa_6 = fadd ssa_8.z, ssa_3 + * vec4 32 ssa_7 = vec4 ssa_8.x, ssa_6, ssa_8.y, ssa_8.w + * * succs: block_1 * + * } + * block block_5: + * * preds: block_2 * + * * succs: block_6 * + * block block_6: + * } + */ + + nir_validate_shader(bld.shader, NULL); + + ASSERT_TRUE(nir_opt_shrink_vectors(bld.shader)); + ASSERT_TRUE(phi_def->num_components == 2); + + check_swizzle(&fge_alu_instr->src[0], "y"); + check_swizzle(&fadd_alu_instr->src[0], "y"); + + nir_validate_shader(bld.shader, NULL); +} + +TEST_F(nir_opt_shrink_vectors_test, opt_shrink_phis_loop_phi_out) +{ + /* Test that the phi is not shrinked when used by intrinsic. + * + * v = vec4(0.0, 0.0, 0.0, 0.0); + * while (v.y < 3) { + * v.y += 1.0; + * } + * out = v; + */ + nir_ssa_def *v = nir_imm_vec4(&bld, 0.0, 0.0, 0.0, 0.0); + nir_ssa_def *increment = nir_imm_float(&bld, 1.0); + nir_ssa_def *loop_max = nir_imm_float(&bld, 3.0); + + nir_phi_instr *const phi = nir_phi_instr_create(bld.shader); + nir_ssa_def *phi_def = &phi->dest.ssa; + + nir_loop *loop = nir_push_loop(&bld); + + nir_ssa_dest_init(&phi->instr, &phi->dest, + v->num_components, v->bit_size, + NULL); + + nir_phi_instr_add_src(phi, v->parent_instr->block, + nir_src_for_ssa(v)); + + nir_ssa_def *fge = nir_fge(&bld, phi_def, loop_max); + nir_alu_instr *fge_alu_instr = nir_instr_as_alu(fge->parent_instr); + fge->num_components = 1; + fge_alu_instr->dest.write_mask = BITFIELD_MASK(1); + fge_alu_instr->src[0].swizzle[0] = 1; + + nir_if *nif = nir_push_if(&bld, fge); + { + nir_jump_instr *jump = nir_jump_instr_create(bld.shader, nir_jump_break); + nir_builder_instr_insert(&bld, &jump->instr); + } + nir_pop_if(&bld, nif); + + nir_ssa_def *fadd = nir_fadd(&bld, phi_def, increment); + nir_alu_instr *fadd_alu_instr = nir_instr_as_alu(fadd->parent_instr); + fadd->num_components = 1; + fadd_alu_instr->dest.write_mask = BITFIELD_MASK(1); + fadd_alu_instr->src[0].swizzle[0] = 1; + + nir_ssa_scalar srcs[4] = {{0}}; + for (unsigned i = 0; i < 4; i++) { + srcs[i] = nir_get_ssa_scalar(phi_def, i); + } + srcs[1] = nir_get_ssa_scalar(fadd, 0); + nir_ssa_def *vec = nir_vec_scalars(&bld, srcs, 4); + + nir_phi_instr_add_src(phi, vec->parent_instr->block, + nir_src_for_ssa(vec)); + + nir_pop_loop(&bld, loop); + + out_var = nir_variable_create(bld.shader, + nir_var_shader_out, + glsl_vec_type(4), "out4"); + + nir_store_var(&bld, out_var, phi_def, BITFIELD_MASK(4)); + + bld.cursor = nir_before_block(nir_loop_first_block(loop)); + nir_builder_instr_insert(&bld, &phi->instr); + + /* Generated nir: + * + * impl main { + * block block_0: + * * preds: * + * vec1 32 ssa_0 = deref_var &in (shader_in vec2) + * vec2 32 ssa_1 = intrinsic load_deref (ssa_0) (access=0) + * vec4 32 ssa_2 = load_const (0x00000000, 0x00000000, 0x00000000, 0x00000000) = (0.000000, 0.000000, 0.000000, 0.000000) + * vec1 32 ssa_3 = load_const (0x3f800000 = 1.000000) + * vec1 32 ssa_4 = load_const (0x40400000 = 3.000000) + * * succs: block_1 * + * loop { + * block block_1: + * * preds: block_0 block_4 * + * vec4 32 ssa_9 = phi block_0: ssa_2, block_4: ssa_7 + * vec1 1 ssa_5 = fge ssa_9.y, ssa_4 + * * succs: block_2 block_3 * + * if ssa_5 { + * block block_2: + * * preds: block_1 * + * break + * * succs: block_5 * + * } else { + * block block_3: + * * preds: block_1 * + * * succs: block_4 * + * } + * block block_4: + * * preds: block_3 * + * vec1 32 ssa_6 = fadd ssa_9.y, ssa_3 + * vec4 32 ssa_7 = vec4 ssa_9.x, ssa_6, ssa_9.z, ssa_9.w + * * succs: block_1 * + * } + * block block_5: + * * preds: block_2 * + * vec1 32 ssa_8 = deref_var &out4 (shader_out vec4) + * intrinsic store_deref (ssa_8, ssa_9) (wrmask=xyzw *15*, access=0) + * * succs: block_6 * + * block block_6: + * } + */ + + nir_validate_shader(bld.shader, NULL); + + ASSERT_FALSE(nir_opt_shrink_vectors(bld.shader)); + ASSERT_TRUE(phi_def->num_components == 4); +}