diff --git a/src/compiler/nir/meson.build b/src/compiler/nir/meson.build index b4ab113974a..c9c6d141805 100644 --- a/src/compiler/nir/meson.build +++ b/src/compiler/nir/meson.build @@ -255,6 +255,7 @@ files_libnir = files( 'nir_opt_offsets.c', 'nir_opt_peephole_select.c', 'nir_opt_phi_precision.c', + 'nir_opt_phi_to_bool.c', 'nir_opt_preamble.c', 'nir_opt_ray_queries.c', 'nir_opt_reassociate_bfi.c', diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index dd1bd03bbd2..b8d859d5b4e 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -6081,6 +6081,8 @@ bool nir_remove_single_src_phis_block(nir_block *block); bool nir_opt_phi_precision(nir_shader *shader); +bool nir_opt_phi_to_bool(nir_shader *shader); + bool nir_opt_shrink_stores(nir_shader *shader, bool shrink_image_store); bool nir_opt_shrink_vectors(nir_shader *shader, bool shrink_start); diff --git a/src/compiler/nir/nir_opt_phi_to_bool.c b/src/compiler/nir/nir_opt_phi_to_bool.c new file mode 100644 index 00000000000..e2673fda9f9 --- /dev/null +++ b/src/compiler/nir/nir_opt_phi_to_bool.c @@ -0,0 +1,236 @@ +/* + * Copyright 2025 Valve Corporation + * SPDX-License-Identifier: MIT + */ + +#include "nir.h" +#include "nir_builder.h" +#include "nir_worklist.h" + +/* Various other IRs do not have 1bit booleans and instead use 0/1, 0/-1, 0/1.0 + * This pass detects phis with all sources in one of these representations and + * converts the phi to 1bit. The cleanup of related alu is left to other passes + * like nir_opt_algebraic. + */ + +/* This enum is used to store what kind of bool the ssa def is in pass_flags. + * It's a mask to allow multiple types for constant 0 and undef. + */ +enum bool_type { + /* 0 is false, 1 is true. */ + bool_type_single_bit = BITFIELD_BIT(0), + /* 0 is false, -1 is true. */ + bool_type_all_bits = BITFIELD_BIT(1), + /* 0 is false, 1.0 is true. */ + bool_type_float = BITFIELD_BIT(2), + + bool_type_all_types = BITFIELD_MASK(3), +}; + +static inline uint8_t +src_pass_flags(nir_src *src) +{ + return src->ssa->parent_instr->pass_flags; +} + +static inline nir_block * +block_get_loop_preheader(nir_block *block) +{ + nir_cf_node *parent = block->cf_node.parent; + if (parent->type != nir_cf_node_loop) + return NULL; + if (block != nir_cf_node_cf_tree_first(parent)) + return NULL; + return nir_cf_node_as_block(nir_cf_node_prev(parent)); +} + +static uint8_t +get_bool_types_const(nir_load_const_instr *load) +{ + uint8_t res = bool_type_all_types; + unsigned bit_size = load->def.bit_size; + for (unsigned i = 0; i < load->def.num_components; i++) { + int64_t ival = nir_const_value_as_int(load->value[i], bit_size); + if (ival == 0) + continue; + else if (ival == 1) + res &= bool_type_single_bit; + else if (ival == -1) + res &= bool_type_all_bits; + else if (bit_size >= 16 && nir_const_value_as_float(load->value[i], bit_size) == 1.0) + res &= bool_type_float; + else + res = 0; + } + return res; +} + +static uint8_t +get_bool_types_phi(nir_phi_instr *phi) +{ + uint8_t res = bool_type_all_types; + nir_foreach_phi_src(phi_src, phi) + res &= src_pass_flags(&phi_src->src); + return res; +} + +static uint8_t +negate_int_bool_types(nir_src *src) +{ + uint8_t src_types = src_pass_flags(src); + uint8_t res = 0; + if (src_types & bool_type_single_bit) + res |= bool_type_all_bits; + if (src_types & bool_type_all_bits) + res |= bool_type_single_bit; + return res; +} + +static uint8_t +get_bool_types_alu(nir_alu_instr *alu) +{ + switch (alu->op) { + case nir_op_b2i8: + case nir_op_b2i16: + case nir_op_b2i32: + case nir_op_b2i64: + return bool_type_single_bit; + case nir_op_b2b8: + case nir_op_b2b16: + case nir_op_b2b32: + return bool_type_all_bits; + case nir_op_b2f16: + case nir_op_b2f32: + case nir_op_b2f64: + return bool_type_float; + case nir_op_ineg: + return negate_int_bool_types(&alu->src[0].src); + case nir_op_inot: + return src_pass_flags(&alu->src[0].src) & bool_type_all_bits; + case nir_op_bcsel: + return src_pass_flags(&alu->src[1].src) & src_pass_flags(&alu->src[2].src); + case nir_op_iand: + if (src_pass_flags(&alu->src[0].src) & bool_type_all_bits) + return src_pass_flags(&alu->src[1].src); + if (src_pass_flags(&alu->src[1].src) & bool_type_all_bits) + return src_pass_flags(&alu->src[0].src); + FALLTHROUGH; + case nir_op_imin: + case nir_op_imax: + case nir_op_umin: + case nir_op_umax: + case nir_op_ior: + case nir_op_ixor: + return src_pass_flags(&alu->src[0].src) & src_pass_flags(&alu->src[1].src); + case nir_op_fmax: + case nir_op_fmin: + case nir_op_fmul: + case nir_op_fmulz: + return src_pass_flags(&alu->src[0].src) & src_pass_flags(&alu->src[1].src) & bool_type_float; + default: + return 0; + } +} + +static uint8_t +get_bool_types(nir_instr *instr) +{ + switch (instr->type) { + case nir_instr_type_undef: + return bool_type_all_types; + case nir_instr_type_load_const: + return get_bool_types_const(nir_instr_as_load_const(instr)); + case nir_instr_type_phi: + return get_bool_types_phi(nir_instr_as_phi(instr)); + case nir_instr_type_alu: + return get_bool_types_alu(nir_instr_as_alu(instr)); + default: + return 0; + } +} + +static bool +phi_to_bool(nir_builder *b, nir_phi_instr *phi, void *unused) +{ + if (!phi->instr.pass_flags || phi->def.bit_size == 1) + return false; + + enum bool_type type = BITFIELD_BIT(ffs(phi->instr.pass_flags) - 1); + + unsigned bit_size = phi->def.bit_size; + phi->def.bit_size = 1; + + nir_foreach_phi_src(phi_src, phi) { + b->cursor = nir_after_block_before_jump(phi_src->pred); + nir_def *src = phi_src->src.ssa; + if (src == &phi->def) + continue; + else if (nir_src_is_undef(phi_src->src)) + src = nir_undef(b, phi->def.num_components, 1); + else if (type == bool_type_float) + src = nir_fneu_imm(b, src, 0); + else + src = nir_i2b(b, src); + + nir_src_rewrite(&phi_src->src, src); + } + + b->cursor = nir_after_phis(phi->instr.block); + + nir_def *res = &phi->def; + if (type == bool_type_single_bit) + res = nir_b2iN(b, res, bit_size); + else if (type == bool_type_all_bits) + res = nir_bcsel(b, res, nir_imm_intN_t(b, -1, bit_size), nir_imm_intN_t(b, 0, bit_size)); + else if (type == bool_type_float) + res = nir_b2fN(b, res, bit_size); + else + unreachable("invalid bool_type"); + + nir_foreach_use_safe(src, &phi->def) { + if (nir_src_parent_instr(src) == &phi->instr || + nir_src_parent_instr(src) == res->parent_instr) + continue; + nir_src_rewrite(src, res); + } + + return true; +} + +bool +nir_opt_phi_to_bool(nir_shader *shader) +{ + nir_instr_worklist *worklist = nir_instr_worklist_create(); + + nir_foreach_function_impl(impl, shader) { + nir_foreach_block(block, impl) { + nir_block *preheader = block_get_loop_preheader(block); + nir_foreach_instr(instr, block) { + if (instr->type == nir_instr_type_phi && preheader) { + nir_phi_src *phi_src = nir_phi_get_src_from_block(nir_instr_as_phi(instr), preheader); + instr->pass_flags = src_pass_flags(&phi_src->src); + /* We only know the types of the preheader phi source + * so we need to revisit it later if nessecary. + */ + if (instr->pass_flags) + nir_instr_worklist_push_tail(worklist, instr); + } else { + instr->pass_flags = get_bool_types(instr); + } + } + } + } + + nir_foreach_instr_in_worklist(instr, worklist) { + uint8_t bool_types = get_bool_types(instr); + if (instr->pass_flags != bool_types) { + instr->pass_flags = bool_types; + nir_foreach_use(use, nir_instr_def(instr)) + nir_instr_worklist_push_tail(worklist, nir_src_parent_instr(use)); + } + } + + nir_instr_worklist_destroy(worklist); + + return nir_shader_phi_pass(shader, phi_to_bool, nir_metadata_control_flow, NULL); +}